From 45bcba0067668c745b983bd20c62d5d303c9a217 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 1 Aug 2025 17:46:00 +0200 Subject: [PATCH 01/47] dual state integration --- adapters/p2p2core/state.go | 9 +- blockchain/blockchain.go | 345 +++++++++++++--- blockchain/blockchain_test.go | 13 +- builder/builder.go | 3 +- core/accessors.go | 4 +- core/mocks/mock_commonstate_reader.go | 148 +++++++ core/receipt.go | 1 + core/state/cache.go | 4 +- core/state/cache_test.go | 3 +- core/state/commonstate/state.go | 174 ++++++++- core/state/commontrie/trie.go | 34 +- core/state/history.go | 11 +- core/state/history_test.go | 6 - core/state/state.go | 12 +- core/trie2/triedb/database.go | 22 ++ core/trie2/triedb/pathdb/buffer.go | 20 +- core/trie2/triedb/pathdb/database.go | 6 + core/trie2/triedb/pathdb/difflayer.go | 7 +- core/trie2/triedb/pathdb/disklayer.go | 7 +- core/trie2/triedb/pathdb/nodeset.go | 1 + core/trie2/triedb/pathdb/reader.go | 2 +- core/trie2/trieutils/accessors.go | 2 +- genesis/genesis.go | 23 +- mempool/mempool.go | 2 +- mempool/mempool_test.go | 6 +- mocks/mock_blockchain.go | 21 +- ...e_reader.go => mock_commonstate_reader.go} | 18 +- ...ock_gateway_handler.go => mock_gateway.go} | 5 +- mocks/mock_synchronizer.go | 9 +- mocks/mock_vm.go | 5 +- node/node.go | 7 + node/throttled_vm.go | 5 +- plugin/plugin_test.go | 1 + rpc/v6/class.go | 2 +- rpc/v6/class_test.go | 8 +- rpc/v6/contract.go | 4 +- rpc/v6/contract_test.go | 4 +- rpc/v6/estimate_fee_test.go | 5 +- rpc/v6/handlers_test.go | 6 +- rpc/v6/helpers.go | 7 +- rpc/v6/pending_data_wrapper.go | 3 +- rpc/v6/pending_data_wrapper_test.go | 2 +- rpc/v6/simulation_test.go | 4 +- rpc/v6/trace.go | 7 +- rpc/v6/trace_test.go | 12 +- rpc/v6/transaction_test.go | 2 +- rpc/v7/compiled_casm_test.go | 4 +- rpc/v7/estimate_fee_test.go | 2 +- rpc/v7/handlers_test.go | 6 +- rpc/v7/helpers.go | 7 +- rpc/v7/pending_data_wrapper.go | 3 +- rpc/v7/pending_data_wrapper_test.go | 2 +- rpc/v7/simulation_test.go | 4 +- rpc/v7/storage.go | 6 +- rpc/v7/storage_test.go | 2 +- rpc/v7/trace.go | 7 +- rpc/v7/trace_test.go | 12 +- rpc/v8/compiled_casm_test.go | 6 +- rpc/v8/estimate_fee_test.go | 2 +- rpc/v8/handlers_test.go | 6 +- rpc/v8/helpers.go | 7 +- rpc/v8/l1.go | 4 +- rpc/v8/pending_data_wrapper.go | 3 +- rpc/v8/pending_data_wrapper_test.go | 2 +- rpc/v8/simulation_test.go | 16 +- rpc/v8/storage.go | 160 ++++++-- rpc/v8/storage_test.go | 12 +- rpc/v8/subscriptions_test.go | 10 +- rpc/v8/trace.go | 7 +- rpc/v8/trace_test.go | 12 +- rpc/v9/class.go | 2 +- rpc/v9/class_test.go | 8 +- rpc/v9/compiled_casm_test.go | 6 +- rpc/v9/estimate_fee_test.go | 2 +- rpc/v9/handlers_test.go | 6 +- rpc/v9/helpers.go | 13 +- rpc/v9/l1.go | 4 +- rpc/v9/nonce.go | 2 +- rpc/v9/nonce_test.go | 2 +- rpc/v9/pending_data_wrapper.go | 3 +- rpc/v9/pending_data_wrapper_test.go | 2 +- rpc/v9/simulation_test.go | 16 +- rpc/v9/storage.go | 157 ++++++-- rpc/v9/storage_test.go | 12 +- rpc/v9/subscriptions_test.go | 10 +- rpc/v9/trace.go | 7 +- rpc/v9/trace_test.go | 10 +- sequencer/sequencer.go | 5 +- sync/pending.go | 59 +-- sync/pending_test.go | 2 +- sync/sync.go | 31 +- sync/sync_test.go | 2 + vm/state.go | 22 +- vm/vm.go | 11 +- vm/vm_test.go | 369 ++++++++++-------- 95 files changed, 1482 insertions(+), 585 deletions(-) create mode 100644 core/mocks/mock_commonstate_reader.go rename mocks/{mock_state_reader.go => mock_commonstate_reader.go} (89%) rename mocks/{mock_gateway_handler.go => mock_gateway.go} (87%) diff --git a/adapters/p2p2core/state.go b/adapters/p2p2core/state.go index f6ddbd633b..0c1aea95d8 100644 --- a/adapters/p2p2core/state.go +++ b/adapters/p2p2core/state.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" "github.com/starknet-io/starknet-p2pspecs/p2p/proto/class" @@ -13,7 +14,7 @@ import ( "github.com/starknet-io/starknet-p2pspecs/p2p/proto/sync/state" ) -func AdaptStateDiff(reader core.StateReader, contractDiffs []*state.ContractDiff, classes []*class.Class) *core.StateDiff { +func AdaptStateDiff(reader commonstate.StateReader, contractDiffs []*state.ContractDiff, classes []*class.Class) *core.StateDiff { var ( declaredV0Classes []*felt.Felt declaredV1Classes = make(map[felt.Felt]*felt.Felt) @@ -53,16 +54,16 @@ func AdaptStateDiff(reader core.StateReader, contractDiffs []*state.ContractDiff classHash: diff.ClassHash, } - var stateClassHash *felt.Felt + var stateClassHash felt.Felt if reader == nil { // zero block - stateClassHash = &felt.Zero + stateClassHash = felt.Zero } else { var err error stateClassHash, err = reader.ContractClassHash(address) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { - stateClassHash = &felt.Zero + stateClassHash = felt.Zero } else { panic(err) } diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index e85b1a16d2..eab92860a7 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -11,7 +11,6 @@ import ( "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" "github.com/NethermindEth/juno/utils" "github.com/ethereum/go-ethereum/common" @@ -23,10 +22,11 @@ type L1HeadSubscription struct { //go:generate mockgen -destination=../mocks/mock_blockchain.go -package=mocks github.com/NethermindEth/juno/blockchain Reader type Reader interface { + StateProvider Height() (height uint64, err error) Head() (head *core.Block, err error) - L1Head() (*core.L1Head, error) + L1Head() (core.L1Head, error) SubscribeL1Head() L1HeadSubscription BlockByNumber(number uint64) (block *core.Block, err error) BlockByHash(hash *felt.Felt) (block *core.Block, err error) @@ -40,11 +40,7 @@ type Reader interface { Receipt(hash *felt.Felt) (receipt *core.TransactionReceipt, blockHash *felt.Felt, blockNumber uint64, err error) StateUpdateByNumber(number uint64) (update *core.StateUpdate, err error) StateUpdateByHash(hash *felt.Felt) (update *core.StateUpdate, err error) - L1HandlerTxnHash(msgHash *common.Hash) (l1HandlerTxnHash *felt.Felt, err error) - - HeadState() (core.StateReader, StateCloser, error) - StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, StateCloser, error) - StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) + L1HandlerTxnHash(msgHash *common.Hash) (l1HandlerTxnHash felt.Felt, err error) BlockCommitmentsByNumber(blockNumber uint64) (*core.BlockCommitments, error) @@ -53,6 +49,12 @@ type Reader interface { Network() *utils.Network } +type StateProvider interface { + HeadState() (commonstate.StateReader, StateCloser, error) + StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, StateCloser, error) + StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, StateCloser, error) +} + var ( ErrParentDoesNotMatchHead = errors.New("block's parent hash does not match head block hash") SupportedStarknetVersion = semver.MustParse("0.14.0") @@ -87,21 +89,23 @@ var _ Reader = (*Blockchain)(nil) type Blockchain struct { network *utils.Network database db.KeyValueStore + trieDB *triedb.Database + StateDB *state.StateDB // TODO(weiihann): not sure if it's a good idea to expose this listener EventListener l1HeadFeed *feed.Feed[*core.L1Head] cachedFilters *AggregatedBloomFilterCache runningFilter *core.RunningEventFilter - stateFactory *commonstate.StateFactory + StateFactory *commonstate.StateFactory } -func New(database db.KeyValueStore, network *utils.Network) *Blockchain { +func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) *Blockchain { trieDB, err := triedb.New(database, nil) // TODO: handle hashdb if err != nil { panic(err) } stateDB := state.NewStateDB(database, trieDB) - stateFactory, err := commonstate.NewStateFactory(true, trieDB, stateDB) + StateFactory, err := commonstate.NewStateFactory(stateVersion, trieDB, stateDB) if err != nil { panic(err) } @@ -116,12 +120,14 @@ func New(database db.KeyValueStore, network *utils.Network) *Blockchain { return &Blockchain{ database: database, + trieDB: trieDB, + StateDB: stateDB, network: network, listener: &SelectiveListener{}, l1HeadFeed: feed.New[*core.L1Head](), cachedFilters: &cachedFilters, runningFilter: runningFilter, - stateFactory: stateFactory, + StateFactory: StateFactory, } } @@ -134,14 +140,6 @@ func (b *Blockchain) Network() *utils.Network { return b.network } -// StateCommitment returns the latest block state commitment. -// If blockchain is empty zero felt is returned. -func (b *Blockchain) StateCommitment() (*felt.Felt, error) { - b.listener.OnRead("StateCommitment") - batch := b.database.NewIndexedBatch() // this is a hack because we don't need to write to the db - return core.NewState(batch).Root() -} - // Height returns the latest block height. If blockchain is empty nil is returned. func (b *Blockchain) Height() (uint64, error) { b.listener.OnRead("Height") @@ -215,10 +213,9 @@ func (b *Blockchain) StateUpdateByHash(hash *felt.Felt) (*core.StateUpdate, erro return core.GetStateUpdateByHash(b.database, hash) } -func (b *Blockchain) L1HandlerTxnHash(msgHash *common.Hash) (*felt.Felt, error) { +func (b *Blockchain) L1HandlerTxnHash(msgHash *common.Hash) (felt.Felt, error) { b.listener.OnRead("L1HandlerTxnHash") - txnHash, err := core.GetL1HandlerTxnHashByMsgHash(b.database, msgHash.Bytes()) - return &txnHash, err // TODO: return felt value + return core.GetL1HandlerTxnHashByMsgHash(b.database, msgHash.Bytes()) // TODO: return felt value } // TransactionByBlockNumberAndIndex gets the transaction for a given block number and index. @@ -258,10 +255,9 @@ func (b *Blockchain) SubscribeL1Head() L1HeadSubscription { return L1HeadSubscription{b.l1HeadFeed.Subscribe()} } -func (b *Blockchain) L1Head() (*core.L1Head, error) { +func (b *Blockchain) L1Head() (core.L1Head, error) { b.listener.OnRead("L1Head") - l1Head, err := core.GetL1Head(b.database) - return &l1Head, err // TODO: this should return a value + return core.GetL1Head(b.database) } func (b *Blockchain) SetL1Head(update *core.L1Head) error { @@ -272,13 +268,29 @@ func (b *Blockchain) SetL1Head(update *core.L1Head) error { // Store takes a block and state update and performs sanity checks before putting in the database. func (b *Blockchain) Store(block *core.Block, blockCommitments *core.BlockCommitments, stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.Class, +) error { + // old state + // TODO(maksymmalick): remove this once we have a new state implementation + if !b.StateFactory.UseNewState { + return b.storeCoreState(block, blockCommitments, stateUpdate, newClasses) + } + + return b.store(block, blockCommitments, stateUpdate, newClasses) +} + +func (b *Blockchain) storeCoreState( + block *core.Block, + blockCommitments *core.BlockCommitments, + stateUpdate *core.StateUpdate, + newClasses map[felt.Felt]core.Class, ) error { err := b.database.Update(func(txn db.IndexedBatch) error { if err := verifyBlock(txn, block); err != nil { return err } - if err := core.NewState(txn).Update(block.Number, stateUpdate, newClasses, false); err != nil { + state := core.NewState(txn) + if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { return err } if err := core.WriteBlockHeader(txn, block.Header); err != nil { @@ -316,6 +328,56 @@ func (b *Blockchain) Store(block *core.Block, blockCommitments *core.BlockCommit ) } +func (b *Blockchain) store( + block *core.Block, + blockCommitments *core.BlockCommitments, + stateUpdate *core.StateUpdate, + newClasses map[felt.Felt]core.Class, +) error { + // TODO(weiihann): handle unexpected shutdown + if err := verifyBlock(b.database, block); err != nil { + return err + } + state, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil) + if err != nil { + return err + } + if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { + return err + } + batch := b.database.NewBatch() + if err := core.WriteBlockHeader(batch, block.Header); err != nil { + return err + } + for i, tx := range block.Transactions { + if err := core.WriteTxAndReceipt(batch, block.Number, uint64(i), tx, + block.Receipts[i]); err != nil { + return err + } + } + if err := core.WriteStateUpdateByBlockNum(batch, block.Number, stateUpdate); err != nil { + return err + } + + if err := core.WriteBlockCommitment(batch, block.Number, blockCommitments); err != nil { + return err + } + + if err := core.WriteL1HandlerMsgHashes(batch, block.Transactions); err != nil { + return err + } + + if err := core.WriteChainHeight(batch, block.Number); err != nil { + return err + } + + if err := b.runningFilter.Insert(block.EventsBloom, block.Number); err != nil { + return err + } + + return batch.Write() +} + // VerifyBlock assumes the block has already been sanity-checked. func (b *Blockchain) VerifyBlock(block *core.Block) error { return verifyBlock(b.database, block) @@ -375,39 +437,46 @@ type StateCloser = func() error var noopStateCloser = func() error { return nil } // TODO: remove this once we refactor the state // HeadState returns a StateReader that provides a stable view to the latest state -func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) { +func (b *Blockchain) HeadState() (commonstate.StateReader, StateCloser, error) { b.listener.OnRead("HeadState") txn := b.database.NewIndexedBatch() - _, err := core.GetChainHeight(txn) + height, err := core.GetChainHeight(txn) + if err != nil { + return nil, nil, err + } + + header, err := core.GetBlockHeaderByNumber(txn, height) if err != nil { return nil, nil, err } - return core.NewState(txn), noopStateCloser, nil + state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn) + + return state, noopStateCloser, err } // StateAtBlockNumber returns a StateReader that provides a stable view to the state at the given block number -func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) { +func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") txn := b.database.NewIndexedBatch() - _, err := core.GetBlockHeaderByNumber(txn, blockNumber) + header, err := core.GetBlockHeaderByNumber(txn, blockNumber) if err != nil { return nil, nil, err } - return core.NewStateSnapshot(core.NewState(txn), blockNumber), noopStateCloser, nil + stateReader, err := b.StateFactory.NewStateReader(header.GlobalStateRoot, txn, blockNumber) + + return stateReader, noopStateCloser, err } // StateAtBlockHash returns a StateReader that provides a stable view to the state at the given block hash -func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, StateCloser, error) { +func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockHash") if blockHash.IsZero() { - memDB := memory.New() - txn := memDB.NewIndexedBatch() - emptyState := core.NewState(txn) - return emptyState, noopStateCloser, nil + emptyState, err := b.StateFactory.EmptyState() + return emptyState, noopStateCloser, err } txn := b.database.NewIndexedBatch() @@ -415,8 +484,8 @@ func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, S if err != nil { return nil, nil, err } - - return core.NewStateSnapshot(core.NewState(txn), header.Number), noopStateCloser, nil + stateReader, err := b.StateFactory.NewStateReader(header.GlobalStateRoot, txn, header.Number) + return stateReader, noopStateCloser, err } // EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain @@ -432,10 +501,26 @@ func (b *Blockchain) EventFilter(from *felt.Felt, keys [][]felt.Felt, pendingBlo // RevertHead reverts the head block func (b *Blockchain) RevertHead() error { - return b.database.Update(b.revertHead) + if !b.StateFactory.UseNewState { + return b.database.Update(b.revertHeadCoreState) + } + return b.revertHead() +} + +func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) { + if !b.StateFactory.UseNewState { + reverseStateDiff, err := b.getReverseStateDiffCoreState() + if err != nil { + return core.StateDiff{}, err + } + return *reverseStateDiff, nil + } + + return b.getReverseStateDiff() } -func (b *Blockchain) GetReverseStateDiff() (*core.StateDiff, error) { +// TODO(maksymmalick): remove this once we have a new state integrated +func (b *Blockchain) getReverseStateDiffCoreState() (*core.StateDiff, error) { var reverseStateDiff *core.StateDiff txn := b.database.NewIndexedBatch() @@ -458,7 +543,31 @@ func (b *Blockchain) GetReverseStateDiff() (*core.StateDiff, error) { return reverseStateDiff, nil } -func (b *Blockchain) revertHead(txn db.IndexedBatch) error { +func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) { + var ret core.StateDiff + + blockNum, err := core.GetChainHeight(b.database) + if err != nil { + return ret, err + } + stateUpdate, err := core.GetStateUpdateByBlockNum(b.database, blockNum) + if err != nil { + return ret, err + } + state, err := state.New(stateUpdate.NewRoot, b.StateDB) + if err != nil { + return ret, err + } + + ret, err = state.GetReverseStateDiff(blockNum, stateUpdate.StateDiff) + if err != nil { + return ret, err + } + + return ret, nil +} + +func (b *Blockchain) revertHeadCoreState(txn db.IndexedBatch) error { blockNumber, err := core.GetChainHeight(txn) if err != nil { return err @@ -493,7 +602,7 @@ func (b *Blockchain) revertHead(txn db.IndexedBatch) error { } } - if err = core.DeleteTxsAndReceipts(txn, blockNumber, header.TransactionCount); err != nil { + if err = core.DeleteTxsAndReceipts(txn, txn, blockNumber, header.TransactionCount); err != nil { return err } @@ -516,6 +625,58 @@ func (b *Blockchain) revertHead(txn db.IndexedBatch) error { return b.runningFilter.OnReorg() } +// RevertHead reverts the head block +func (b *Blockchain) revertHead() error { + blockNumber, err := core.GetChainHeight(b.database) + if err != nil { + return err + } + stateUpdate, err := core.GetStateUpdateByBlockNum(b.database, blockNumber) + if err != nil { + return err + } + state, err := state.New(stateUpdate.NewRoot, b.StateDB) + if err != nil { + return err + } + // revert state + if err = state.Revert(blockNumber, stateUpdate); err != nil { + return err + } + header, err := core.GetBlockHeaderByNumber(b.database, blockNumber) + if err != nil { + return err + } + genesisBlock := blockNumber == 0 + + batch := b.database.NewBatch() + for _, key := range [][]byte{ + db.BlockHeaderByNumberKey(header.Number), + db.BlockHeaderNumbersByHashKey(header.Hash), + db.BlockCommitmentsKey(header.Number), + } { + if err = batch.Delete(key); err != nil { + return err + } + } + if err = core.DeleteTxsAndReceipts(b.database, batch, blockNumber, header.TransactionCount); err != nil { + return err + } + if err = core.DeleteStateUpdateByBlockNum(batch, blockNumber); err != nil { + return err + } + if genesisBlock { + if err := core.DeleteChainHeight(batch); err != nil { + return err + } + } else { + if err := core.WriteChainHeight(batch, blockNumber-1); err != nil { + return err + } + } + return batch.Write() +} + type SimulateResult struct { BlockCommitments *core.BlockCommitments ConcatCount felt.Felt @@ -566,8 +727,31 @@ func (b *Blockchain) Finalise( newClasses map[felt.Felt]core.Class, sign utils.BlockSignFunc, ) error { - err := b.database.Update(func(txn db.IndexedBatch) error { - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { + if !b.StateFactory.UseNewState { + err := b.database.Update(func(txn db.IndexedBatch) error { + if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { + return err + } + commitments, err := b.updateBlockHash(block, stateUpdate) + if err != nil { + return err + } + if err := b.signBlock(block, stateUpdate, sign); err != nil { + return err + } + if err := b.storeBlockData(txn, block, stateUpdate, commitments); err != nil { + return err + } + return core.WriteChainHeight(txn, block.Number) + }) + if err != nil { + return err + } + + return b.runningFilter.Insert(block.EventsBloom, block.Number) + } else { + batch := b.database.NewBatch() + if err := b.updateStateRoots(nil, block, stateUpdate, newClasses); err != nil { return err } commitments, err := b.updateBlockHash(block, stateUpdate) @@ -577,16 +761,17 @@ func (b *Blockchain) Finalise( if err := b.signBlock(block, stateUpdate, sign); err != nil { return err } - if err := b.storeBlockData(txn, block, stateUpdate, commitments); err != nil { + if err := b.storeBlockData(batch, block, stateUpdate, commitments); err != nil { return err } - return core.WriteChainHeight(txn, block.Number) - }) - if err != nil { - return err + if err := core.WriteChainHeight(batch, block.Number); err != nil { + return err + } + if err := batch.Write(); err != nil { + return err + } + return b.runningFilter.Insert(block.EventsBloom, block.Number) } - - return b.runningFilter.Insert(block.EventsBloom, block.Number) } // updateStateRoots computes and updates state roots in the block and state update @@ -596,14 +781,31 @@ func (b *Blockchain) updateStateRoots( stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.Class, ) error { - state := core.NewState(txn) + var height uint64 + var err error + if height, err = core.GetChainHeight(b.database); err != nil { + height = 0 + } + + header, _ := core.GetBlockHeaderByNumber(b.database, height) + var stateRoot *felt.Felt + if header != nil { + stateRoot = header.GlobalStateRoot + } else { + stateRoot = &felt.Zero + } + + state, err := b.StateFactory.NewState(stateRoot, txn) + if err != nil { + return err + } // Get old state root - oldStateRoot, err := state.Root() + oldStateRoot, err := state.Commitment() if err != nil { return err } - stateUpdate.OldRoot = oldStateRoot + stateUpdate.OldRoot = &oldStateRoot // Apply state update if err = state.Update(block.Number, stateUpdate, newClasses, true); err != nil { @@ -611,12 +813,12 @@ func (b *Blockchain) updateStateRoots( } // Get new state root - newStateRoot, err := state.Root() + newStateRoot, err := state.Commitment() if err != nil { return err } - block.GlobalStateRoot = newStateRoot + block.GlobalStateRoot = &newStateRoot stateUpdate.NewRoot = block.GlobalStateRoot return nil @@ -659,35 +861,35 @@ func (b *Blockchain) signBlock( // storeBlockData persists all block-related data to the database func (b *Blockchain) storeBlockData( - txn db.IndexedBatch, + w db.KeyValueWriter, block *core.Block, stateUpdate *core.StateUpdate, commitments *core.BlockCommitments, ) error { // Store block header - if err := core.WriteBlockHeader(txn, block.Header); err != nil { + if err := core.WriteBlockHeader(w, block.Header); err != nil { return err } // Store transactions and receipts for i, tx := range block.Transactions { - if err := core.WriteTxAndReceipt(txn, block.Number, uint64(i), tx, block.Receipts[i]); err != nil { + if err := core.WriteTxAndReceipt(w, block.Number, uint64(i), tx, block.Receipts[i]); err != nil { return err } } // Store state update - if err := core.WriteStateUpdateByBlockNum(txn, block.Number, stateUpdate); err != nil { + if err := core.WriteStateUpdateByBlockNum(w, block.Number, stateUpdate); err != nil { return err } // Store block commitments - if err := core.WriteBlockCommitment(txn, block.Number, commitments); err != nil { + if err := core.WriteBlockCommitment(w, block.Number, commitments); err != nil { return err } // Store L1 handler message hashes - if err := core.WriteL1HandlerMsgHashes(txn, block.Transactions); err != nil { + if err := core.WriteL1HandlerMsgHashes(w, block.Transactions); err != nil { return err } @@ -729,3 +931,20 @@ func (b *Blockchain) StoreGenesis( func (b *Blockchain) WriteRunningEventFilter() error { return b.runningFilter.Write() } + +func (b *Blockchain) Stop() error { + if b.trieDB.Scheme() == triedb.PathScheme { + head, err := b.HeadsHeader() + if err != nil { + return err + } + + stateUpdate, err := b.StateUpdateByNumber(head.Number) + if err != nil { + return err + } + + return b.trieDB.Journal(stateUpdate.NewRoot) + } + return nil +} diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index 0a108c2242..d8d0a68608 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -239,10 +239,6 @@ func TestStore(t *testing.T) { require.NoError(t, err) assert.Equal(t, block0, headBlock) - root, err := chain.StateCommitment() - require.NoError(t, err) - assert.Equal(t, stateUpdate0.NewRoot, root) - got0Block, err := chain.BlockByNumber(0) require.NoError(t, err) assert.Equal(t, block0, got0Block) @@ -267,10 +263,6 @@ func TestStore(t *testing.T) { require.NoError(t, err) assert.Equal(t, block1, headBlock) - root, err := chain.StateCommitment() - require.NoError(t, err) - assert.Equal(t, stateUpdate1.NewRoot, root) - got1Block, err := chain.BlockByNumber(1) require.NoError(t, err) assert.Equal(t, block1, got1Block) @@ -296,7 +288,8 @@ func TestStoreL1HandlerTxnHash(t *testing.T) { l1HandlerMsgHash := common.HexToHash("0x42e76df4e3d5255262929c27132bd0d295a8d3db2cfe63d2fcd061c7a7a7ab34") l1HandlerTxnHash, err := chain.L1HandlerTxnHash(&l1HandlerMsgHash) require.NoError(t, err) - require.Equal(t, utils.HexToFelt(t, "0x785c2ada3f53fbc66078d47715c27718f92e6e48b96372b36e5197de69b82b5"), l1HandlerTxnHash) + expectedL1HandlerTxnHash := utils.HexToFelt(t, "0x785c2ada3f53fbc66078d47715c27718f92e6e48b96372b36e5197de69b82b5") + require.Equal(t, *expectedL1HandlerTxnHash, l1HandlerTxnHash) } func TestBlockCommitments(t *testing.T) { @@ -681,7 +674,7 @@ func TestL1Update(t *testing.T) { require.NoError(t, chain.SetL1Head(head)) got, err := chain.L1Head() require.NoError(t, err) - assert.Equal(t, head, got) + assert.Equal(t, head, &got) }) } } diff --git a/builder/builder.go b/builder/builder.go index 0454675f6c..11ce7f7a52 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/sync" "github.com/NethermindEth/juno/utils" @@ -124,7 +125,7 @@ func (b *Builder) getRevealedBlockHash(blockHeight uint64) (*felt.Felt, error) { return header.Hash, nil } -func (b *Builder) PendingState(buildState *BuildState) (core.StateReader, func() error, error) { +func (b *Builder) PendingState(buildState *BuildState) (commonstate.StateReader, func() error, error) { if buildState.Preconfirmed == nil { return nil, nil, sync.ErrPendingBlockNotFound } diff --git a/core/accessors.go b/core/accessors.go index 3a347185bd..4a98ef3d31 100644 --- a/core/accessors.go +++ b/core/accessors.go @@ -225,10 +225,10 @@ func GetReceiptByHash(r db.KeyValueReader, hash *felt.Felt) (*TransactionReceipt return GetReceiptByBlockNumIndexBytes(r, val) } -func DeleteTxsAndReceipts(batch db.IndexedBatch, blockNum, numTxs uint64) error { +func DeleteTxsAndReceipts(r db.KeyValueReader, batch db.KeyValueWriter, blockNum, numTxs uint64) error { // remove txs and receipts for i := range numTxs { - txn, err := GetTxByBlockNumIndex(batch, blockNum, i) + txn, err := GetTxByBlockNumIndex(r, blockNum, i) if err != nil { return err } diff --git a/core/mocks/mock_commonstate_reader.go b/core/mocks/mock_commonstate_reader.go new file mode 100644 index 0000000000..d6d0c40e89 --- /dev/null +++ b/core/mocks/mock_commonstate_reader.go @@ -0,0 +1,148 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/NethermindEth/juno/core/state/commonstate (interfaces: StateReader) +// +// Generated by this command: +// +// mockgen -destination=../../mocks/mock_commonstate_reader.go -package=mocks github.com/NethermindEth/juno/core/state/commonstate StateReader +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + core "github.com/NethermindEth/juno/core" + felt "github.com/NethermindEth/juno/core/felt" + commontrie "github.com/NethermindEth/juno/core/state/commontrie" + gomock "go.uber.org/mock/gomock" +) + +// MockStateReader is a mock of StateReader interface. +type MockStateReader struct { + ctrl *gomock.Controller + recorder *MockStateReaderMockRecorder + isgomock struct{} +} + +// MockStateReaderMockRecorder is the mock recorder for MockStateReader. +type MockStateReaderMockRecorder struct { + mock *MockStateReader +} + +// NewMockStateReader creates a new mock instance. +func NewMockStateReader(ctrl *gomock.Controller) *MockStateReader { + mock := &MockStateReader{ctrl: ctrl} + mock.recorder = &MockStateReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateReader) EXPECT() *MockStateReaderMockRecorder { + return m.recorder +} + +// Class mocks base method. +func (m *MockStateReader) Class(classHash *felt.Felt) (*core.DeclaredClass, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Class", classHash) + ret0, _ := ret[0].(*core.DeclaredClass) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Class indicates an expected call of Class. +func (mr *MockStateReaderMockRecorder) Class(classHash any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Class", reflect.TypeOf((*MockStateReader)(nil).Class), classHash) +} + +// ClassTrie mocks base method. +func (m *MockStateReader) ClassTrie() (commontrie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClassTrie") + ret0, _ := ret[0].(commontrie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ClassTrie indicates an expected call of ClassTrie. +func (mr *MockStateReaderMockRecorder) ClassTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClassTrie", reflect.TypeOf((*MockStateReader)(nil).ClassTrie)) +} + +// ContractClassHash mocks base method. +func (m *MockStateReader) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractClassHash", addr) + ret0, _ := ret[0].(felt.Felt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractClassHash indicates an expected call of ContractClassHash. +func (mr *MockStateReaderMockRecorder) ContractClassHash(addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractClassHash", reflect.TypeOf((*MockStateReader)(nil).ContractClassHash), addr) +} + +// ContractNonce mocks base method. +func (m *MockStateReader) ContractNonce(addr *felt.Felt) (felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractNonce", addr) + ret0, _ := ret[0].(felt.Felt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractNonce indicates an expected call of ContractNonce. +func (mr *MockStateReaderMockRecorder) ContractNonce(addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractNonce", reflect.TypeOf((*MockStateReader)(nil).ContractNonce), addr) +} + +// ContractStorage mocks base method. +func (m *MockStateReader) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractStorage", addr, key) + ret0, _ := ret[0].(felt.Felt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractStorage indicates an expected call of ContractStorage. +func (mr *MockStateReaderMockRecorder) ContractStorage(addr, key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractStorage", reflect.TypeOf((*MockStateReader)(nil).ContractStorage), addr, key) +} + +// ContractStorageTrie mocks base method. +func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractStorageTrie", addr) + ret0, _ := ret[0].(commontrie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractStorageTrie indicates an expected call of ContractStorageTrie. +func (mr *MockStateReaderMockRecorder) ContractStorageTrie(addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractStorageTrie", reflect.TypeOf((*MockStateReader)(nil).ContractStorageTrie), addr) +} + +// ContractTrie mocks base method. +func (m *MockStateReader) ContractTrie() (commontrie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractTrie") + ret0, _ := ret[0].(commontrie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractTrie indicates an expected call of ContractTrie. +func (mr *MockStateReaderMockRecorder) ContractTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractTrie", reflect.TypeOf((*MockStateReader)(nil).ContractTrie)) +} diff --git a/core/receipt.go b/core/receipt.go index 92f3e85f44..0852f2c5ce 100644 --- a/core/receipt.go +++ b/core/receipt.go @@ -74,6 +74,7 @@ func receiptCommitment(receipts []*TransactionReceipt) (*felt.Felt, error) { ) } +// TODO(maksymmalick): change this to trie2 after integration done type ( onTempTrieFunc func(uint8, func(*trie.Trie) error) error processFunc[T any] func(T) *felt.Felt diff --git a/core/state/cache.go b/core/state/cache.go index e20a840950..a0d839f43f 100644 --- a/core/state/cache.go +++ b/core/state/cache.go @@ -81,7 +81,9 @@ func (c *stateCache) PopLayer(stateRoot, parentRoot *felt.Felt) error { node, ok := c.rootMap[*stateRoot] if !ok { - return fmt.Errorf("layer with state root %v not found", stateRoot) + // There should be no error when layer is not found in the cache. + // The layer might not be cached (i. e. after node shutdown). + return nil } if node.child != nil { diff --git a/core/state/cache_test.go b/core/state/cache_test.go index 74c37204f6..924a9ed874 100644 --- a/core/state/cache_test.go +++ b/core/state/cache_test.go @@ -125,8 +125,7 @@ func TestStateCache(t *testing.T) { root := new(felt.Felt).SetUint64(1) err := cache.PopLayer(root, &felt.Zero) - require.Error(t, err) - assert.Contains(t, err.Error(), "layer with state root") + require.NoError(t, err) }) t.Run("push and pop multiple layers with no changes", func(t *testing.T) { diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 860ac65178..288fedcdc7 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -7,9 +7,11 @@ import ( "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" ) -type CommonState interface { +//go:generate mockgen -destination=../../mocks/mock_commonstate_reader.go -package=mocks github.com/NethermindEth/juno/core/state/commonstate StateReader +type State interface { StateReader ContractStorageAt(addr, key *felt.Felt, blockNumber uint64) (felt.Felt, error) @@ -19,6 +21,7 @@ type CommonState interface { Update(blockNum uint64, update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool) error Revert(blockNum uint64, update *core.StateUpdate) error + Commitment() (felt.Felt, error) } type StateReader interface { @@ -27,9 +30,9 @@ type StateReader interface { ContractStorage(addr, key *felt.Felt) (felt.Felt, error) Class(classHash *felt.Felt) (*core.DeclaredClass, error) - ClassTrie() (commontrie.CommonTrie, error) - ContractTrie() (commontrie.CommonTrie, error) - ContractStorageTrie(addr *felt.Felt) (commontrie.CommonTrie, error) + ClassTrie() (commontrie.Trie, error) + ContractTrie() (commontrie.Trie, error) + ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) } // CoreStateAdapter wraps core.State to implement CommonState @@ -97,7 +100,7 @@ func (csa *CoreStateAdapter) Class(classHash *felt.Felt) (*core.DeclaredClass, e return csa.State.Class(classHash) } -func (csa *CoreStateAdapter) ClassTrie() (commontrie.CommonTrie, error) { +func (csa *CoreStateAdapter) ClassTrie() (commontrie.Trie, error) { t, err := csa.State.ClassTrie() if err != nil { return nil, err @@ -105,7 +108,7 @@ func (csa *CoreStateAdapter) ClassTrie() (commontrie.CommonTrie, error) { return commontrie.NewTrieAdapter(t), nil } -func (csa *CoreStateAdapter) ContractTrie() (commontrie.CommonTrie, error) { +func (csa *CoreStateAdapter) ContractTrie() (commontrie.Trie, error) { t, err := csa.State.ContractTrie() if err != nil { return nil, err @@ -113,7 +116,7 @@ func (csa *CoreStateAdapter) ContractTrie() (commontrie.CommonTrie, error) { return commontrie.NewTrieAdapter(t), nil } -func (csa *CoreStateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.CommonTrie, error) { +func (csa *CoreStateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { t, err := csa.State.ContractStorageTrie(addr) if err != nil { return nil, err @@ -121,6 +124,11 @@ func (csa *CoreStateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.Co return commontrie.NewTrieAdapter(t), nil } +func (csa *CoreStateAdapter) Commitment() (felt.Felt, error) { + root, err := csa.State.Root() + return *root, err +} + // StateAdapter wraps state.State to implement CommonState type StateAdapter struct { *state.State @@ -130,7 +138,7 @@ func NewStateAdapter(s *state.State) *StateAdapter { return &StateAdapter{State: s} } -func (sa *StateAdapter) ClassTrie() (commontrie.CommonTrie, error) { +func (sa *StateAdapter) ClassTrie() (commontrie.Trie, error) { t, err := sa.State.ClassTrie() if err != nil { return nil, err @@ -138,7 +146,7 @@ func (sa *StateAdapter) ClassTrie() (commontrie.CommonTrie, error) { return commontrie.NewTrie2Adapter(t), nil } -func (sa *StateAdapter) ContractTrie() (commontrie.CommonTrie, error) { +func (sa *StateAdapter) ContractTrie() (commontrie.Trie, error) { t, err := sa.State.ContractTrie() if err != nil { return nil, err @@ -146,7 +154,7 @@ func (sa *StateAdapter) ContractTrie() (commontrie.CommonTrie, error) { return commontrie.NewTrie2Adapter(t), nil } -func (sa *StateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.CommonTrie, error) { +func (sa *StateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { t, err := sa.State.ContractStorageTrie(addr) if err != nil { return nil, err @@ -154,26 +162,134 @@ func (sa *StateAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.CommonT return commontrie.NewTrie2Adapter(t), nil } +type CoreStateReaderAdapter struct { + core.StateReader +} + +func NewCoreStateReaderAdapter(s core.StateReader) *CoreStateReaderAdapter { + return &CoreStateReaderAdapter{StateReader: s} +} + +func (ssa *CoreStateReaderAdapter) Class(classHash *felt.Felt) (*core.DeclaredClass, error) { + return ssa.StateReader.Class(classHash) +} + +func (ssa *CoreStateReaderAdapter) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { + classHash, err := ssa.StateReader.ContractClassHash(addr) + if err != nil { + return felt.Zero, err + } + return *classHash, nil +} + +func (ssa *CoreStateReaderAdapter) ContractNonce(addr *felt.Felt) (felt.Felt, error) { + nonce, err := ssa.StateReader.ContractNonce(addr) + if err != nil { + return felt.Zero, err + } + return *nonce, nil +} + +func (ssa *CoreStateReaderAdapter) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) { + value, err := ssa.StateReader.ContractStorage(addr, key) + if err != nil { + return felt.Zero, err + } + return *value, nil +} + +func (ssa *CoreStateReaderAdapter) ClassTrie() (commontrie.Trie, error) { + t, err := ssa.StateReader.ClassTrie() + if err != nil { + return nil, err + } + return commontrie.NewTrieAdapter(t), nil +} + +func (ssa *CoreStateReaderAdapter) ContractTrie() (commontrie.Trie, error) { + t, err := ssa.StateReader.ContractTrie() + if err != nil { + return nil, err + } + return commontrie.NewTrieAdapter(t), nil +} + +func (ssa *CoreStateReaderAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { + t, err := ssa.StateReader.ContractStorageTrie(addr) + if err != nil { + return nil, err + } + return commontrie.NewTrieAdapter(t), nil +} + +type StateReaderAdapter struct { + state.StateReader +} + +func NewStateReaderAdapter(s state.StateReader) *StateReaderAdapter { + return &StateReaderAdapter{StateReader: s} +} + +func (sha *StateReaderAdapter) Class(classHash *felt.Felt) (*core.DeclaredClass, error) { + return sha.StateReader.Class(classHash) +} + +func (sha *StateReaderAdapter) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { + return sha.StateReader.ContractClassHash(addr) +} + +func (sha *StateReaderAdapter) ContractNonce(addr *felt.Felt) (felt.Felt, error) { + return sha.StateReader.ContractNonce(addr) +} + +func (sha *StateReaderAdapter) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) { + return sha.StateReader.ContractStorage(addr, key) +} + +func (sha *StateReaderAdapter) ClassTrie() (commontrie.Trie, error) { + t, err := sha.StateReader.ClassTrie() + if err != nil { + return nil, err + } + return commontrie.NewTrie2Adapter(t), nil +} + +func (sha *StateReaderAdapter) ContractTrie() (commontrie.Trie, error) { + t, err := sha.StateReader.ContractTrie() + if err != nil { + return nil, err + } + return commontrie.NewTrie2Adapter(t), nil +} + +func (sha *StateReaderAdapter) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { + t, err := sha.StateReader.ContractStorageTrie(addr) + if err != nil { + return nil, err + } + return commontrie.NewTrie2Adapter(t), nil +} + type StateFactory struct { - useNewState bool + UseNewState bool triedb *triedb.Database stateDB *state.StateDB } func NewStateFactory(useNewState bool, triedb *triedb.Database, stateDB *state.StateDB) (*StateFactory, error) { if !useNewState { - return &StateFactory{useNewState: false}, nil + return &StateFactory{UseNewState: false}, nil } return &StateFactory{ - useNewState: true, + UseNewState: true, triedb: triedb, stateDB: stateDB, }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (CommonState, error) { - if !sf.useNewState { +func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (State, error) { + if !sf.UseNewState { coreState := core.NewState(txn) return NewCoreStateAdapter(coreState), nil } @@ -184,3 +300,31 @@ func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (Com } return NewStateAdapter(stateState), nil } + +func (sf *StateFactory) NewStateReader(stateRoot *felt.Felt, txn db.IndexedBatch, blockNumber uint64) (StateReader, error) { + if !sf.UseNewState { + coreState := core.NewState(txn) + snapshot := core.NewStateSnapshot(coreState, blockNumber) + return NewCoreStateReaderAdapter(snapshot), nil + } + + history, err := state.NewStateHistory(blockNumber, stateRoot, sf.stateDB) + if err != nil { + return nil, err + } + return NewStateReaderAdapter(&history), nil +} + +func (sf *StateFactory) EmptyState() (StateReader, error) { + if !sf.UseNewState { + memDB := memory.New() + txn := memDB.NewIndexedBatch() + emptyState := core.NewState(txn) + return NewCoreStateReaderAdapter(emptyState), nil + } + state, err := state.New(&felt.Zero, sf.stateDB) + if err != nil { + return nil, err + } + return NewStateReaderAdapter(state), nil +} diff --git a/core/state/commontrie/trie.go b/core/state/commontrie/trie.go index 42471b9a27..48b812c5c8 100644 --- a/core/state/commontrie/trie.go +++ b/core/state/commontrie/trie.go @@ -7,65 +7,65 @@ import ( "github.com/NethermindEth/juno/core/trie2" ) -type CommonTrie interface { +type Trie interface { Update(key, value *felt.Felt) error Get(key *felt.Felt) (felt.Felt, error) - Hash() felt.Felt + Hash() (felt.Felt, error) HashFn() crypto.HashFn } // TrieAdapter wraps trie.Trie to implement commontrie.CommonTrie type TrieAdapter struct { - trie *trie.Trie + Trie *trie.Trie } func NewTrieAdapter(t *trie.Trie) *TrieAdapter { - return &TrieAdapter{trie: t} + return &TrieAdapter{Trie: t} } func (ta *TrieAdapter) Update(key, value *felt.Felt) error { - _, err := ta.trie.Put(key, value) + _, err := ta.Trie.Put(key, value) return err } func (ta *TrieAdapter) Get(key *felt.Felt) (felt.Felt, error) { - value, err := ta.trie.Get(key) + value, err := ta.Trie.Get(key) if err != nil { return felt.Zero, err } return *value, nil } -func (ta *TrieAdapter) Hash() felt.Felt { - root, _ := ta.trie.Root() - return *root +func (ta *TrieAdapter) Hash() (felt.Felt, error) { + root, err := ta.Trie.Root() + return *root, err } func (ta *TrieAdapter) HashFn() crypto.HashFn { - return ta.trie.HashFn() + return ta.Trie.HashFn() } // Trie2Adapter wraps trie2.Trie to implement commontrie.CommonTrie type Trie2Adapter struct { - trie *trie2.Trie + Trie *trie2.Trie } func NewTrie2Adapter(t *trie2.Trie) *Trie2Adapter { - return &Trie2Adapter{trie: t} + return &Trie2Adapter{Trie: t} } func (ta *Trie2Adapter) Update(key, value *felt.Felt) error { - return ta.trie.Update(key, value) + return ta.Trie.Update(key, value) } func (ta *Trie2Adapter) Get(key *felt.Felt) (felt.Felt, error) { - return ta.trie.Get(key) + return ta.Trie.Get(key) } -func (ta *Trie2Adapter) Hash() felt.Felt { - return ta.trie.Hash() +func (ta *Trie2Adapter) Hash() (felt.Felt, error) { + return ta.Trie.Hash(), nil } func (ta *Trie2Adapter) HashFn() crypto.HashFn { - return ta.trie.HashFn() + return ta.Trie.HashFn() } diff --git a/core/state/history.go b/core/state/history.go index decce06583..af680d686e 100644 --- a/core/state/history.go +++ b/core/state/history.go @@ -16,14 +16,13 @@ type StateHistory struct { } func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (StateHistory, error) { - state, err := New(stateRoot, db) - if err != nil { - return StateHistory{}, err - } - return StateHistory{ blockNum: blockNum, - state: state, + state: &State{ + initRoot: *stateRoot, + db: db, + stateObjects: make(map[felt.Felt]*stateObject), + }, }, nil } diff --git a/core/state/history_test.go b/core/state/history_test.go index bc09a7a6c3..bcfbff70b1 100644 --- a/core/state/history_test.go +++ b/core/state/history_test.go @@ -20,12 +20,6 @@ func TestNewStateHistory(t *testing.T) { assert.Equal(t, uint64(0), history.blockNum) assert.NotNil(t, history.state) }) - - t.Run("invalid state root", func(t *testing.T) { - invalidRoot := new(felt.Felt).SetUint64(999) // Non-existent root - _, err := NewStateHistory(1, invalidRoot, stateDB) - assert.Error(t, err) - }) } func TestStateHistoryContractOperations(t *testing.T) { diff --git a/core/state/state.go b/core/state/state.go index f4fb5ac372..2272868a39 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -25,12 +25,13 @@ var ( noClassContractsClassHash = felt.Zero noClassContracts = map[felt.Felt]struct{}{ *new(felt.Felt).SetUint64(1): {}, + *new(felt.Felt).SetUint64(2): {}, } ) var _ StateReader = &State{} -//go:generate mockgen -destination=../../mocks/mock_state_reader.go -package=mocks github.com/NethermindEth/juno/core/state StateReader +// TODO(maksym): add mock generation after integration complete type StateReader interface { ContractReader ClassReader @@ -218,8 +219,11 @@ func (s *State) Update( } // Check if the new commitment matches the one in state diff - if !newComm.Equal(update.NewRoot) { - return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.NewRoot, &newComm) + // The following check isn't relevant for the centralised Juno sequencer + if !skipVerifyNewRoot { + if !newComm.Equal(update.NewRoot) { + return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.NewRoot, &newComm) + } } s.db.stateCache.PushLayer(&newComm, &stateUpdate.prevComm, &diffCache{ @@ -566,7 +570,7 @@ func (s *State) verifyComm(comm *felt.Felt) error { } if !curComm.Equal(comm) { - return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", comm, curComm) + return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", comm, &curComm) } return nil diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go index 3c9dc72676..6a88c6cd45 100644 --- a/core/trie2/triedb/database.go +++ b/core/trie2/triedb/database.go @@ -12,6 +12,11 @@ import ( "github.com/NethermindEth/juno/db" ) +const ( + PathScheme string = "path" + HashScheme string = "hash" +) + type Config struct { PathConfig *pathdb.Config HashConfig *hashdb.Config @@ -63,6 +68,23 @@ func (d *Database) Update( } } +func (d *Database) Journal(root *felt.Felt) error { + pdb, ok := d.triedb.(*pathdb.Database) + if !ok { + return fmt.Errorf("unsupported trie db type: %T", d.triedb) + } + return pdb.Journal(root) +} + +func (d *Database) Scheme() string { + if d.config == nil { + return PathScheme + } else if d.config.PathConfig != nil { + return PathScheme + } + return HashScheme +} + func (d *Database) NodeReader(id trieutils.TrieID) (database.NodeReader, error) { return d.triedb.NodeReader(id) } diff --git a/core/trie2/triedb/pathdb/buffer.go b/core/trie2/triedb/pathdb/buffer.go index a52b8acd43..a27047c2fa 100644 --- a/core/trie2/triedb/pathdb/buffer.go +++ b/core/trie2/triedb/pathdb/buffer.go @@ -39,7 +39,6 @@ func (b *buffer) commit(nodes *nodeSet) *buffer { func (b *buffer) reset() { b.layers = 0 - b.limit = 0 b.nodes.reset() } @@ -48,7 +47,15 @@ func (b *buffer) isFull() bool { } func (b *buffer) flush(kvs db.KeyValueStore, cleans *cleanCache, id uint64) error { - latestPersistedID, _ := trieutils.ReadPersistedStateID(kvs) + latestPersistedID, err := trieutils.ReadPersistedStateID(kvs) + if err != nil { + if err == db.ErrKeyNotFound { + latestPersistedID = 0 + } else { + return err + } + } + if latestPersistedID+b.layers != id { return fmt.Errorf( "mismatch buffer layers applied: latest state id (%d) + buffer layers (%d) != target state id (%d)", @@ -58,10 +65,17 @@ func (b *buffer) flush(kvs db.KeyValueStore, cleans *cleanCache, id uint64) erro ) } - batch := kvs.NewBatchWithSize(b.nodes.dbSize()) + dbSize := b.nodes.dbSize() + + batch := kvs.NewBatchWithSize(dbSize) + if batch == nil { + return fmt.Errorf("failed to create batch") + } + if err := b.nodes.write(batch, cleans); err != nil { return err } + if err := trieutils.WritePersistedStateID(batch, id); err != nil { return err } diff --git a/core/trie2/triedb/pathdb/database.go b/core/trie2/triedb/pathdb/database.go index 8c178a34ec..d0276501bc 100644 --- a/core/trie2/triedb/pathdb/database.go +++ b/core/trie2/triedb/pathdb/database.go @@ -51,6 +51,12 @@ func New(disk db.KeyValueStore, config *Config) (*Database, error) { } func (d *Database) Close() error { + diskLayerHash := d.tree.diskLayer().rootHash() + err := d.Journal(diskLayerHash) + if err != nil { + return err + } + d.lock.Lock() defer d.lock.Unlock() diff --git a/core/trie2/triedb/pathdb/difflayer.go b/core/trie2/triedb/pathdb/difflayer.go index df36aaa1de..58eba5a81f 100644 --- a/core/trie2/triedb/pathdb/difflayer.go +++ b/core/trie2/triedb/pathdb/difflayer.go @@ -5,7 +5,9 @@ import ( "sync" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" ) var _ layer = (*diffLayer)(nil) @@ -38,10 +40,13 @@ func (dl *diffLayer) node(id trieutils.TrieID, owner *felt.Felt, path *trieutils isClass := id.Type() == trieutils.Class n, ok := dl.nodes.node(owner, path, isClass) if ok { + if _, deleted := n.(*trienode.DeletedNode); deleted { + return nil, db.ErrKeyNotFound + } return n.Blob(), nil } - return dl.parent.node(id, owner, path, isClass) + return dl.parent.node(id, owner, path, isLeaf) } func (dl *diffLayer) rootHash() *felt.Felt { diff --git a/core/trie2/triedb/pathdb/disklayer.go b/core/trie2/triedb/pathdb/disklayer.go index c920337725..029dfc366b 100644 --- a/core/trie2/triedb/pathdb/disklayer.go +++ b/core/trie2/triedb/pathdb/disklayer.go @@ -4,7 +4,9 @@ import ( "sync" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" ) var _ layer = (*diskLayer)(nil) @@ -67,6 +69,9 @@ func (dl *diskLayer) node(id trieutils.TrieID, owner *felt.Felt, path *trieutils isClass := id.Type() == trieutils.Class n, ok := dl.dirties.node(owner, path, isClass) if ok { + if _, deleted := n.(*trienode.DeletedNode); deleted { + return nil, db.ErrKeyNotFound + } return n.Blob(), nil } @@ -77,7 +82,7 @@ func (dl *diskLayer) node(id trieutils.TrieID, owner *felt.Felt, path *trieutils } // Finally, read from disk - blob, err := trieutils.GetNodeByPath(dl.db.disk, id.Bucket(), owner, path, isClass) + blob, err := trieutils.GetNodeByPath(dl.db.disk, id.Bucket(), owner, path, isLeaf) if err != nil { return nil, err } diff --git a/core/trie2/triedb/pathdb/nodeset.go b/core/trie2/triedb/pathdb/nodeset.go index 0bd3c89c3f..70e1ec52f7 100644 --- a/core/trie2/triedb/pathdb/nodeset.go +++ b/core/trie2/triedb/pathdb/nodeset.go @@ -231,6 +231,7 @@ func (s *nodeSet) decode(data []byte) error { } } } + s.computeSize() return nil } diff --git a/core/trie2/triedb/pathdb/reader.go b/core/trie2/triedb/pathdb/reader.go index d0ef2b80ef..a652c85f19 100644 --- a/core/trie2/triedb/pathdb/reader.go +++ b/core/trie2/triedb/pathdb/reader.go @@ -24,7 +24,7 @@ func (d *Database) NodeReader(id trieutils.TrieID) (database.NodeReader, error) stateComm := id.StateComm() l := d.tree.get(&stateComm) if l == nil { - return nil, fmt.Errorf("layer %v not found", id.StateComm()) + return nil, fmt.Errorf("layer %v not found", &stateComm) } return &reader{id: id, l: l}, nil } diff --git a/core/trie2/trieutils/accessors.go b/core/trie2/trieutils/accessors.go index 3efa77de41..ce9d31adc1 100644 --- a/core/trie2/trieutils/accessors.go +++ b/core/trie2/trieutils/accessors.go @@ -78,7 +78,7 @@ func WritePersistedStateID(w db.KeyValueWriter, id uint64) error { func ReadTrieJournal(r db.KeyValueReader) ([]byte, error) { var journal []byte if err := r.Get(db.TrieJournal.Key(), func(value []byte) error { - journal = value + journal = append([]byte(nil), value...) return nil }); err != nil { return nil, err diff --git a/genesis/genesis.go b/genesis/genesis.go index 3e3c2fb210..7024b8d34b 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -9,6 +9,9 @@ import ( "github.com/NethermindEth/juno/adapters/vm2core" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" rpc "github.com/NethermindEth/juno/rpc/v8" "github.com/NethermindEth/juno/starknet" @@ -104,10 +107,26 @@ func GenesisStateDiff( ) (core.StateDiff, map[felt.Felt]core.Class, error) { initialStateDiff := core.EmptyStateDiff() memDB := memory.New() + triedb, err := triedb.New(memDB, nil) + if err != nil { + return core.StateDiff{}, nil, err + } + stateDB := state.NewStateDB(memDB, triedb) + + // TODO(maksymmalick): remove this after integration done + stateFactory, err := commonstate.NewStateFactory(false, triedb, stateDB) + if err != nil { + return core.StateDiff{}, nil, err + } + state, err := stateFactory.NewState(&felt.Zero, memDB.NewIndexedBatch()) + if err != nil { + return core.StateDiff{}, nil, err + } + // genesisState := sync.NewPendingStateWriter( &initialStateDiff, make(map[felt.Felt]core.Class, len(config.Classes)), - core.NewState(memDB.NewIndexedBatch()), + state, ) classhashToSierraVersion, err := declareClasses(config, &genesisState) @@ -250,7 +269,7 @@ func executeFunctionCalls( callInfo := &vm.CallInfo{ ContractAddress: &contractAddress, - ClassHash: classHash, + ClassHash: &classHash, Selector: &entryPointSelector, Calldata: fnCall.Calldata, } diff --git a/mempool/mempool.go b/mempool/mempool.go index 3a2189751a..2d165e2c6b 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -288,7 +288,7 @@ func (p *SequencerMempool) validate(userTxn *BroadcastedTransaction) error { return fmt.Errorf("deploy transactions are not supported") case *core.DeployAccountTransaction: if !t.Nonce.IsZero() { - return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce) + return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce.String()) } case *core.DeclareTransaction: nonce, err := state.ContractNonce(t.SenderAddress) diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index fb38e4207d..2c08048aa6 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -46,7 +46,7 @@ func TestMempool(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) chain := mocks.NewMockReader(mockCtrl) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) require.NoError(t, err) defer dbCloser() @@ -121,7 +121,7 @@ func TestRestoreMempool(t *testing.T) { log := utils.NewNopZapLogger() mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) chain := mocks.NewMockReader(mockCtrl) testDB, dbDeleter, err := setupDatabase("testrestoremempool", true) require.NoError(t, err) @@ -233,7 +233,7 @@ func TestPopBatch(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) chain := mocks.NewMockReader(mockCtrl) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) require.NoError(t, err) defer dbCloser() diff --git a/mocks/mock_blockchain.go b/mocks/mock_blockchain.go index 44a45d4cf6..f21eec21db 100644 --- a/mocks/mock_blockchain.go +++ b/mocks/mock_blockchain.go @@ -15,6 +15,7 @@ import ( blockchain "github.com/NethermindEth/juno/blockchain" core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" + commonstate "github.com/NethermindEth/juno/core/state/commonstate" utils "github.com/NethermindEth/juno/utils" common "github.com/ethereum/go-ethereum/common" gomock "go.uber.org/mock/gomock" @@ -150,10 +151,10 @@ func (mr *MockReaderMockRecorder) Head() *gomock.Call { } // HeadState mocks base method. -func (m *MockReader) HeadState() (core.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) HeadState() (commonstate.StateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadState") - ret0, _ := ret[0].(core.StateReader) + ret0, _ := ret[0].(commonstate.StateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -196,10 +197,10 @@ func (mr *MockReaderMockRecorder) Height() *gomock.Call { } // L1HandlerTxnHash mocks base method. -func (m *MockReader) L1HandlerTxnHash(msgHash *common.Hash) (*felt.Felt, error) { +func (m *MockReader) L1HandlerTxnHash(msgHash *common.Hash) (felt.Felt, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "L1HandlerTxnHash", msgHash) - ret0, _ := ret[0].(*felt.Felt) + ret0, _ := ret[0].(felt.Felt) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -211,10 +212,10 @@ func (mr *MockReaderMockRecorder) L1HandlerTxnHash(msgHash any) *gomock.Call { } // L1Head mocks base method. -func (m *MockReader) L1Head() (*core.L1Head, error) { +func (m *MockReader) L1Head() (core.L1Head, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "L1Head") - ret0, _ := ret[0].(*core.L1Head) + ret0, _ := ret[0].(core.L1Head) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -257,10 +258,10 @@ func (mr *MockReaderMockRecorder) Receipt(hash any) *gomock.Call { } // StateAtBlockHash mocks base method. -func (m *MockReader) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StateAtBlockHash", blockHash) - ret0, _ := ret[0].(core.StateReader) + ret0, _ := ret[0].(commonstate.StateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -273,10 +274,10 @@ func (mr *MockReaderMockRecorder) StateAtBlockHash(blockHash any) *gomock.Call { } // StateAtBlockNumber mocks base method. -func (m *MockReader) StateAtBlockNumber(blockNumber uint64) (core.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StateAtBlockNumber", blockNumber) - ret0, _ := ret[0].(core.StateReader) + ret0, _ := ret[0].(commonstate.StateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 diff --git a/mocks/mock_state_reader.go b/mocks/mock_commonstate_reader.go similarity index 89% rename from mocks/mock_state_reader.go rename to mocks/mock_commonstate_reader.go index 18abcb2f14..5586e499c4 100644 --- a/mocks/mock_state_reader.go +++ b/mocks/mock_commonstate_reader.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/NethermindEth/juno/core/state (interfaces: StateReader) +// Source: github.com/NethermindEth/juno/core/state/commonstate (interfaces: StateReader) // // Generated by this command: // -// mockgen -destination=../../mocks/mock_state_reader.go -package=mocks github.com/NethermindEth/juno/core/state StateReader +// mockgen -destination=mocks/mock_commonstate_reader.go -package=mocks github.com/NethermindEth/juno/core/state/commonstate StateReader // // Package mocks is a generated GoMock package. @@ -14,7 +14,7 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - trie2 "github.com/NethermindEth/juno/core/trie2" + commontrie "github.com/NethermindEth/juno/core/state/commontrie" gomock "go.uber.org/mock/gomock" ) @@ -58,10 +58,10 @@ func (mr *MockStateReaderMockRecorder) Class(classHash any) *gomock.Call { } // ClassTrie mocks base method. -func (m *MockStateReader) ClassTrie() (*trie2.Trie, error) { +func (m *MockStateReader) ClassTrie() (commontrie.Trie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClassTrie") - ret0, _ := ret[0].(*trie2.Trie) + ret0, _ := ret[0].(commontrie.Trie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -118,10 +118,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorage(addr, key any) *gomock.Ca } // ContractStorageTrie mocks base method. -func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (*trie2.Trie, error) { +func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractStorageTrie", addr) - ret0, _ := ret[0].(*trie2.Trie) + ret0, _ := ret[0].(commontrie.Trie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -133,10 +133,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorageTrie(addr any) *gomock.Cal } // ContractTrie mocks base method. -func (m *MockStateReader) ContractTrie() (*trie2.Trie, error) { +func (m *MockStateReader) ContractTrie() (commontrie.Trie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractTrie") - ret0, _ := ret[0].(*trie2.Trie) + ret0, _ := ret[0].(commontrie.Trie) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/mocks/mock_gateway_handler.go b/mocks/mock_gateway.go similarity index 87% rename from mocks/mock_gateway_handler.go rename to mocks/mock_gateway.go index f9bbfc50ea..0ae47df050 100644 --- a/mocks/mock_gateway_handler.go +++ b/mocks/mock_gateway.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/NethermindEth/juno/rpc (interfaces: Gateway) +// Source: github.com/NethermindEth/juno/rpc/rpccore (interfaces: Gateway) // // Generated by this command: // -// mockgen -destination=../mocks/mock_gateway_handler.go -package=mocks github.com/NethermindEth/juno/rpc Gateway +// mockgen -destination=mocks/mock_gateway.go -package=mocks github.com/NethermindEth/juno/rpc/rpccore Gateway // // Package mocks is a generated GoMock package. @@ -21,6 +21,7 @@ import ( type MockGateway struct { ctrl *gomock.Controller recorder *MockGatewayMockRecorder + isgomock struct{} } // MockGatewayMockRecorder is the mock recorder for MockGateway. diff --git a/mocks/mock_synchronizer.go b/mocks/mock_synchronizer.go index 21b139cb56..a0b6d500d6 100644 --- a/mocks/mock_synchronizer.go +++ b/mocks/mock_synchronizer.go @@ -13,6 +13,7 @@ import ( reflect "reflect" core "github.com/NethermindEth/juno/core" + commonstate "github.com/NethermindEth/juno/core/state/commonstate" sync "github.com/NethermindEth/juno/sync" gomock "go.uber.org/mock/gomock" ) @@ -85,10 +86,10 @@ func (mr *MockSyncReaderMockRecorder) PendingData() *gomock.Call { } // PendingState mocks base method. -func (m *MockSyncReader) PendingState() (core.StateReader, func() error, error) { +func (m *MockSyncReader) PendingState() (commonstate.StateReader, func() error, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PendingState") - ret0, _ := ret[0].(core.StateReader) + ret0, _ := ret[0].(commonstate.StateReader) ret1, _ := ret[1].(func() error) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -101,10 +102,10 @@ func (mr *MockSyncReaderMockRecorder) PendingState() *gomock.Call { } // PendingStateBeforeIndex mocks base method. -func (m *MockSyncReader) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (m *MockSyncReader) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PendingStateBeforeIndex", index) - ret0, _ := ret[0].(core.StateReader) + ret0, _ := ret[0].(commonstate.StateReader) ret1, _ := ret[1].(func() error) ret2, _ := ret[2].(error) return ret0, ret1, ret2 diff --git a/mocks/mock_vm.go b/mocks/mock_vm.go index 2c3e1f9bdc..bedaa546cc 100644 --- a/mocks/mock_vm.go +++ b/mocks/mock_vm.go @@ -14,6 +14,7 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" + commonstate "github.com/NethermindEth/juno/core/state/commonstate" utils "github.com/NethermindEth/juno/utils" vm "github.com/NethermindEth/juno/vm" gomock "go.uber.org/mock/gomock" @@ -44,7 +45,7 @@ func (m *MockVM) EXPECT() *MockVMMockRecorder { } // Call mocks base method. -func (m *MockVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, structuredErrStack, returnStateDiff bool) (vm.CallResult, error) { +func (m *MockVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state commonstate.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, structuredErrStack, returnStateDiff bool) (vm.CallResult, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Call", callInfo, blockInfo, state, network, maxSteps, sierraVersion, structuredErrStack, returnStateDiff) ret0, _ := ret[0].(vm.CallResult) @@ -59,7 +60,7 @@ func (mr *MockVMMockRecorder) Call(callInfo, blockInfo, state, network, maxSteps } // Execute mocks base method. -func (m *MockVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool) (vm.ExecutionResults, error) { +func (m *MockVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state commonstate.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool) (vm.ExecutionResults, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Execute", txns, declaredClasses, paidFeesOnL1, blockInfo, state, network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch) ret0, _ := ret[0].(vm.ExecutionResults) diff --git a/node/node.go b/node/node.go index 0b3ea32a1e..2ae9729e4c 100644 --- a/node/node.go +++ b/node/node.go @@ -441,6 +441,13 @@ func (n *Node) Run(ctx context.Context) { } }() + defer func() { + if err := n.blockchain.Stop(); err != nil { + n.log.Errorw("Error while stopping the blockchain", "err", err) + } + n.log.Infow("TrieDB Journal saved") + }() + cfg := make(map[string]any) err := mapstructure.Decode(n.cfg, &cfg) if err != nil { diff --git a/node/throttled_vm.go b/node/throttled_vm.go index 807d376e21..3a08bf4f55 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -3,6 +3,7 @@ package node import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/utils" "github.com/NethermindEth/juno/vm" ) @@ -19,7 +20,7 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott } } -func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state core.StateReader, +func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state commonstate.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, errStack, returnStateDiff bool, ) (vm.CallResult, error) { ret := vm.CallResult{} @@ -31,7 +32,7 @@ func (tvm *ThrottledVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, sta } func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, - blockInfo *vm.BlockInfo, state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, + blockInfo *vm.BlockInfo, state commonstate.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, ) (vm.ExecutionResults, error) { var executionResult vm.ExecutionResults diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index c3256e4f70..9509a73717 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -45,6 +45,7 @@ func TestPlugin(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), timeout) require.NoError(t, synchronizer.Run(ctx)) cancel() + require.NoError(t, bc.Stop()) t.Run("resync to mainnet with the same db", func(t *testing.T) { bc := blockchain.New(testDB, &utils.Mainnet) diff --git a/rpc/v6/class.go b/rpc/v6/class.go index df20bdd5fc..9850ac8774 100644 --- a/rpc/v6/class.go +++ b/rpc/v6/class.go @@ -168,5 +168,5 @@ func (h *Handler) ClassHashAt(id BlockID, address felt.Felt) (*felt.Felt, *jsonr return nil, rpccore.ErrContractNotFound } - return classHash, nil + return &classHash, nil } diff --git a/rpc/v6/class_test.go b/rpc/v6/class_test.go index 45e82af08e..9e4becdd3c 100644 --- a/rpc/v6/class_test.go +++ b/rpc/v6/class_test.go @@ -28,7 +28,7 @@ func TestClass(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(gomock.Any()).DoAndReturn(func(classHash *felt.Felt) (*core.DeclaredClass, error) { class, err := integGw.Class(t.Context(), classHash) @@ -80,7 +80,7 @@ func TestClass(t *testing.T) { t.Run("class hash not found error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) handler := rpc.New(mockReader, nil, nil, n, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(mockState, func() error { @@ -103,7 +103,7 @@ func TestClassAt(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) cairo0ContractAddress, _ := new(felt.Felt).SetRandom() cairo0ClassHash := utils.HexToFelt(t, "0x4631b6b3fa31e140524b7d21ba784cea223e618bffe60b5bbdca44a8b45be04") @@ -181,7 +181,7 @@ func TestClassHashAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v6/contract.go b/rpc/v6/contract.go index 9c78f799d1..b21fcc1aae 100644 --- a/rpc/v6/contract.go +++ b/rpc/v6/contract.go @@ -29,7 +29,7 @@ func (h *Handler) Nonce(id BlockID, address felt.Felt) (*felt.Felt, *jsonrpc.Err return nil, rpccore.ErrContractNotFound } - return nonce, nil + return &nonce, nil } // StorageAt gets the value of the storage at the given address and key for a given block. @@ -61,5 +61,5 @@ func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *js return nil, rpccore.ErrInternal } - return value, nil + return &value, nil } diff --git a/rpc/v6/contract_test.go b/rpc/v6/contract_test.go index a35df6769a..07ae7a2937 100644 --- a/rpc/v6/contract_test.go +++ b/rpc/v6/contract_test.go @@ -50,7 +50,7 @@ func TestNonce(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -136,7 +136,7 @@ func TestStorageAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v6/estimate_fee_test.go b/rpc/v6/estimate_fee_test.go index 1e9adbf708..f7b4393e58 100644 --- a/rpc/v6/estimate_fee_test.go +++ b/rpc/v6/estimate_fee_test.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -46,7 +47,7 @@ func TestEstimateMessageFee(t *testing.T) { Timestamp: 456, L1GasPriceETH: new(felt.Felt).SetUint64(42), } - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(latestHeader, nil) @@ -56,7 +57,7 @@ func TestEstimateMessageFee(t *testing.T) { Header: latestHeader, }, gomock.Any(), &utils.Mainnet, gomock.Any(), false, true, false, true).DoAndReturn( func(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, - state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, + state commonstate.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, ) (vm.ExecutionResults, error) { require.Len(t, txns, 1) assert.NotNil(t, txns[0].(*core.L1HandlerTransaction)) diff --git a/rpc/v6/handlers_test.go b/rpc/v6/handlers_test.go index 1f5f93d9d1..dc964a7e66 100644 --- a/rpc/v6/handlers_test.go +++ b/rpc/v6/handlers_test.go @@ -34,7 +34,7 @@ func TestThrottledVMError(t *testing.T) { throttledVM := node.NewThrottledVM(mockVM, 0, 0) handler := rpc.New(mockReader, mockSyncReader, throttledVM, &utils.Mainnet, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) throttledErr := "VM throughput limit reached" t.Run("call", func(t *testing.T) { @@ -85,9 +85,9 @@ func TestThrottledVMError(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index 9d5b2e8a45..490562bcbd 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -21,7 +22,7 @@ func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } // nil is returned if l1 head doesn't exist - return l1Head, nil + return &l1Head, nil } func isL1Verified(n uint64, l1 *core.L1Head) bool { @@ -149,8 +150,8 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(id *BlockID) (core.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader core.StateReader +func (h *Handler) stateByBlockID(id *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader commonstate.StateReader var closer blockchain.StateCloser var err error switch { diff --git a/rpc/v6/pending_data_wrapper.go b/rpc/v6/pending_data_wrapper.go index 8b8bbcab1d..5bddea4900 100644 --- a/rpc/v6/pending_data_wrapper.go +++ b/rpc/v6/pending_data_wrapper.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/sync" ) @@ -36,7 +37,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (core.StateReader, func() error, error) { +func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { pending, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, sync.ErrPendingBlockNotFound) { diff --git a/rpc/v6/pending_data_wrapper_test.go b/rpc/v6/pending_data_wrapper_test.go index edf9778038..37db5577b0 100644 --- a/rpc/v6/pending_data_wrapper_test.go +++ b/rpc/v6/pending_data_wrapper_test.go @@ -107,7 +107,7 @@ func TestPendingDataWrapper_PendingState(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) handler := rpc.New(mockReader, mockSyncReader, nil, &utils.Sepolia, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("Returns pending state when starknet version < 0.14.0", func(t *testing.T) { mockSyncReader.EXPECT().PendingData().Return( &sync.Pending{}, diff --git a/rpc/v6/simulation_test.go b/rpc/v6/simulation_test.go index 2fc07d4e21..d702fafb9d 100644 --- a/rpc/v6/simulation_test.go +++ b/rpc/v6/simulation_test.go @@ -29,7 +29,7 @@ func TestSimulateTransactions(t *testing.T) { mockVM := mocks.NewMockVM(mockCtrl) handler := rpc.New(mockReader, nil, mockVM, n, utils.NewNopZapLogger()) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).AnyTimes() headsHeader := &core.Header{ SequencerAddress: n.BlockHashMetaInfo.FallBackSequencerAddress, @@ -214,7 +214,7 @@ func TestSimulateTransactionsShouldErrorWithoutSenderAddressOrResourceBounds(t * mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v6/trace.go b/rpc/v6/trace.go index f6f4c67ed9..b0b10cf27e 100644 --- a/rpc/v6/trace.go +++ b/rpc/v6/trace.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -158,7 +159,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState core.StateReader + headState commonstate.StateReader headStateCloser blockchain.StateCloser ) if isPending { @@ -270,7 +271,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json return nil, rpccore.ErrContractNotFound } - declaredClass, err := state.Class(classHash) + declaredClass, err := state.Class(&classHash) if err != nil { return nil, rpccore.ErrClassHashNotFound } @@ -286,7 +287,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json ContractAddress: &funcCall.ContractAddress, Selector: &funcCall.EntryPointSelector, Calldata: funcCall.Calldata, - ClassHash: classHash, + ClassHash: &classHash, }, &vm.BlockInfo{ Header: header, BlockHashToBeRevealed: blockHashToBeRevealed, diff --git a/rpc/v6/trace_test.go b/rpc/v6/trace_test.go index 7d7f4f9d19..89b80b0204 100644 --- a/rpc/v6/trace_test.go +++ b/rpc/v6/trace_test.go @@ -275,7 +275,7 @@ func TestTraceTransaction(t *testing.T) { mockReader.EXPECT().BlockByHash(header.Hash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -340,7 +340,7 @@ func TestTraceTransaction(t *testing.T) { ).Times(2) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) @@ -554,9 +554,9 @@ func TestTraceBlockTransactions(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) @@ -629,7 +629,7 @@ func TestTraceBlockTransactions(t *testing.T) { mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -1116,7 +1116,7 @@ func TestCall(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v6/transaction_test.go b/rpc/v6/transaction_test.go index 467e772ca1..2dc3aff8c8 100644 --- a/rpc/v6/transaction_test.go +++ b/rpc/v6/transaction_test.go @@ -1498,7 +1498,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) handler := rpc.New(mockReader, nil, nil, test.network, nil) diff --git a/rpc/v7/compiled_casm_test.go b/rpc/v7/compiled_casm_test.go index a9de5862ca..c5dfb5ae60 100644 --- a/rpc/v7/compiled_casm_test.go +++ b/rpc/v7/compiled_casm_test.go @@ -36,7 +36,7 @@ func TestCompiledCasm(t *testing.T) { t.Run("class doesn't exist", func(t *testing.T) { classHash := utils.HexToFelt(t, "0x111") - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(nil, db.ErrKeyNotFound) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -65,7 +65,7 @@ func TestCompiledCasm(t *testing.T) { err = json.Unmarshal(program, &cairo0Definition) require.NoError(t, err) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v7/estimate_fee_test.go b/rpc/v7/estimate_fee_test.go index fee4a59313..4fe250e55f 100644 --- a/rpc/v7/estimate_fee_test.go +++ b/rpc/v7/estimate_fee_test.go @@ -30,7 +30,7 @@ func TestEstimateFee(t *testing.T) { log := utils.NewNopZapLogger() handler := rpcv7.New(mockReader, nil, mockVM, n, log) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil).AnyTimes() diff --git a/rpc/v7/handlers_test.go b/rpc/v7/handlers_test.go index 5d8390bcab..67d309e7e3 100644 --- a/rpc/v7/handlers_test.go +++ b/rpc/v7/handlers_test.go @@ -35,7 +35,7 @@ func TestThrottledVMError(t *testing.T) { throttledVM := node.NewThrottledVM(mockVM, 0, 0) handler := rpcv7.New(mockReader, mockSyncReader, throttledVM, &utils.Mainnet, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) throttledErr := "VM throughput limit reached" t.Run("call", func(t *testing.T) { @@ -87,9 +87,9 @@ func TestThrottledVMError(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) diff --git a/rpc/v7/helpers.go b/rpc/v7/helpers.go index 60859bc8b4..b6a21a8b01 100644 --- a/rpc/v7/helpers.go +++ b/rpc/v7/helpers.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -21,7 +22,7 @@ func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } // nil is returned if l1 head doesn't exist - return l1Head, nil + return &l1Head, nil } func isL1Verified(n uint64, l1 *core.L1Head) bool { @@ -156,8 +157,8 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(id *BlockID) (core.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader core.StateReader +func (h *Handler) stateByBlockID(id *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader commonstate.StateReader var closer blockchain.StateCloser var err error switch { diff --git a/rpc/v7/pending_data_wrapper.go b/rpc/v7/pending_data_wrapper.go index e0f306379a..01e6f0dce9 100644 --- a/rpc/v7/pending_data_wrapper.go +++ b/rpc/v7/pending_data_wrapper.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/sync" ) @@ -36,7 +37,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (core.StateReader, func() error, error) { +func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { pending, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, sync.ErrPendingBlockNotFound) { diff --git a/rpc/v7/pending_data_wrapper_test.go b/rpc/v7/pending_data_wrapper_test.go index 58f807652a..1937678ca4 100644 --- a/rpc/v7/pending_data_wrapper_test.go +++ b/rpc/v7/pending_data_wrapper_test.go @@ -108,7 +108,7 @@ func TestPendingDataWrapper_PendingState(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) handler := rpc.New(mockReader, mockSyncReader, nil, &utils.Sepolia, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("Returns pending state when starknet version < 0.14.0", func(t *testing.T) { mockSyncReader.EXPECT().PendingData().Return( &sync.Pending{}, diff --git a/rpc/v7/simulation_test.go b/rpc/v7/simulation_test.go index 329af0e5fa..e1afb3cada 100644 --- a/rpc/v7/simulation_test.go +++ b/rpc/v7/simulation_test.go @@ -31,7 +31,7 @@ func TestSimulateTransactions(t *testing.T) { mockVM := mocks.NewMockVM(mockCtrl) handler := rpcv7.New(mockReader, nil, mockVM, n, utils.NewNopZapLogger()) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).AnyTimes() headsHeader := &core.Header{ SequencerAddress: n.BlockHashMetaInfo.FallBackSequencerAddress, @@ -208,7 +208,7 @@ func TestSimulateTransactionsShouldErrorWithoutSenderAddressOrResourceBounds(t * mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v7/storage.go b/rpc/v7/storage.go index af51f8f8a5..2797e8d1d1 100644 --- a/rpc/v7/storage.go +++ b/rpc/v7/storage.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -28,7 +29,8 @@ func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *js // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(&address) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } h.log.Errorw("Failed to get contract nonce", "err", err) @@ -40,5 +42,5 @@ func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *js return nil, rpccore.ErrInternal } - return value, nil + return &value, nil } diff --git a/rpc/v7/storage_test.go b/rpc/v7/storage_test.go index 0ac88905a1..9a04a251d3 100644 --- a/rpc/v7/storage_test.go +++ b/rpc/v7/storage_test.go @@ -49,7 +49,7 @@ func TestStorageAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v7/trace.go b/rpc/v7/trace.go index 2db79f6a73..d504e93956 100644 --- a/rpc/v7/trace.go +++ b/rpc/v7/trace.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -196,7 +197,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block) defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState core.StateReader + headState commonstate.StateReader headStateCloser blockchain.StateCloser ) if isPending { @@ -323,7 +324,7 @@ func (h *Handler) Call(funcCall FunctionCall, id BlockID) ([]*felt.Felt, *jsonrp return nil, rpccore.ErrContractNotFound } - declaredClass, err := state.Class(classHash) + declaredClass, err := state.Class(&classHash) if err != nil { return nil, rpccore.ErrClassHashNotFound } @@ -339,7 +340,7 @@ func (h *Handler) Call(funcCall FunctionCall, id BlockID) ([]*felt.Felt, *jsonrp ContractAddress: &funcCall.ContractAddress, Selector: &funcCall.EntryPointSelector, Calldata: funcCall.Calldata, - ClassHash: classHash, + ClassHash: &classHash, }, &vm.BlockInfo{ Header: header, BlockHashToBeRevealed: blockHashToBeRevealed, diff --git a/rpc/v7/trace_test.go b/rpc/v7/trace_test.go index f0290c356b..388bb4a50b 100644 --- a/rpc/v7/trace_test.go +++ b/rpc/v7/trace_test.go @@ -311,7 +311,7 @@ func TestTraceTransaction(t *testing.T) { mockReader.EXPECT().BlockByHash(header.Hash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -411,7 +411,7 @@ func TestTraceTransaction(t *testing.T) { nil, ).Times(2) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) @@ -680,9 +680,9 @@ func TestTraceBlockTransactions(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) @@ -786,7 +786,7 @@ func TestTraceBlockTransactions(t *testing.T) { mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -1502,7 +1502,7 @@ func TestCall(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v8/compiled_casm_test.go b/rpc/v8/compiled_casm_test.go index 11560b0dcc..fdfab5199d 100644 --- a/rpc/v8/compiled_casm_test.go +++ b/rpc/v8/compiled_casm_test.go @@ -37,7 +37,7 @@ func TestCompiledCasm(t *testing.T) { t.Run("class doesn't exist", func(t *testing.T) { classHash := utils.HexToFelt(t, "0x111") - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(nil, db.ErrKeyNotFound) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -66,7 +66,7 @@ func TestCompiledCasm(t *testing.T) { err = json.Unmarshal(program, &cairo0Definition) require.NoError(t, err) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -108,7 +108,7 @@ func TestCompiledCasm(t *testing.T) { Compiled: compiledClass, } - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: cairo1Class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v8/estimate_fee_test.go b/rpc/v8/estimate_fee_test.go index f7308c64b8..e48fee7952 100644 --- a/rpc/v8/estimate_fee_test.go +++ b/rpc/v8/estimate_fee_test.go @@ -30,7 +30,7 @@ func TestEstimateFee(t *testing.T) { log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, mockVM, log) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil).AnyTimes() diff --git a/rpc/v8/handlers_test.go b/rpc/v8/handlers_test.go index a3d703a239..1426aff2cc 100644 --- a/rpc/v8/handlers_test.go +++ b/rpc/v8/handlers_test.go @@ -35,7 +35,7 @@ func TestThrottledVMError(t *testing.T) { throttledVM := node.NewThrottledVM(mockVM, 0, 0) handler := rpcv8.New(mockReader, mockSyncReader, throttledVM, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) throttledErr := "VM throughput limit reached" t.Run("call", func(t *testing.T) { @@ -91,9 +91,9 @@ func TestThrottledVMError(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) diff --git a/rpc/v8/helpers.go b/rpc/v8/helpers.go index a68009cbbf..7c7f0f723d 100644 --- a/rpc/v8/helpers.go +++ b/rpc/v8/helpers.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -21,7 +22,7 @@ func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } // nil is returned if l1 head doesn't exist - return l1Head, nil + return &l1Head, nil } func isL1Verified(n uint64, l1 *core.L1Head) bool { @@ -137,8 +138,8 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(blockID *BlockID) (core.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader core.StateReader +func (h *Handler) stateByBlockID(blockID *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader commonstate.StateReader var closer blockchain.StateCloser var err error switch blockID.Type() { diff --git a/rpc/v8/l1.go b/rpc/v8/l1.go index f217ed3d03..6cecdc0aab 100644 --- a/rpc/v8/l1.go +++ b/rpc/v8/l1.go @@ -67,12 +67,12 @@ func (h *Handler) GetMessageStatus(ctx context.Context, l1TxnHash *common.Hash) if err != nil { return nil, jsonrpc.Err(jsonrpc.InternalError, fmt.Errorf("failed to retrieve L1 handler txn %v", err)) } - status, rpcErr := h.TransactionStatus(ctx, *hash) + status, rpcErr := h.TransactionStatus(ctx, hash) if rpcErr != nil { return nil, rpcErr } results[i] = MsgStatus{ - L1HandlerHash: hash, + L1HandlerHash: &hash, FinalityStatus: status.Finality, FailureReason: status.FailureReason, } diff --git a/rpc/v8/pending_data_wrapper.go b/rpc/v8/pending_data_wrapper.go index e616dac023..18eb0a444c 100644 --- a/rpc/v8/pending_data_wrapper.go +++ b/rpc/v8/pending_data_wrapper.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/sync" ) @@ -36,7 +37,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (core.StateReader, func() error, error) { +func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { pending, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, sync.ErrPendingBlockNotFound) { diff --git a/rpc/v8/pending_data_wrapper_test.go b/rpc/v8/pending_data_wrapper_test.go index 7c3c9f23e8..b685c7b179 100644 --- a/rpc/v8/pending_data_wrapper_test.go +++ b/rpc/v8/pending_data_wrapper_test.go @@ -108,7 +108,7 @@ func TestPendingDataWrapper_PendingState(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) handler := rpc.New(mockReader, mockSyncReader, nil, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("Returns pending state when starknet version < 0.14.0", func(t *testing.T) { mockSyncReader.EXPECT().PendingData().Return( &sync.Pending{}, diff --git a/rpc/v8/simulation_test.go b/rpc/v8/simulation_test.go index 0774b42f2e..6ff9ebbbf5 100644 --- a/rpc/v8/simulation_test.go +++ b/rpc/v8/simulation_test.go @@ -36,7 +36,7 @@ func TestSimulateTransactions(t *testing.T) { PriceInFri: &felt.Zero, }, } - defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateReader) { mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) @@ -45,14 +45,14 @@ func TestSimulateTransactions(t *testing.T) { name string stepsUsed uint64 err *jsonrpc.Error - mockBehavior func(*mocks.MockReader, *mocks.MockVM, *mocks.MockStateHistoryReader) + mockBehavior func(*mocks.MockReader, *mocks.MockVM, *mocks.MockStateReader) simulationFlags []rpcv6.SimulationFlag simulatedTxs []rpc.SimulatedTransaction }{ { //nolint:dupl name: "ok with zero values, skip fee", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -71,7 +71,7 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip validate", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -89,7 +89,7 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "transaction execution error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -107,7 +107,7 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "inconsistent lengths error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -135,7 +135,7 @@ func TestSimulateTransactions(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) test.mockBehavior(mockReader, mockVM, mockState) handler := rpc.New(mockReader, nil, mockVM, utils.NewNopZapLogger()) @@ -259,7 +259,7 @@ func TestSimulateTransactionsShouldErrorWithoutSenderAddressOrResourceBounds(t * mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 8167da6626..ee7310d3e5 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -4,9 +4,13 @@ import ( "errors" "fmt" - "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -37,7 +41,8 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(address) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } h.log.Errorw("Failed to get contract nonce", "err", err) @@ -49,7 +54,7 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * return nil, rpccore.ErrInternal } - return value, nil + return &value, nil } type StorageProofResult struct { @@ -67,7 +72,7 @@ func (h *Handler) StorageProof( return nil, rpccore.ErrInternal.CloneWithData(err) } - chainHeight, err := state.ChainHeight() + chainHeight, err := h.bcReader.Height() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } @@ -122,12 +127,12 @@ func (h *Handler) StorageProof( return nil, rpccore.ErrInternal.CloneWithData(err) } - contractTreeRoot, err := contractTrie.Root() + contractTreeRoot, err := contractTrie.Hash() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } - classTreeRoot, err := classTrie.Root() + classTreeRoot, err := classTrie.Hash() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } @@ -137,8 +142,8 @@ func (h *Handler) StorageProof( ContractsProof: contractProof, ContractsStorageProofs: contractStorageProof, GlobalRoots: &GlobalRoots{ - ContractsTreeRoot: contractTreeRoot, - ClassesTreeRoot: classTreeRoot, + ContractsTreeRoot: &contractTreeRoot, + ClassesTreeRoot: &classTreeRoot, BlockHash: head.Hash, }, }, nil @@ -206,20 +211,44 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp return nil } -func getClassProof(tr *trie.Trie, classes []felt.Felt) ([]*HashToNode, error) { - classProof := trie.NewProofNodeSet() - for _, class := range classes { - if err := tr.Prove(&class, classProof); err != nil { - return nil, err +func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { + switch t := tr.(type) { + case *commontrie.TrieAdapter: + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := t.Trie.Prove(&class, classProof); err != nil { + return nil, err + } } + return adaptTrie1ProofNodes(classProof), nil + case *commontrie.Trie2Adapter: + classProof := trie2.NewProofNodeSet() + for _, class := range classes { + if err := t.Trie.Prove(&class, classProof); err != nil { + return nil, err + } + } + return adaptTrie2ProofNodes(classProof), nil + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) } +} - return adaptProofNodes(classProof), nil +func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { + switch t := tr.(type) { + case *commontrie.TrieAdapter: + return getContractProofWithTrie(t.Trie, state, contracts) + case *commontrie.Trie2Adapter: + return getContractProofWithTrie2(t.Trie, state, contracts) + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) + } } -func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) + for i, contract := range contracts { if err := tr.Prove(&contract, contractProof); err != nil { return nil, err @@ -244,19 +273,56 @@ func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Fe } contractLeavesData[i] = &LeafData{ - Nonce: nonce, - ClassHash: classHash, + Nonce: &nonce, + ClassHash: &classHash, StorageRoot: root, } } return &ContractProof{ - Nodes: adaptProofNodes(contractProof), + Nodes: adaptTrie1ProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { + contractProof := trie2.NewProofNodeSet() + contractLeavesData := make([]*LeafData, len(contracts)) + + for i, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + + root := tr.Hash() + + nonce, err := state.ContractNonce(&contract) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + continue + } + return nil, err + } + + classHash, err := state.ContractClassHash(&contract) + if err != nil { + return nil, err + } + + contractLeavesData[i] = &LeafData{ + Nonce: &nonce, + ClassHash: &classHash, + StorageRoot: &root, + } + } + + return &ContractProof{ + Nodes: adaptTrie2ProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } -func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { +func getContractStorageProof(state commonstate.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) for i, storageKey := range storageKeys { contractStorageTrie, err := state.ContractStorageTrie(storageKey.Contract) @@ -264,20 +330,32 @@ func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) return nil, err } - contractStorageProof := trie.NewProofNodeSet() - for _, key := range storageKey.Keys { - if err := contractStorageTrie.Prove(&key, contractStorageProof); err != nil { - return nil, err + switch t := contractStorageTrie.(type) { + case *commontrie.TrieAdapter: + contractStorageProof := trie.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + contractStorageRes[i] = adaptTrie1ProofNodes(contractStorageProof) + case *commontrie.Trie2Adapter: + contractStorageProof := trie2.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + return nil, err + } } + contractStorageRes[i] = adaptTrie2ProofNodes(contractStorageProof) + default: + return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } - - contractStorageRes[i] = adaptProofNodes(contractStorageProof) } return contractStorageRes, nil } -func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -307,6 +385,36 @@ func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } +func adaptTrie2ProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trienode.BinaryNode: + node = &BinaryNode{ + Left: &hash, + Right: &hash, + } + case *trienode.EdgeNode: + path := n.Path.Felt() + node = &EdgeNode{ + Path: path.String(), + Length: int(n.Path.Len()), + Child: &hash, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index d92041db16..2f4373f64a 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -62,7 +62,7 @@ func TestStorageAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -172,11 +172,11 @@ func TestStorageProof(t *testing.T) { headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, func() error { return nil }, nil).AnyTimes() + mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() - mockState.EXPECT().ChainHeight().Return(blockNumber, nil).AnyTimes() mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() @@ -642,16 +642,16 @@ func TestStorageProof_StorageRoots(t *testing.T) { contractTrie, err := reader.ContractTrie() assert.NoError(t, err) - clsRoot, err := classTrie.Root() + clsRoot, err := classTrie.Hash() assert.NoError(t, err) - stgRoot, err := contractTrie.Root() + stgRoot, err := contractTrie.Hash() assert.NoError(t, err) assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) - verifyGlobalStateRoot(t, expectedGlobalRoot, clsRoot, stgRoot) + verifyGlobalStateRoot(t, expectedGlobalRoot, &clsRoot, &stgRoot) }) t.Run("check requested contract and storage slot exists", func(t *testing.T) { diff --git a/rpc/v8/subscriptions_test.go b/rpc/v8/subscriptions_test.go index 55eaf22015..ea6dc91ce5 100644 --- a/rpc/v8/subscriptions_test.go +++ b/rpc/v8/subscriptions_test.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -398,9 +399,12 @@ func (fs *fakeSyncer) HighestBlockHeader() *core.Header { func (fs *fakeSyncer) PendingData() (core.PendingData, error) { return nil, sync.ErrPendingBlockNotFound } -func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { return nil, nil, nil } -func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } +func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, error) { + return nil, nil, nil +} + +func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { return nil, nil, nil } diff --git a/rpc/v8/trace.go b/rpc/v8/trace.go index 750e0c8ca5..bdc4e64294 100644 --- a/rpc/v8/trace.go +++ b/rpc/v8/trace.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -190,7 +191,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block) defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState core.StateReader + headState commonstate.StateReader headStateCloser blockchain.StateCloser ) if isPending { @@ -315,7 +316,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json return nil, rpccore.ErrContractNotFound } - declaredClass, err := state.Class(classHash) + declaredClass, err := state.Class(&classHash) if err != nil { return nil, rpccore.ErrClassHashNotFound } @@ -331,7 +332,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json ContractAddress: &funcCall.ContractAddress, Selector: &funcCall.EntryPointSelector, Calldata: funcCall.Calldata, - ClassHash: classHash, + ClassHash: &classHash, }, &vm.BlockInfo{ Header: header, BlockHashToBeRevealed: blockHashToBeRevealed, diff --git a/rpc/v8/trace_test.go b/rpc/v8/trace_test.go index 9d91d2cebb..6fa1c8e7ee 100644 --- a/rpc/v8/trace_test.go +++ b/rpc/v8/trace_test.go @@ -311,7 +311,7 @@ func TestTraceTransaction(t *testing.T) { mockReader.EXPECT().BlockByHash(header.Hash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -397,7 +397,7 @@ func TestTraceTransaction(t *testing.T) { ).Times(2) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) @@ -644,9 +644,9 @@ func TestTraceBlockTransactions(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) @@ -767,7 +767,7 @@ func TestTraceBlockTransactions(t *testing.T) { mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -1269,7 +1269,7 @@ func TestCall(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v9/class.go b/rpc/v9/class.go index efa4aa0d61..3a57782597 100644 --- a/rpc/v9/class.go +++ b/rpc/v9/class.go @@ -134,7 +134,7 @@ func (h *Handler) ClassHashAt(id *BlockID, address *felt.Felt) (*felt.Felt, *jso return nil, rpccore.ErrContractNotFound } - return classHash, nil + return &classHash, nil } func adaptCairo0EntryPoints(entryPoints []core.EntryPoint) []rpcv6.EntryPoint { diff --git a/rpc/v9/class_test.go b/rpc/v9/class_test.go index 99255e3b73..fda48724cd 100644 --- a/rpc/v9/class_test.go +++ b/rpc/v9/class_test.go @@ -29,7 +29,7 @@ func TestClass(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(gomock.Any()).DoAndReturn(func(classHash *felt.Felt) (*core.DeclaredClass, error) { class, err := integGw.Class(t.Context(), classHash) @@ -81,7 +81,7 @@ func TestClass(t *testing.T) { t.Run("class hash not found error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(mockState, func() error { @@ -104,7 +104,7 @@ func TestClassAt(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) cairo0ContractAddress, _ := new(felt.Felt).SetRandom() cairo0ClassHash := utils.HexToFelt(t, "0x4631b6b3fa31e140524b7d21ba784cea223e618bffe60b5bbdca44a8b45be04") @@ -181,7 +181,7 @@ func TestClassHashAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v9/compiled_casm_test.go b/rpc/v9/compiled_casm_test.go index 1b0ac44249..a265b70400 100644 --- a/rpc/v9/compiled_casm_test.go +++ b/rpc/v9/compiled_casm_test.go @@ -37,7 +37,7 @@ func TestCompiledCasm(t *testing.T) { t.Run("class doesn't exist", func(t *testing.T) { classHash := utils.HexToFelt(t, "0x111") - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(nil, db.ErrKeyNotFound) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -66,7 +66,7 @@ func TestCompiledCasm(t *testing.T) { err = json.Unmarshal(program, &cairo0Definition) require.NoError(t, err) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -108,7 +108,7 @@ func TestCompiledCasm(t *testing.T) { Compiled: compiledClass, } - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: cairo1Class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v9/estimate_fee_test.go b/rpc/v9/estimate_fee_test.go index cc057c444c..0a4b6016f1 100644 --- a/rpc/v9/estimate_fee_test.go +++ b/rpc/v9/estimate_fee_test.go @@ -30,7 +30,7 @@ func TestEstimateFee(t *testing.T) { log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, mockVM, log) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil).AnyTimes() diff --git a/rpc/v9/handlers_test.go b/rpc/v9/handlers_test.go index 18c53c2acc..507b6d7ba1 100644 --- a/rpc/v9/handlers_test.go +++ b/rpc/v9/handlers_test.go @@ -34,7 +34,7 @@ func TestThrottledVMError(t *testing.T) { throttledVM := node.NewThrottledVM(mockVM, 0, 0) handler := rpcv9.New(mockReader, mockSyncReader, throttledVM, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) throttledErr := "VM throughput limit reached" t.Run("call", func(t *testing.T) { @@ -90,9 +90,9 @@ func TestThrottledVMError(t *testing.T) { } mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) - state := mocks.NewMockStateHistoryReader(mockCtrl) + state := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(state, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) diff --git a/rpc/v9/helpers.go b/rpc/v9/helpers.go index fc650f0b22..bb152427e8 100644 --- a/rpc/v9/helpers.go +++ b/rpc/v9/helpers.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -21,7 +22,7 @@ func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } // nil is returned if l1 head doesn't exist - return l1Head, nil + return &l1Head, nil } func isL1Verified(n uint64, l1 *core.L1Head) bool { @@ -47,7 +48,7 @@ func (h *Handler) blockByID(blockID *BlockID) (*core.Block, *jsonrpc.Error) { case hash: block, err = h.bcReader.BlockByHash(blockID.Hash()) case l1Accepted: - var l1Head *core.L1Head + var l1Head core.L1Head l1Head, err = h.bcReader.L1Head() if err != nil { break @@ -86,7 +87,7 @@ func (h *Handler) blockHeaderByID(blockID *BlockID) (*core.Header, *jsonrpc.Erro case number: header, err = h.bcReader.BlockHeaderByNumber(blockID.Number()) case l1Accepted: - var l1Head *core.L1Head + var l1Head core.L1Head l1Head, err = h.bcReader.L1Head() if err != nil { break @@ -151,8 +152,8 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(blockID *BlockID) (core.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader core.StateReader +func (h *Handler) stateByBlockID(blockID *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader commonstate.StateReader var closer blockchain.StateCloser var err error switch blockID.Type() { @@ -165,7 +166,7 @@ func (h *Handler) stateByBlockID(blockID *BlockID) (core.StateReader, blockchain case number: reader, closer, err = h.bcReader.StateAtBlockNumber(blockID.Number()) case l1Accepted: - var l1Head *core.L1Head + var l1Head core.L1Head l1Head, err = h.bcReader.L1Head() if err != nil { break diff --git a/rpc/v9/l1.go b/rpc/v9/l1.go index a3d9298a02..a128793ca0 100644 --- a/rpc/v9/l1.go +++ b/rpc/v9/l1.go @@ -68,7 +68,7 @@ func (h *Handler) GetMessageStatus(ctx context.Context, l1TxnHash *common.Hash) if err != nil { return nil, jsonrpc.Err(jsonrpc.InternalError, fmt.Errorf("failed to retrieve L1 handler txn %v", err)) } - status, rpcErr := h.TransactionStatus(ctx, *hash) + status, rpcErr := h.TransactionStatus(ctx, hash) if rpcErr != nil { return nil, rpcErr } @@ -80,7 +80,7 @@ func (h *Handler) GetMessageStatus(ctx context.Context, l1TxnHash *common.Hash) } results[i] = MsgStatus{ - L1HandlerHash: hash, + L1HandlerHash: &hash, FinalityStatus: status.Finality, FailureReason: status.FailureReason, ExecutionStatus: status.Execution, diff --git a/rpc/v9/nonce.go b/rpc/v9/nonce.go index 0d8fd55570..5bfa8954ed 100644 --- a/rpc/v9/nonce.go +++ b/rpc/v9/nonce.go @@ -26,5 +26,5 @@ func (h *Handler) Nonce(id *BlockID, address *felt.Felt) (*felt.Felt, *jsonrpc.E return nil, rpccore.ErrContractNotFound } - return nonce, nil + return &nonce, nil } diff --git a/rpc/v9/nonce_test.go b/rpc/v9/nonce_test.go index 8a188d180c..d1ad43dfb0 100644 --- a/rpc/v9/nonce_test.go +++ b/rpc/v9/nonce_test.go @@ -51,7 +51,7 @@ func TestNonce(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v9/pending_data_wrapper.go b/rpc/v9/pending_data_wrapper.go index ab6b538aa2..4a5bb134c0 100644 --- a/rpc/v9/pending_data_wrapper.go +++ b/rpc/v9/pending_data_wrapper.go @@ -6,6 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/sync" ) @@ -46,7 +47,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (core.StateReader, func() error, error) { +func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { state, closer, err := h.syncReader.PendingState() if err != nil { if errors.Is(err, sync.ErrPendingBlockNotFound) { diff --git a/rpc/v9/pending_data_wrapper_test.go b/rpc/v9/pending_data_wrapper_test.go index 2cfae14f99..271fe98812 100644 --- a/rpc/v9/pending_data_wrapper_test.go +++ b/rpc/v9/pending_data_wrapper_test.go @@ -95,7 +95,7 @@ func TestPendingDataWrapper_PendingState(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) handler := rpc.New(mockReader, mockSyncReader, nil, nil) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("Returns pending state", func(t *testing.T) { mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) pendingState, closer, err := handler.PendingState() diff --git a/rpc/v9/simulation_test.go b/rpc/v9/simulation_test.go index b7161ff1b4..9ba9713892 100644 --- a/rpc/v9/simulation_test.go +++ b/rpc/v9/simulation_test.go @@ -36,7 +36,7 @@ func TestSimulateTransactions(t *testing.T) { PriceInFri: &felt.Zero, }, } - defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateReader) { mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) @@ -45,14 +45,14 @@ func TestSimulateTransactions(t *testing.T) { name string stepsUsed uint64 err *jsonrpc.Error - mockBehavior func(*mocks.MockReader, *mocks.MockVM, *mocks.MockStateHistoryReader) + mockBehavior func(*mocks.MockReader, *mocks.MockVM, *mocks.MockStateReader) simulationFlags []rpcv6.SimulationFlag simulatedTxs []rpc.SimulatedTransaction }{ { //nolint:dupl name: "ok with zero values, skip fee", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -71,7 +71,7 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip validate", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -89,7 +89,7 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "transaction execution error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -107,7 +107,7 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "inconsistent lengths error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateHistoryReader) { + mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -135,7 +135,7 @@ func TestSimulateTransactions(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) test.mockBehavior(mockReader, mockVM, mockState) handler := rpc.New(mockReader, nil, mockVM, utils.NewNopZapLogger()) @@ -259,7 +259,7 @@ func TestSimulateTransactionsShouldErrorWithoutSenderAddressOrResourceBounds(t * mockReader := mocks.NewMockReader(mockCtrl) mockVM := mocks.NewMockVM(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index eb9a23e507..76a8a4d339 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -4,9 +4,12 @@ import ( "errors" "fmt" - "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -49,7 +52,7 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * return nil, rpccore.ErrInternal } - return value, nil + return &value, nil } type StorageProofResult struct { @@ -67,7 +70,7 @@ func (h *Handler) StorageProof( return nil, rpccore.ErrInternal.CloneWithData(err) } - chainHeight, err := state.ChainHeight() + chainHeight, err := h.bcReader.Height() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } @@ -122,12 +125,11 @@ func (h *Handler) StorageProof( return nil, rpccore.ErrInternal.CloneWithData(err) } - contractTreeRoot, err := contractTrie.Root() + contractTreeRoot, err := contractTrie.Hash() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } - - classTreeRoot, err := classTrie.Root() + classTreeRoot, err := classTrie.Hash() if err != nil { return nil, rpccore.ErrInternal.CloneWithData(err) } @@ -137,8 +139,8 @@ func (h *Handler) StorageProof( ContractsProof: contractProof, ContractsStorageProofs: contractStorageProof, GlobalRoots: &GlobalRoots{ - ContractsTreeRoot: contractTreeRoot, - ClassesTreeRoot: classTreeRoot, + ContractsTreeRoot: &contractTreeRoot, + ClassesTreeRoot: &classTreeRoot, BlockHash: head.Hash, }, }, nil @@ -207,20 +209,44 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp return nil } -func getClassProof(tr *trie.Trie, classes []felt.Felt) ([]*HashToNode, error) { - classProof := trie.NewProofNodeSet() - for _, class := range classes { - if err := tr.Prove(&class, classProof); err != nil { - return nil, err +func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { + switch t := tr.(type) { + case *commontrie.TrieAdapter: + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := t.Trie.Prove(&class, classProof); err != nil { + return nil, err + } } + return adaptTrie1ProofNodes(classProof), nil + case *commontrie.Trie2Adapter: + classProof := trie2.NewProofNodeSet() + for _, class := range classes { + if err := t.Trie.Prove(&class, classProof); err != nil { + return nil, err + } + } + return adaptTrie2ProofNodes(classProof), nil + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) } +} - return adaptProofNodes(classProof), nil +func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { + switch t := tr.(type) { + case *commontrie.TrieAdapter: + return getContractProofWithTrie1(t.Trie, state, contracts) + case *commontrie.Trie2Adapter: + return getContractProofWithTrie2(t.Trie, state, contracts) + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) + } } -func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie1(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) + for i, contract := range contracts { if err := tr.Prove(&contract, contractProof); err != nil { return nil, err @@ -245,19 +271,56 @@ func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Fe } contractLeavesData[i] = &LeafData{ - Nonce: nonce, - ClassHash: classHash, + Nonce: &nonce, + ClassHash: &classHash, StorageRoot: root, } } return &ContractProof{ - Nodes: adaptProofNodes(contractProof), + Nodes: adaptTrie1ProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } -func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { +func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { + contractProof := trie2.NewProofNodeSet() + contractLeavesData := make([]*LeafData, len(contracts)) + + for i, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + + root := tr.Hash() + + nonce, err := state.ContractNonce(&contract) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + continue + } + return nil, err + } + + classHash, err := state.ContractClassHash(&contract) + if err != nil { + return nil, err + } + + contractLeavesData[i] = &LeafData{ + Nonce: &nonce, + ClassHash: &classHash, + StorageRoot: &root, + } + } + + return &ContractProof{ + Nodes: adaptTrie2ProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractStorageProof(state commonstate.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) for i, storageKey := range storageKeys { contractStorageTrie, err := state.ContractStorageTrie(storageKey.Contract) @@ -265,20 +328,32 @@ func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) return nil, err } - contractStorageProof := trie.NewProofNodeSet() - for _, key := range storageKey.Keys { - if err := contractStorageTrie.Prove(&key, contractStorageProof); err != nil { - return nil, err + switch t := contractStorageTrie.(type) { + case *commontrie.TrieAdapter: + contractStorageProof := trie.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + contractStorageRes[i] = adaptTrie1ProofNodes(contractStorageProof) + case *commontrie.Trie2Adapter: + contractStorageProof := trie2.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + return nil, err + } } + contractStorageRes[i] = adaptTrie2ProofNodes(contractStorageProof) + default: + return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } - - contractStorageRes[i] = adaptProofNodes(contractStorageProof) } return contractStorageRes, nil } -func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -308,6 +383,36 @@ func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } +func adaptTrie2ProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trienode.BinaryNode: + node = &BinaryNode{ + Left: &hash, + Right: &hash, + } + case *trienode.EdgeNode: + path := n.Path.Felt() + node = &EdgeNode{ + Path: path.String(), + Length: int(n.Path.Len()), + Child: &hash, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index f5d529931d..3663e554d5 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -62,7 +62,7 @@ func TestStorageAt(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) @@ -190,11 +190,11 @@ func TestStorageProof(t *testing.T) { headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} mockReader := mocks.NewMockReader(mockCtrl) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) mockReader.EXPECT().HeadState().Return(mockState, func() error { return nil }, nil).AnyTimes() + mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() - mockState.EXPECT().ChainHeight().Return(blockNumber, nil).AnyTimes() mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() @@ -660,16 +660,16 @@ func TestStorageProof_StorageRoots(t *testing.T) { contractTrie, err := reader.ContractTrie() assert.NoError(t, err) - clsRoot, err := classTrie.Root() + clsRoot, err := classTrie.Hash() assert.NoError(t, err) - stgRoot, err := contractTrie.Root() + stgRoot, err := contractTrie.Hash() assert.NoError(t, err) assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) - verifyGlobalStateRoot(t, expectedGlobalRoot, clsRoot, stgRoot) + verifyGlobalStateRoot(t, expectedGlobalRoot, &clsRoot, &stgRoot) }) t.Run("check requested contract and storage slot exists", func(t *testing.T) { diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index dd8fd506d2..1046da0603 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -450,9 +451,12 @@ func (fs *fakeSyncer) HighestBlockHeader() *core.Header { func (fs *fakeSyncer) PendingData() (core.PendingData, error) { return nil, sync.ErrPendingBlockNotFound } -func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { return nil, nil, nil } -func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } +func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, error) { + return nil, nil, nil +} + +func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { return nil, nil, nil } diff --git a/rpc/v9/trace.go b/rpc/v9/trace.go index e90bb3ad7d..d69959d1f2 100644 --- a/rpc/v9/trace.go +++ b/rpc/v9/trace.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -285,7 +286,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block) defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState core.StateReader + headState commonstate.StateReader headStateCloser blockchain.StateCloser ) if isPending { @@ -410,7 +411,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json return nil, rpccore.ErrContractNotFound } - declaredClass, err := state.Class(classHash) + declaredClass, err := state.Class(&classHash) if err != nil { return nil, rpccore.ErrClassHashNotFound } @@ -426,7 +427,7 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json ContractAddress: &funcCall.ContractAddress, Selector: &funcCall.EntryPointSelector, Calldata: funcCall.Calldata, - ClassHash: classHash, + ClassHash: &classHash, }, &vm.BlockInfo{ Header: header, BlockHashToBeRevealed: blockHashToBeRevealed, diff --git a/rpc/v9/trace_test.go b/rpc/v9/trace_test.go index 657af76a94..1aa8e50da8 100644 --- a/rpc/v9/trace_test.go +++ b/rpc/v9/trace_test.go @@ -326,7 +326,7 @@ func TestTraceTransaction(t *testing.T) { mockReader.EXPECT().BlockByHash(header.Hash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -413,7 +413,7 @@ func TestTraceTransaction(t *testing.T) { ) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockSyncReader.EXPECT().PendingState().Return(headState, nopCloser, nil) @@ -500,7 +500,7 @@ func TestTraceTransaction(t *testing.T) { &preConfirmed, nil, ) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) mockSyncReader.EXPECT().PendingStateBeforeIndex(0).Return(headState, nopCloser, nil) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) @@ -749,7 +749,7 @@ func TestTraceBlockTransactions(t *testing.T) { mockReader.EXPECT().BlockByHash(blockHash).Return(block, nil) mockReader.EXPECT().StateAtBlockHash(header.ParentHash).Return(nil, nopCloser, nil) - headState := mocks.NewMockStateHistoryReader(mockCtrl) + headState := mocks.NewMockStateReader(mockCtrl) headState.EXPECT().Class(tx.ClassHash).Return(declaredClass, nil) mockReader.EXPECT().HeadState().Return(headState, nopCloser, nil) @@ -1230,7 +1230,7 @@ func TestCall(t *testing.T) { assert.Equal(t, rpccore.ErrBlockNotFound, rpcErr) }) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) diff --git a/sequencer/sequencer.go b/sequencer/sequencer.go index f5821fb9cc..499c547ffb 100644 --- a/sequencer/sequencer.go +++ b/sequencer/sequencer.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/feed" "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/plugin" @@ -210,11 +211,11 @@ func (s *Sequencer) PendingBlock() *core.Block { return s.buildState.PendingBlock() } -func (s *Sequencer) PendingState() (core.StateReader, func() error, error) { +func (s *Sequencer) PendingState() (commonstate.StateReader, func() error, error) { return s.builder.PendingState(s.buildState) } -func (s *Sequencer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (s *Sequencer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { return nil, nil, errors.ErrUnsupported } diff --git a/sync/pending.go b/sync/pending.go index ae05cc55d3..c5405fd1d0 100644 --- a/sync/pending.go +++ b/sync/pending.go @@ -6,7 +6,9 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/db" ) @@ -61,10 +63,10 @@ func (p *Pending) Variant() core.PendingDataVariant { type PendingState struct { stateDiff *core.StateDiff newClasses map[felt.Felt]core.Class - head core.StateReader + head commonstate.StateReader } -func NewPendingState(stateDiff *core.StateDiff, newClasses map[felt.Felt]core.Class, head core.StateReader) *PendingState { +func NewPendingState(stateDiff *core.StateDiff, newClasses map[felt.Felt]core.Class, head commonstate.StateReader) *PendingState { return &PendingState{ stateDiff: stateDiff, newClasses: newClasses, @@ -72,42 +74,41 @@ func NewPendingState(stateDiff *core.StateDiff, newClasses map[felt.Felt]core.Cl } } -func (p *PendingState) ChainHeight() (uint64, error) { - return p.head.ChainHeight() -} - func (p *PendingState) StateDiff() *core.StateDiff { return p.stateDiff } -func (p *PendingState) ContractClassHash(addr *felt.Felt) (*felt.Felt, error) { +func (p *PendingState) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { if classHash, ok := p.stateDiff.ReplacedClasses[*addr]; ok { - return classHash, nil + return *classHash, nil } else if classHash, ok = p.stateDiff.DeployedContracts[*addr]; ok { - return classHash, nil + return *classHash, nil } - return p.head.ContractClassHash(addr) + classHash, err := p.head.ContractClassHash(addr) + return classHash, err } -func (p *PendingState) ContractNonce(addr *felt.Felt) (*felt.Felt, error) { +func (p *PendingState) ContractNonce(addr *felt.Felt) (felt.Felt, error) { if nonce, found := p.stateDiff.Nonces[*addr]; found { - return nonce, nil + return *nonce, nil } else if _, found = p.stateDiff.DeployedContracts[*addr]; found { - return &felt.Felt{}, nil + return felt.Zero, nil } - return p.head.ContractNonce(addr) + nonce, err := p.head.ContractNonce(addr) + return nonce, err } -func (p *PendingState) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) { +func (p *PendingState) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) { if diffs, found := p.stateDiff.StorageDiffs[*addr]; found { if value, found := diffs[*key]; found { - return value, nil + return *value, nil } } if _, found := p.stateDiff.DeployedContracts[*addr]; found { - return &felt.Felt{}, nil + return felt.Zero, nil } - return p.head.ContractStorage(addr, key) + value, err := p.head.ContractStorage(addr, key) + return value, err } func (p *PendingState) Class(classHash *felt.Felt) (*core.DeclaredClass, error) { @@ -121,23 +122,27 @@ func (p *PendingState) Class(classHash *felt.Felt) (*core.DeclaredClass, error) return p.head.Class(classHash) } -func (p *PendingState) ClassTrie() (*trie.Trie, error) { - return nil, core.ErrHistoricalTrieNotSupported +func (p *PendingState) ClassTrie() (commontrie.Trie, error) { + return nil, state.ErrHistoricalTrieNotSupported } -func (p *PendingState) ContractTrie() (*trie.Trie, error) { - return nil, core.ErrHistoricalTrieNotSupported +func (p *PendingState) ContractTrie() (commontrie.Trie, error) { + return nil, state.ErrHistoricalTrieNotSupported } -func (p *PendingState) ContractStorageTrie(addr *felt.Felt) (*trie.Trie, error) { - return nil, core.ErrHistoricalTrieNotSupported +func (p *PendingState) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { + return nil, state.ErrHistoricalTrieNotSupported } type PendingStateWriter struct { *PendingState } -func NewPendingStateWriter(stateDiff *core.StateDiff, newClasses map[felt.Felt]core.Class, head core.StateReader) PendingStateWriter { +func NewPendingStateWriter( + stateDiff *core.StateDiff, + newClasses map[felt.Felt]core.Class, + head commonstate.StateReader, +) PendingStateWriter { return PendingStateWriter{ PendingState: &PendingState{ stateDiff: stateDiff, @@ -160,7 +165,7 @@ func (p *PendingStateWriter) IncrementNonce(contractAddress *felt.Felt) error { if err != nil { return fmt.Errorf("get contract nonce: %v", err) } - p.stateDiff.Nonces[*contractAddress] = currentNonce.Add(currentNonce, feltOne) + p.stateDiff.Nonces[*contractAddress] = currentNonce.Add(¤tNonce, feltOne) return nil } diff --git a/sync/pending_test.go b/sync/pending_test.go index d2cc63e73f..172d8e53a7 100644 --- a/sync/pending_test.go +++ b/sync/pending_test.go @@ -16,7 +16,7 @@ func TestPendingState(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) - mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockState := mocks.NewMockStateReader(mockCtrl) deployedAddr, err := new(felt.Felt).SetRandom() require.NoError(t, err) diff --git a/sync/sync.go b/sync/sync.go index fd23d808c5..2a4f833ef9 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" junoplugin "github.com/NethermindEth/juno/plugin" @@ -76,8 +77,8 @@ type Reader interface { PendingData() (core.PendingData, error) PendingBlock() *core.Block - PendingState() (core.StateReader, func() error, error) - PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) + PendingState() (commonstate.StateReader, func() error, error) + PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) } // This is temporary and will be removed once the p2p synchronizer implements this interface. @@ -111,11 +112,11 @@ func (n *NoopSynchronizer) PendingData() (core.PendingData, error) { return nil, errors.New("PendingData() is not implemented") } -func (n *NoopSynchronizer) PendingState() (core.StateReader, func() error, error) { +func (n *NoopSynchronizer) PendingState() (commonstate.StateReader, func() error, error) { return nil, nil, errors.New("PendingState() not implemented") } -func (n *NoopSynchronizer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (n *NoopSynchronizer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { return nil, nil, errors.New("PendingStateBeforeIndex() not implemented") } @@ -245,7 +246,8 @@ func (s *Synchronizer) handlePluginRevertBlock() { err = s.plugin.RevertBlock( &junoplugin.BlockAndStateUpdate{Block: fromBlock, StateUpdate: fromSU}, toBlockAndStateUpdate, - reverseStateDiff) + &reverseStateDiff, + ) if err != nil { s.log.Errorw("Plugin RevertBlock failure:", "err", err) } @@ -745,7 +747,7 @@ func (s *Synchronizer) PendingBlock() *core.Block { var noop = func() error { return nil } // PendingState returns the state resulting from execution of the pending block -func (s *Synchronizer) PendingState() (core.StateReader, func() error, error) { +func (s *Synchronizer) PendingState() (commonstate.StateReader, func() error, error) { txn := s.db.NewIndexedBatch() pending, err := s.PendingData() @@ -753,12 +755,18 @@ func (s *Synchronizer) PendingState() (core.StateReader, func() error, error) { return nil, nil, err } - return NewPendingState(pending.GetStateUpdate().StateDiff, pending.GetNewClasses(), core.NewState(txn)), noop, nil + pendingStateUpdate := pending.GetStateUpdate() + state, err := s.blockchain.StateFactory.NewState(pendingStateUpdate.OldRoot, txn) + if err != nil { + return nil, nil, err + } + + return NewPendingState(pendingStateUpdate.StateDiff, pending.GetNewClasses(), state), noop, nil } // PendingStateAfterIndex returns the state obtained by applying all transaction state diffs // up to given index in the pre-confirmed block. -func (s *Synchronizer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (s *Synchronizer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { txn := s.db.NewIndexedBatch() pending, err := s.PendingData() @@ -777,7 +785,12 @@ func (s *Synchronizer) PendingStateBeforeIndex(index int) (core.StateReader, fun stateDiff.Merge(txStateDiff) } - return NewPendingState(&stateDiff, pending.GetNewClasses(), core.NewState(txn)), noop, nil + state, err := s.blockchain.StateFactory.NewState(pending.GetStateUpdate().OldRoot, txn) + if err != nil { + return nil, nil, err + } + + return NewPendingState(&stateDiff, pending.GetNewClasses(), state), noop, nil } func (s *Synchronizer) storeEmptyPendingData(lastHeader *core.Header) { diff --git a/sync/sync_test.go b/sync/sync_test.go index 130fe891b2..53f01f005e 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -157,6 +157,8 @@ func TestReorg(t *testing.T) { require.NoError(t, synchronizer.Run(ctx)) cancel() + require.NoError(t, bc.Stop()) + t.Run("resync to mainnet with the same db", func(t *testing.T) { bc := blockchain.New(testDB, &utils.Mainnet) diff --git a/vm/state.go b/vm/state.go index 4ce85663ea..1020893cea 100644 --- a/vm/state.go +++ b/vm/state.go @@ -9,6 +9,7 @@ import ( "unsafe" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/db" ) @@ -25,14 +26,15 @@ func JunoStateGetStorageAt(readerHandle C.uintptr_t, contractAddress, storageLoc storageLocationFelt := makeFeltFromPtr(storageLocation) val, err := context.state.ContractStorage(contractAddressFelt, storageLocationFelt) if err != nil { - if !errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalicki): handle errors of both states + if !errors.Is(err, state.ErrContractNotDeployed) && !errors.Is(err, db.ErrKeyNotFound) { context.log.Errorw("JunoStateGetStorageAt failed to read contract storage", "err", err) return 0 } - val = &felt.Zero + val = felt.Zero } - return fillBufferWithFelt(val, buffer) + return fillBufferWithFelt(&val, buffer) } //export JunoStateGetNonceAt @@ -42,14 +44,15 @@ func JunoStateGetNonceAt(readerHandle C.uintptr_t, contractAddress, buffer unsaf contractAddressFelt := makeFeltFromPtr(contractAddress) val, err := context.state.ContractNonce(contractAddressFelt) if err != nil { - if !errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalicki): handle errors of both states + if !errors.Is(err, db.ErrKeyNotFound) && !errors.Is(err, state.ErrContractNotDeployed) { context.log.Errorw("JunoStateGetNonceAt failed to read contract nonce", "err", err) return 0 } - val = &felt.Zero + val = felt.Zero } - return fillBufferWithFelt(val, buffer) + return fillBufferWithFelt(&val, buffer) } //export JunoStateGetClassHashAt @@ -59,14 +62,15 @@ func JunoStateGetClassHashAt(readerHandle C.uintptr_t, contractAddress, buffer u contractAddressFelt := makeFeltFromPtr(contractAddress) val, err := context.state.ContractClassHash(contractAddressFelt) if err != nil { - if !errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalicki): handle errors of both states + if !errors.Is(err, db.ErrKeyNotFound) && !errors.Is(err, state.ErrContractNotDeployed) { context.log.Errorw("JunoStateGetClassHashAt failed to read contract class", "err", err) return 0 } - val = &felt.Zero + val = felt.Zero } - return fillBufferWithFelt(val, buffer) + return fillBufferWithFelt(&val, buffer) } //export JunoStateGetCompiledClass diff --git a/vm/vm.go b/vm/vm.go index 3e31e27708..9348f341e7 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -20,6 +20,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/utils" ) @@ -40,10 +41,10 @@ type CallResult struct { //go:generate mockgen -destination=../mocks/mock_vm.go -package=mocks github.com/NethermindEth/juno/vm VM type VM interface { - Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, + Call(callInfo *CallInfo, blockInfo *BlockInfo, state commonstate.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, structuredErrStack, returnStateDiff bool) (CallResult, error) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, - state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, + state commonstate.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, ) (ExecutionResults, error) } @@ -62,7 +63,7 @@ func New(concurrencyMode bool, log utils.SimpleLogger) VM { // callContext manages the context that a Call instance executes on type callContext struct { // state that the call is running on - state core.StateReader + state commonstate.StateReader log utils.SimpleLogger // err field to be possibly populated in case of an error in execution err string @@ -240,7 +241,7 @@ func makeByteFromBool(b bool) byte { return boolByte } -func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateReader, +func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state commonstate.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, structuredErrStack, returnStateDiff bool, ) (CallResult, error) { context := &callContext{ @@ -291,7 +292,7 @@ func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateRead // Execute executes a given transaction set and returns the gas spent per transaction func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, - blockInfo *BlockInfo, state core.StateReader, network *utils.Network, + blockInfo *BlockInfo, state commonstate.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, errorStack, allowBinarySearch bool, ) (ExecutionResults, error) { context := &callContext{ diff --git a/vm/vm_test.go b/vm/vm_test.go index 66467b3e12..7f586bba48 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -7,6 +7,9 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/rpc/rpccore" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -27,48 +30,57 @@ func TestCallDeprecatedCairo(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := core.NewState(txn) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, - }, - }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) - - entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") - - ret, err := New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) + triedb, err := triedb.New(testDB, nil) require.NoError(t, err) - assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) - - require.NoError(t, testState.Update(1, &core.StateUpdate{ - OldRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - NewRoot: utils.HexToFelt(t, "0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), - StateDiff: &core.StateDiff{ - StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ - *contractAddr: { - *utils.HexToFelt(t, "0x206f38f7e4f15e87567361213c28f235cccdaa1d7fd34c9db1dfe9489c6a091"): new(felt.Felt).SetUint64(1337), + stateDB := state.NewStateDB(testDB, triedb) + for _, stateVersion := range []bool{true, false} { + stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, }, }, - }, - }, nil, false)) - - ret, err = New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret.Result) + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) + + entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") + + ret, err := New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + + require.NoError(t, testState.Update(1, &core.StateUpdate{ + OldRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + NewRoot: utils.HexToFelt(t, "0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), + StateDiff: &core.StateDiff{ + StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ + *contractAddr: { + *utils.HexToFelt(t, "0x206f38f7e4f15e87567361213c28f235cccdaa1d7fd34c9db1dfe9489c6a091"): new(felt.Felt).SetUint64(1337), + }, + }, + }, + }, nil, false)) + + ret, err = New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret.Result) + } } func TestCallDeprecatedCairoMaxSteps(t *testing.T) { @@ -83,27 +95,36 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := core.NewState(txn) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, + triedb, err := triedb.New(testDB, nil) + require.NoError(t, err) + stateDB := state.NewStateDB(testDB, triedb) + for _, stateVersion := range []bool{true, false} { + stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, + }, }, - }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) - - entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") - - _, err = New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 0, simpleClass.SierraVersion(), false, false) - assert.ErrorContains(t, err, "RunResources has no remaining steps") + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) + + entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") + + _, err = New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 0, simpleClass.SierraVersion(), false, false) + assert.ErrorContains(t, err, "RunResources has no remaining steps") + } } func TestCallCairo(t *testing.T) { @@ -118,57 +139,65 @@ func TestCallCairo(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := core.NewState(txn) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, + triedb, err := triedb.New(testDB, nil) + require.NoError(t, err) + stateDB := state.NewStateDB(testDB, triedb) + for _, stateVersion := range []bool{true, false} { + stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, + }, }, - }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) - logLevel := utils.NewLogLevel(utils.ERROR) - log, err := utils.NewZapLogger(logLevel, false) - require.NoError(t, err) + logLevel := utils.NewLogLevel(utils.ERROR) + log, err := utils.NewZapLogger(logLevel, false) + require.NoError(t, err) - // test_storage_read - entryPoint := utils.HexToFelt(t, "0x5df99ae77df976b4f0e5cf28c7dcfe09bd6e81aab787b19ac0c08e03d928cf") - storageLocation := utils.HexToFelt(t, "0x44") - ret, err := New(false, log).Call(&CallInfo{ - ContractAddress: contractAddr, - Selector: entryPoint, - Calldata: []felt.Felt{ - *storageLocation, - }, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) - - require.NoError(t, testState.Update(1, &core.StateUpdate{ - OldRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), - NewRoot: utils.HexToFelt(t, "0x7a9da0a7471a8d5118d3eefb8c26a6acbe204eb1eaa934606f4757a595fe552"), - StateDiff: &core.StateDiff{ - StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ - *contractAddr: { - *storageLocation: new(felt.Felt).SetUint64(37), + // test_storage_read + entryPoint := utils.HexToFelt(t, "0x5df99ae77df976b4f0e5cf28c7dcfe09bd6e81aab787b19ac0c08e03d928cf") + storageLocation := utils.HexToFelt(t, "0x44") + ret, err := New(false, log).Call(&CallInfo{ + ContractAddress: contractAddr, + Selector: entryPoint, + Calldata: []felt.Felt{ + *storageLocation, + }, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + + require.NoError(t, testState.Update(1, &core.StateUpdate{ + OldRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), + NewRoot: utils.HexToFelt(t, "0x7a9da0a7471a8d5118d3eefb8c26a6acbe204eb1eaa934606f4757a595fe552"), + StateDiff: &core.StateDiff{ + StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ + *contractAddr: { + *storageLocation: new(felt.Felt).SetUint64(37), + }, }, }, - }, - }, nil, false)) - - ret, err = New(false, log).Call(&CallInfo{ - ContractAddress: contractAddr, - Selector: entryPoint, - Calldata: []felt.Felt{ - *storageLocation, - }, - }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(37)}, ret.Result) + }, nil, false)) + + ret, err = New(false, log).Call(&CallInfo{ + ContractAddress: contractAddr, + Selector: entryPoint, + Calldata: []felt.Felt{ + *storageLocation, + }, + }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(37)}, ret.Result) + } } func TestCallInfoErrorHandling(t *testing.T) { @@ -182,45 +211,53 @@ func TestCallInfoErrorHandling(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := core.NewState(txn) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0xa6258de574e5540253c4a52742137d58b9e8ad8f584115bee46d9d18255c42"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, + triedb, err := triedb.New(testDB, nil) + require.NoError(t, err) + stateDB := state.NewStateDB(testDB, triedb) + for _, stateVersion := range []bool{true, false} { + stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0xa6258de574e5540253c4a52742137d58b9e8ad8f584115bee46d9d18255c42"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, + }, }, - }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) - logLevel := utils.NewLogLevel(utils.ERROR) - log, err := utils.NewZapLogger(logLevel, false) - require.NoError(t, err) + logLevel := utils.NewLogLevel(utils.ERROR) + log, err := utils.NewZapLogger(logLevel, false) + require.NoError(t, err) - callInfo := &CallInfo{ - ClassHash: classHash, - ContractAddress: contractAddr, - Selector: utils.HexToFelt(t, "0x123"), // doesn't exist - Calldata: []felt.Felt{}, + callInfo := &CallInfo{ + ClassHash: classHash, + ContractAddress: contractAddr, + Selector: utils.HexToFelt(t, "0x123"), // doesn't exist + Calldata: []felt.Felt{}, + } + + // Starknet version <0.13.4 should return an error + ret, err := New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ + ProtocolVersion: "0.13.0", + }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) + require.Equal(t, CallResult{}, ret) + require.ErrorContains(t, err, "not found in contract") + + // Starknet version 0.13.4 should return an "error" in the CallInfo + ret, err = New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ + ProtocolVersion: "0.13.4", + }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) + require.True(t, ret.ExecutionFailed) + require.Equal(t, len(ret.Result), 1) + require.Equal(t, ret.Result[0].String(), rpccore.EntrypointNotFoundFelt) + require.Nil(t, err) } - - // Starknet version <0.13.4 should return an error - ret, err := New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ - ProtocolVersion: "0.13.0", - }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) - require.Equal(t, CallResult{}, ret) - require.ErrorContains(t, err, "not found in contract") - - // Starknet version 0.13.4 should return an "error" in the CallInfo - ret, err = New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ - ProtocolVersion: "0.13.4", - }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) - require.True(t, ret.ExecutionFailed) - require.Equal(t, len(ret.Result), 1) - require.Equal(t, ret.Result[0].String(), rpccore.EntrypointNotFoundFelt) - require.Nil(t, err) } func TestExecute(t *testing.T) { @@ -229,30 +266,38 @@ func TestExecute(t *testing.T) { testDB := memory.New() txn := testDB.NewIndexedBatch() - state := core.NewState(txn) - - t.Run("empty transaction list", func(t *testing.T) { - _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ - Header: &core.Header{ - Timestamp: 1666877926, - SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), - L1GasPriceETH: &felt.Zero, - L1GasPriceSTRK: &felt.Zero, - }, - }, state, - &network, false, false, false, false, false) + triedb, err := triedb.New(testDB, nil) + require.NoError(t, err) + stateDB := state.NewStateDB(testDB, triedb) + for _, stateVersion := range []bool{true, false} { + stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) require.NoError(t, err) - }) - t.Run("zero data", func(t *testing.T) { - _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ - Header: &core.Header{ - SequencerAddress: &felt.Zero, - L1GasPriceETH: &felt.Zero, - L1GasPriceSTRK: &felt.Zero, - }, - }, state, &network, false, false, false, false, false) + state, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) - }) + + t.Run("empty transaction list", func(t *testing.T) { + _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ + Header: &core.Header{ + Timestamp: 1666877926, + SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), + L1GasPriceETH: &felt.Zero, + L1GasPriceSTRK: &felt.Zero, + }, + }, state, + &network, false, false, false, false, false) + require.NoError(t, err) + }) + t.Run("zero data", func(t *testing.T) { + _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ + Header: &core.Header{ + SequencerAddress: &felt.Zero, + L1GasPriceETH: &felt.Zero, + L1GasPriceSTRK: &felt.Zero, + }, + }, state, &network, false, false, false, false, false) + require.NoError(t, err) + }) + } } func TestSetVersionedConstants(t *testing.T) { From d1bc57104e675c2e95295a2b0eb5e9d928fb480c Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 1 Aug 2025 18:14:51 +0200 Subject: [PATCH 02/47] add flags for picking the state version --- cmd/juno/dbcmd.go | 19 ++++++++++++++++--- cmd/juno/juno.go | 5 +++-- node/node.go | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index 02e0f9b486..3ff88c4695 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -45,12 +45,14 @@ func DBCmd(defaultDBPath string) *cobra.Command { } func DBInfoCmd() *cobra.Command { - return &cobra.Command{ + cmd := &cobra.Command{ Use: "info", Short: "Retrieve database information", Long: `This subcommand retrieves and displays blockchain information stored in the database.`, RunE: dbInfo, } + cmd.Flags().Bool(newStateF, defaultNewState, newStateUsage) + return cmd } func DBSizeCmd() *cobra.Command { @@ -70,6 +72,7 @@ func DBRevertCmd() *cobra.Command { RunE: dbRevert, } cmd.Flags().Uint64(dbRevertToBlockF, 0, "New head (this block won't be reverted)") + cmd.Flags().Bool(newStateF, defaultNewState, newStateUsage) return cmd } @@ -80,13 +83,18 @@ func dbInfo(cmd *cobra.Command, args []string) error { return err } + stateVersion, err := cmd.Flags().GetBool(newStateF) + if err != nil { + return err + } + database, err := openDB(dbPath) if err != nil { return err } defer database.Close() - chain := blockchain.New(database, nil) + chain := blockchain.New(database, nil, stateVersion) var info DBInfo // Get the latest block information @@ -146,6 +154,11 @@ func dbRevert(cmd *cobra.Command, args []string) error { return fmt.Errorf("--%v cannot be 0", dbRevertToBlockF) } + stateVersion, err := cmd.Flags().GetBool(newStateF) + if err != nil { + return err + } + database, err := openDB(dbPath) if err != nil { return err @@ -153,7 +166,7 @@ func dbRevert(cmd *cobra.Command, args []string) error { defer database.Close() for { - chain := blockchain.New(database, nil) + chain := blockchain.New(database, nil, stateVersion) head, err := chain.Head() if err != nil { return fmt.Errorf("failed to get the latest block information: %v", err) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index c9bd8d36da..5211c0ac83 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -144,6 +144,7 @@ const ( defaultHTTPUpdatePort = 0 defaultSubmittedTransactionsCacheSize = 10_000 defaultSubmittedTransactionsCacheEntryTTL = 5 * time.Minute + defaultNewState = false configFlagUsage = "The YAML configuration file." logLevelFlagUsage = "Options: trace, debug, info, warn, error." @@ -210,7 +211,7 @@ const ( httpUpdatePortUsage = "The port on which the log level and gateway timeouts HTTP server will listen for requests." submittedTransactionsCacheSize = "Maximum number of entries in the submitted transactions cache" submittedTransactionsCacheEntryTTL = "Time-to-live for each entry in the submitted transactions cache" - newStateUsage = "Use the new state package implementation instead of the legacy one" + newStateUsage = "EXPERIMENTAL:Use the new state package implementation instead of the legacy one" ) var Version string @@ -418,7 +419,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr defaultSubmittedTransactionsCacheEntryTTL, submittedTransactionsCacheEntryTTL, ) - junoCmd.Flags().Bool(newStateF, false, newStateUsage) + junoCmd.Flags().Bool(newStateF, defaultNewState, newStateUsage) junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath)) return junoCmd diff --git a/node/node.go b/node/node.go index 2ae9729e4c..b6c113482b 100644 --- a/node/node.go +++ b/node/node.go @@ -155,7 +155,7 @@ func New(cfg *Config, version string, logLevel *utils.LogLevel) (*Node, error) { services := make([]service.Service, 0) earlyServices := make([]service.Service, 0) - chain := blockchain.New(database, &cfg.Network) + chain := blockchain.New(database, &cfg.Network, cfg.NewState) // Verify that cfg.Network is compatible with the database. head, err := chain.Head() From b210933f0f48e2e2b48c9d384eeb85c8fb56a917 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 4 Aug 2025 01:13:05 +0200 Subject: [PATCH 03/47] unittests fixes --- blockchain/blockchain.go | 2 +- blockchain/blockchain_test.go | 45 ++- cmd/juno/dbcmd_test.go | 11 +- cmd/juno/juno.go | 2 +- consensus/integtest/integ_test.go | 9 +- .../p2p/validator/empty_fixtures_test.go | 9 +- consensus/p2p/validator/fixtures_test.go | 3 +- .../validator/proposal_stream_demux_test.go | 3 +- consensus/p2p/validator/transition_test.go | 3 +- consensus/proposer/proposer_test.go | 9 +- core/running_event_filter_test.go | 9 +- core/state/commonstate/state.go | 4 +- core/state/state_test_utils/new_state_flag.go | 44 +++ l1/l1_pkg_test.go | 15 +- l1/l1_test.go | 7 +- mempool/mempool_test.go | 16 +- migration/migration_pkg_test.go | 13 +- node/node_test.go | 9 +- plugin/plugin_test.go | 11 +- rpc/v6/block_test.go | 47 ++- rpc/v6/class_test.go | 14 +- rpc/v6/contract_test.go | 38 +- rpc/v6/events_test.go | 3 +- rpc/v6/handlers_test.go | 2 +- rpc/v6/state_update_test.go | 3 +- rpc/v6/trace_test.go | 19 +- rpc/v6/transaction_test.go | 12 +- rpc/v7/block_test.go | 47 ++- rpc/v7/handlers_test.go | 2 +- rpc/v7/storage_test.go | 26 +- rpc/v7/trace_test.go | 19 +- rpc/v7/transaction_test.go | 12 +- rpc/v8/block_test.go | 41 +- rpc/v8/handlers_test.go | 16 +- rpc/v8/l1_test.go | 4 +- rpc/v8/simulation_pkg_test.go | 7 + rpc/v8/storage_test.go | 37 +- rpc/v8/subscriptions_test.go | 9 +- rpc/v8/trace_test.go | 17 +- rpc/v8/transaction_test.go | 16 +- rpc/v9/block_test.go | 41 +- rpc/v9/class_test.go | 30 +- rpc/v9/events_test.go | 3 +- rpc/v9/handlers_test.go | 2 +- rpc/v9/l1_test.go | 4 +- rpc/v9/nonce_test.go | 14 +- rpc/v9/simulation_pkg_test.go | 7 + rpc/v9/state_update_test.go | 13 +- rpc/v9/storage_test.go | 43 +- rpc/v9/subscriptions_test.go | 9 +- rpc/v9/trace_test.go | 17 +- rpc/v9/transaction_test.go | 18 +- sequencer/sequencer_test.go | 11 +- sync/pending_test.go | 28 +- sync/sync_test.go | 31 +- vm/vm_test.go | 371 +++++++++--------- 56 files changed, 716 insertions(+), 541 deletions(-) create mode 100644 core/state/state_test_utils/new_state_flag.go diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index eab92860a7..977c632530 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -933,7 +933,7 @@ func (b *Blockchain) WriteRunningEventFilter() error { } func (b *Blockchain) Stop() error { - if b.trieDB.Scheme() == triedb.PathScheme { + if b.trieDB.Scheme() == triedb.PathScheme && b.StateFactory.UseNewState { head, err := b.HeadsHeader() if err != nil { return err diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index d8d0a68608..d5deb5bd72 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -2,12 +2,14 @@ package blockchain_test import ( "fmt" + "os" "testing" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -17,13 +19,18 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + var emptyCommitments = core.BlockCommitments{} func TestNew(t *testing.T) { client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) t.Run("empty blockchain's head is nil", func(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) assert.Equal(t, &utils.Mainnet, chain.Network()) b, err := chain.Head() assert.Nil(t, b) @@ -37,10 +44,10 @@ func TestNew(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &utils.Mainnet) + chain = blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) b, err := chain.Head() require.NoError(t, err) assert.Equal(t, block0, b) @@ -51,7 +58,7 @@ func TestHeight(t *testing.T) { client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) t.Run("return nil if blockchain is empty", func(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Sepolia) + chain := blockchain.New(memory.New(), &utils.Sepolia, statetestutils.UseNewState()) _, err := chain.Height() assert.Error(t, err) }) @@ -63,10 +70,10 @@ func TestHeight(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &utils.Mainnet) + chain = blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) height, err := chain.Height() require.NoError(t, err) assert.Equal(t, block0.Number, height) @@ -74,7 +81,7 @@ func TestHeight(t *testing.T) { } func TestBlockByNumberAndHash(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Sepolia) + chain := blockchain.New(memory.New(), &utils.Sepolia, statetestutils.UseNewState()) t.Run("same block is returned for both core.GetBlockByNumber and GetBlockByHash", func(t *testing.T) { client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -110,7 +117,7 @@ func TestVerifyBlock(t *testing.T) { h1, err := new(felt.Felt).SetRandom() require.NoError(t, err) - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) t.Run("error if chain is empty and incoming block number is not 0", func(t *testing.T) { block := &core.Block{Header: &core.Header{Number: 10}} @@ -187,7 +194,7 @@ func TestSanityCheckNewHeight(t *testing.T) { h1, err := new(felt.Felt).SetRandom() require.NoError(t, err) - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) @@ -232,7 +239,7 @@ func TestStore(t *testing.T) { require.NoError(t, err) t.Run("add block to empty blockchain", func(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) require.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) headBlock, err := chain.Head() @@ -255,7 +262,7 @@ func TestStore(t *testing.T) { stateUpdate1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) require.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) require.NoError(t, chain.Store(block1, &emptyCommitments, stateUpdate1, nil)) @@ -276,7 +283,7 @@ func TestStore(t *testing.T) { func TestStoreL1HandlerTxnHash(t *testing.T) { client := feeder.NewTestClient(t, &utils.Sepolia) gw := adaptfeeder.New(client) - chain := blockchain.New(memory.New(), &utils.Sepolia) + chain := blockchain.New(memory.New(), &utils.Sepolia, statetestutils.UseNewState()) var stateUpdate *core.StateUpdate for i := range uint64(7) { block, err := gw.BlockByNumber(t.Context(), i) @@ -293,7 +300,7 @@ func TestStoreL1HandlerTxnHash(t *testing.T) { } func TestBlockCommitments(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -318,7 +325,7 @@ func TestBlockCommitments(t *testing.T) { } func TestTransactionAndReceipt(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -406,7 +413,7 @@ func TestTransactionAndReceipt(t *testing.T) { func TestState(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -474,7 +481,7 @@ func TestEvents(t *testing.T) { } testDB := memory.New() - chain := blockchain.New(testDB, &utils.Goerli2) + chain := blockchain.New(testDB, &utils.Goerli2, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Goerli2) gw := adaptfeeder.New(client) @@ -590,7 +597,7 @@ func TestEvents(t *testing.T) { func TestRevert(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -670,7 +677,7 @@ func TestL1Update(t *testing.T) { for _, head := range heads { t.Run(fmt.Sprintf("update L1 head to block %d", head.BlockNumber), func(t *testing.T) { - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) require.NoError(t, chain.SetL1Head(head)) got, err := chain.L1Head() require.NoError(t, err) @@ -685,7 +692,7 @@ func TestSubscribeL1Head(t *testing.T) { StateRoot: new(felt.Felt).SetUint64(2), } - chain := blockchain.New(memory.New(), &utils.Mainnet) + chain := blockchain.New(memory.New(), &utils.Mainnet, statetestutils.UseNewState()) sub := chain.SubscribeL1Head() t.Cleanup(sub.Unsubscribe) diff --git a/cmd/juno/dbcmd_test.go b/cmd/juno/dbcmd_test.go index 7ab12059f6..cdb344e145 100644 --- a/cmd/juno/dbcmd_test.go +++ b/cmd/juno/dbcmd_test.go @@ -1,6 +1,7 @@ package main_test import ( + "os" "strconv" "testing" @@ -8,6 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" juno "github.com/NethermindEth/juno/cmd/juno" "github.com/NethermindEth/juno/core" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/pebble" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" @@ -16,6 +18,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + var emptyCommitments = core.BlockCommitments{} func TestDBCmd(t *testing.T) { @@ -55,7 +62,7 @@ func TestDBCmd(t *testing.T) { require.NoError(t, db.Close()) }) - chain := blockchain.New(db, &network) + chain := blockchain.New(db, &network, statetestutils.UseNewState()) block, err := chain.Head() require.NoError(t, err) assert.Equal(t, revertToBlock, block.Number) @@ -79,7 +86,7 @@ func prepareDB(t *testing.T, network *utils.Network, syncToBlock uint64) string testDB, err := pebble.New(dbPath) require.NoError(t, err) - chain := blockchain.New(testDB, network) + chain := blockchain.New(testDB, network, statetestutils.UseNewState()) for blockNumber := uint64(0); blockNumber <= syncToBlock; blockNumber++ { block, err := gw.BlockByNumber(t.Context(), blockNumber) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index 5211c0ac83..9d8aacd645 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -211,7 +211,7 @@ const ( httpUpdatePortUsage = "The port on which the log level and gateway timeouts HTTP server will listen for requests." submittedTransactionsCacheSize = "Maximum number of entries in the submitted transactions cache" submittedTransactionsCacheEntryTTL = "Time-to-live for each entry in the submitted transactions cache" - newStateUsage = "EXPERIMENTAL:Use the new state package implementation instead of the legacy one" + newStateUsage = "EXPERIMENTAL: Use the new state package implementation instead of the legacy one" ) var Version string diff --git a/consensus/integtest/integ_test.go b/consensus/integtest/integ_test.go index 092aad283d..762f10c84a 100644 --- a/consensus/integtest/integ_test.go +++ b/consensus/integtest/integ_test.go @@ -2,6 +2,7 @@ package integtest import ( "fmt" + "os" "testing" "time" @@ -18,6 +19,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/utils" @@ -74,6 +76,11 @@ func getTimeoutFn(cfg *testConfig) func(types.Step, types.Round) time.Duration { } } +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func newDB(t *testing.T) *mocks.MockTendermintDB[starknet.Value, starknet.Hash, starknet.Address] { t.Helper() ctrl := gomock.NewController(t) @@ -92,7 +99,7 @@ func getBuilder(t *testing.T, genesisDiff core.StateDiff, genesisClasses map[fel network := &utils.Mainnet log := utils.NewNopZapLogger() - bc := bc.New(testDB, network) + bc := bc.New(testDB, network, statetestutils.UseNewState()) require.NoError(t, bc.StoreGenesis(&genesisDiff, genesisClasses)) executor := builder.NewExecutor(bc, vm.New(false, log), log, false, true) diff --git a/consensus/p2p/validator/empty_fixtures_test.go b/consensus/p2p/validator/empty_fixtures_test.go index d24ceaa341..bec88e41f9 100644 --- a/consensus/p2p/validator/empty_fixtures_test.go +++ b/consensus/p2p/validator/empty_fixtures_test.go @@ -2,6 +2,7 @@ package validator import ( "math/rand/v2" + "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -9,6 +10,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" "github.com/starknet-io/starknet-p2pspecs/p2p/proto/common" @@ -25,6 +27,11 @@ type EmptyTestFixture struct { Proposal *starknet.Proposal } +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func NewEmptyTestFixture( t *testing.T, executor *mockExecutor, @@ -45,7 +52,7 @@ func NewEmptyTestFixture( executor.RegisterBuildResult(&buildResult) - b := builder.New(blockchain.New(database, testCase.Network), executor) + b := builder.New(blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor) proposalCommitment := EmptyProposalCommitment(headBlock, proposer, timestamp) diff --git a/consensus/p2p/validator/fixtures_test.go b/consensus/p2p/validator/fixtures_test.go index 9bf4be39ce..f65baebdc0 100644 --- a/consensus/p2p/validator/fixtures_test.go +++ b/consensus/p2p/validator/fixtures_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/starknet" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -108,7 +109,7 @@ func BuildTestFixture( executor.RegisterBuildResult(&buildResult) - builder := builder.New(blockchain.New(database, testCase.Network), executor) + builder := builder.New(blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor) return TestFixture{ ProposalInit: &proposalInit, diff --git a/consensus/p2p/validator/proposal_stream_demux_test.go b/consensus/p2p/validator/proposal_stream_demux_test.go index e501ea9049..bbb00afed6 100644 --- a/consensus/p2p/validator/proposal_stream_demux_test.go +++ b/consensus/p2p/validator/proposal_stream_demux_test.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/utils" @@ -63,7 +64,7 @@ func TestProposalStreamDemux(t *testing.T) { network := &utils.SepoliaIntegration executor := NewMockExecutor(t, network) database := memory.New() - bc := blockchain.New(database, network) + bc := blockchain.New(database, network, statetestutils.UseNewState()) builder := builder.New(bc, executor) transition := NewTransition(&builder) proposalStore := proposal.ProposalStore[starknet.Hash]{} diff --git a/consensus/p2p/validator/transition_test.go b/consensus/p2p/validator/transition_test.go index a4ddd5d721..4b7b46c91c 100644 --- a/consensus/p2p/validator/transition_test.go +++ b/consensus/p2p/validator/transition_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -29,7 +30,7 @@ func getBuilder(t *testing.T, seqAddr *felt.Felt) (*builder.Builder, *core.Heade t.Helper() testDB := memory.New() network := &utils.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New(testDB, network, statetestutils.UseNewState()) log := utils.NewNopZapLogger() privKey, err := ecdsa.GenerateKey(rand.Reader) diff --git a/consensus/proposer/proposer_test.go b/consensus/proposer/proposer_test.go index 4a13507010..8de6f89fae 100644 --- a/consensus/proposer/proposer_test.go +++ b/consensus/proposer/proposer_test.go @@ -3,6 +3,7 @@ package proposer_test import ( "context" "fmt" + "os" "slices" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -37,6 +39,11 @@ const ( var allBatchSizes = []int{1, 0, 3, 2, 4, 0, 1} +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestProposer(t *testing.T) { logger, err := utils.NewZapLogger(utils.NewLogLevel(logLevel), true) require.NoError(t, err) @@ -173,7 +180,7 @@ func getBlockchain(t *testing.T) *blockchain.Blockchain { t.Helper() testDB := memory.New() network := &utils.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New(testDB, network, statetestutils.UseNewState()) return bc } diff --git a/core/running_event_filter_test.go b/core/running_event_filter_test.go index 2a02f8ab41..bd42ccfa2c 100644 --- a/core/running_event_filter_test.go +++ b/core/running_event_filter_test.go @@ -1,6 +1,7 @@ package core_test import ( + "flag" "testing" "github.com/NethermindEth/juno/blockchain" @@ -17,6 +18,12 @@ import ( "github.com/stretchr/testify/require" ) +var newState bool + +func init() { + flag.BoolVar(&newState, "use-new-state", false, "...") +} + func testBloomWithRandomKeys(t *testing.T, numKeys uint) *bloom.BloomFilter { t.Helper() filter := bloom.New(core.EventsBloomLength, core.EventsBloomHashFuncs) @@ -67,7 +74,7 @@ func TestRunningEventFilter_LazyInitialization_EmptyDB(t *testing.T) { func TestRunningEventFilter_LazyInitialization_Preload(t *testing.T) { testDB := memory.New() n := &utils.Sepolia - chain := blockchain.New(testDB, n) + chain := blockchain.New(testDB, n, newState) client := feeder.NewTestClient(t, n) gw := adaptfeeder.New(client) diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 288fedcdc7..6b6f7f2f38 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -276,8 +276,8 @@ type StateFactory struct { stateDB *state.StateDB } -func NewStateFactory(useNewState bool, triedb *triedb.Database, stateDB *state.StateDB) (*StateFactory, error) { - if !useNewState { +func NewStateFactory(newState bool, triedb *triedb.Database, stateDB *state.StateDB) (*StateFactory, error) { + if !newState { return &StateFactory{UseNewState: false}, nil } diff --git a/core/state/state_test_utils/new_state_flag.go b/core/state/state_test_utils/new_state_flag.go new file mode 100644 index 0000000000..606dac39d3 --- /dev/null +++ b/core/state/state_test_utils/new_state_flag.go @@ -0,0 +1,44 @@ +package statetestutils + +import ( + "flag" + "fmt" + "os" + "strings" + "sync" +) + +var ( + once sync.Once + parsed bool + useNewState bool +) + +func parseFlags() { + flag.BoolVar(&useNewState, "use-new-state", false, "use new state implementation") + fmt.Println("use-new-state", useNewState) + + cleanArgs := []string{os.Args[0]} + for i := 1; i < len(os.Args); i++ { + arg := os.Args[i] + if arg == "-use-new-state" || strings.HasPrefix(arg, "-use-new-state=") { + continue + } + cleanArgs = append(cleanArgs, arg) + } + os.Args = cleanArgs + + flag.Parse() + parsed = true +} + +func Parse() { + once.Do(parseFlags) +} + +func UseNewState() bool { + if !parsed { + Parse() + } + return useNewState +} diff --git a/l1/l1_pkg_test.go b/l1/l1_pkg_test.go index 0e94d02e01..dbceb69c50 100644 --- a/l1/l1_pkg_test.go +++ b/l1/l1_pkg_test.go @@ -4,12 +4,14 @@ import ( "context" "errors" "math/big" + "os" "testing" "time" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1/contract" "github.com/NethermindEth/juno/mocks" @@ -113,6 +115,11 @@ var longSequenceOfBlocks = []*l1Block{ }, } +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestClient(t *testing.T) { t.Parallel() @@ -336,7 +343,7 @@ func TestClient(t *testing.T) { ctrl := gomock.NewController(t) nopLog := utils.NewNopZapLogger() network := utils.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New(memory.New(), &network, statetestutils.UseNewState()) client := NewClient(nil, chain, nopLog).WithResubscribeDelay(0).WithPollFinalisedInterval(time.Nanosecond) @@ -384,7 +391,7 @@ func TestClient(t *testing.T) { BlockHash: block.expectedL2BlockHash, StateRoot: block.expectedL2BlockHash, } - assert.Equal(t, want, got) + assert.Equal(t, want, &got) } } }) @@ -397,7 +404,7 @@ func TestUnreliableSubscription(t *testing.T) { ctrl := gomock.NewController(t) nopLog := utils.NewNopZapLogger() network := utils.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New(memory.New(), &network, statetestutils.UseNewState()) client := NewClient(nil, chain, nopLog).WithResubscribeDelay(0).WithPollFinalisedInterval(time.Nanosecond) err := errors.New("test err") @@ -462,7 +469,7 @@ func TestUnreliableSubscription(t *testing.T) { BlockHash: block.expectedL2BlockHash, StateRoot: block.expectedL2BlockHash, } - assert.Equal(t, want, got) + assert.Equal(t, want, &got) } } } diff --git a/l1/l1_test.go b/l1/l1_test.go index 0c4c94d64d..807ae040ba 100644 --- a/l1/l1_test.go +++ b/l1/l1_test.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1" "github.com/NethermindEth/juno/l1/contract" @@ -54,7 +55,7 @@ func TestFailToCreateSubscription(t *testing.T) { network := utils.Mainnet ctrl := gomock.NewController(t) nopLog := utils.NewNopZapLogger() - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New(memory.New(), &network, statetestutils.UseNewState()) subscriber := mocks.NewMockSubscriber(ctrl) @@ -85,7 +86,7 @@ func TestMismatchedChainID(t *testing.T) { network := utils.Mainnet ctrl := gomock.NewController(t) nopLog := utils.NewNopZapLogger() - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New(memory.New(), &network, statetestutils.UseNewState()) subscriber := mocks.NewMockSubscriber(ctrl) @@ -110,7 +111,7 @@ func TestEventListener(t *testing.T) { ctrl := gomock.NewController(t) nopLog := utils.NewNopZapLogger() network := utils.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New(memory.New(), &network, statetestutils.UseNewState()) subscriber := mocks.NewMockSubscriber(ctrl) subscriber. diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 2c08048aa6..1ac6885274 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" _ "github.com/NethermindEth/juno/encoder/registry" @@ -40,6 +41,11 @@ func setupDatabase(dbPath string, dltExisting bool) (db.KeyValueStore, func(), e return persistentPool, closer, nil } +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestMempool(t *testing.T) { testDB, dbCloser, err := setupDatabase("testmempool", true) log := utils.NewNopZapLogger() @@ -62,7 +68,7 @@ func TestMempool(t *testing.T) { for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) chain.EXPECT().HeadState().Return(state, func() error { return nil }, nil) - state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) + state.EXPECT().ContractNonce(senderAddress).Return(felt.Zero, nil) require.NoError(t, pool.Push(t.Context(), &mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -85,7 +91,7 @@ func TestMempool(t *testing.T) { for i := uint64(4); i < 6; i++ { senderAddress := new(felt.Felt).SetUint64(i) chain.EXPECT().HeadState().Return(state, func() error { return nil }, nil) - state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) + state.EXPECT().ContractNonce(senderAddress).Return(felt.Zero, nil) require.NoError(t, pool.Push(t.Context(), &mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -139,7 +145,7 @@ func TestRestoreMempool(t *testing.T) { for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) chain.EXPECT().HeadState().Return(state, func() error { return nil }, nil) - state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + state.EXPECT().ContractNonce(senderAddress).Return(felt.Zero, nil) tx := mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -194,7 +200,7 @@ func TestWait(t *testing.T) { defer dbCloser() mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) - bc := blockchain.New(testDB, &utils.Sepolia) + bc := blockchain.New(testDB, &utils.Sepolia, statetestutils.UseNewState()) block0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) stateUpdate0, err := gw.StateUpdate(t.Context(), 0) @@ -262,7 +268,7 @@ func TestPopBatch(t *testing.T) { for i := start; i <= end; i++ { senderAddress := new(felt.Felt).SetUint64(i) chain.EXPECT().HeadState().Return(state, func() error { return nil }, nil) - state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) + state.EXPECT().ContractNonce(senderAddress).Return(felt.Zero, nil) require.NoError(t, pool.Push(t.Context(), &mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index d47c62f3d6..9fcea5dda0 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -7,12 +7,14 @@ import ( "encoding/json" "errors" "math/rand" + "os" "testing" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" @@ -25,6 +27,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestMigration0000(t *testing.T) { testDB := memory.New() @@ -85,7 +92,7 @@ func TestRelocateContractStorageRootKeys(t *testing.T) { func TestRecalculateBloomFilters(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -187,7 +194,7 @@ func TestChangeTrieNodeEncoding(t *testing.T) { func TestCalculateBlockCommitments(t *testing.T) { testdb := memory.New() - chain := blockchain.New(testdb, &utils.Mainnet) + chain := blockchain.New(testdb, &utils.Mainnet, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) @@ -211,7 +218,7 @@ func TestCalculateBlockCommitments(t *testing.T) { func TestL1HandlerTxns(t *testing.T) { testdb := memory.New() - chain := blockchain.New(testdb, &utils.Sepolia) + chain := blockchain.New(testdb, &utils.Sepolia, statetestutils.UseNewState()) client := feeder.NewTestClient(t, &utils.Sepolia) gw := adaptfeeder.New(client) diff --git a/node/node_test.go b/node/node_test.go index 58da3dc2b3..eeab1b1e5f 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -2,11 +2,13 @@ package node_test import ( "context" + "os" "testing" "time" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/node" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -15,6 +17,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + // Create a new node with all services enabled. func TestNewNode(t *testing.T) { config := &node.Config{ @@ -72,7 +79,7 @@ func TestNetworkVerificationOnNonEmptyDB(t *testing.T) { log := utils.NewNopZapLogger() database, err := pebble.New(dbPath) require.NoError(t, err) - chain := blockchain.New(database, &network) + chain := blockchain.New(database, &network, statetestutils.UseNewState()) ctx, cancel := context.WithCancel(t.Context()) dataSource := sync.NewFeederGatewayDataSource(chain, adaptfeeder.New(feeder.NewTestClient(t, &network))) syncer := sync.New(chain, dataSource, log, 0, 0, false, database). diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index 9509a73717..8d5bfc89f7 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -2,11 +2,13 @@ package plugin_test import ( "context" + "os" "testing" "time" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" junoplugin "github.com/NethermindEth/juno/plugin" @@ -17,6 +19,11 @@ import ( "go.uber.org/mock/gomock" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestPlugin(t *testing.T) { timeout := time.Second mockCtrl := gomock.NewController(t) @@ -38,7 +45,7 @@ func TestPlugin(t *testing.T) { require.NoError(t, err) plugin.EXPECT().NewBlock(block, su, gomock.Any()) } - bc := blockchain.New(testDB, &utils.Integration) + bc := blockchain.New(testDB, &utils.Integration, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, integGw) synchronizer := sync.New(bc, dataSource, utils.NewNopZapLogger(), 0, 0, false, nil).WithPlugin(plugin) @@ -48,7 +55,7 @@ func TestPlugin(t *testing.T) { require.NoError(t, bc.Stop()) t.Run("resync to mainnet with the same db", func(t *testing.T) { - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) // Ensure current head is Integration head head, err := bc.HeadsHeader() diff --git a/rpc/v6/block_test.go b/rpc/v6/block_test.go index ccbf5f3a40..4e00a8354f 100644 --- a/rpc/v6/block_test.go +++ b/rpc/v6/block_test.go @@ -2,12 +2,14 @@ package rpcv6_test import ( "errors" + "os" "testing" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -21,6 +23,11 @@ import ( "go.uber.org/mock/gomock" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestBlockId(t *testing.T) { t.Parallel() tests := map[string]struct { @@ -273,7 +280,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { //nolint:goconst mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PendingData().Return(nil, sync.ErrPendingBlockNotFound) @@ -323,7 +330,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -333,7 +340,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Hash: latestBlockHash}) require.Nil(t, rpcErr) @@ -343,7 +350,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Number: latestBlockNumber}) require.Nil(t, rpcErr) @@ -353,7 +360,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -374,7 +381,7 @@ func TestBlockWithTxHashes(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Pending: true}) require.Nil(t, rpcErr) checkLatestBlock(t, block) @@ -390,7 +397,7 @@ func TestBlockWithTxHashes(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) blockWTxHashes, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -429,7 +436,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PendingData().Return(nil, sync.ErrPendingBlockNotFound) @@ -480,7 +487,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -493,7 +500,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Hash: latestBlockHash}) require.Nil(t, rpcErr) @@ -506,7 +513,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Number: latestBlockNumber}) require.Nil(t, rpcErr) @@ -522,7 +529,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -548,7 +555,7 @@ func TestBlockWithTxs(t *testing.T) { &pending, nil, ).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpc.BlockID{Pending: true}) require.Nil(t, rpcErr) blockWithTxs, rpcErr := handler.BlockWithTxs(rpc.BlockID{Pending: true}) @@ -567,7 +574,7 @@ func TestBlockWithTxs(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) blockWithTxs, rpcErr := handler.BlockWithTxs(rpc.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -603,7 +610,7 @@ func TestBlockWithTxHashesV013(t *testing.T) { require.True(t, ok) mockReader.EXPECT().BlockByNumber(gomock.Any()).Return(coreBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) got, rpcErr := handler.BlockWithTxs(rpc.BlockID{Number: blockNumber}) require.Nil(t, rpcErr) got.Transactions = got.Transactions[:1] @@ -681,7 +688,7 @@ func TestBlockWithReceipts(t *testing.T) { err := errors.New("l1 failure") mockReader.EXPECT().BlockByNumber(blockID.Number).Return(block, nil) - mockReader.EXPECT().L1Head().Return(nil, err) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, err) resp, rpcErr := handler.BlockWithReceipts(blockID) assert.Nil(t, resp) @@ -714,7 +721,7 @@ func TestBlockWithReceipts(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) resp, rpcErr := handler.BlockWithReceipts(blockID) header := resp.BlockHeader @@ -743,7 +750,7 @@ func TestBlockWithReceipts(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(block0.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) resp, rpcErr := handler.BlockWithReceipts(blockID) header := resp.BlockHeader @@ -769,7 +776,7 @@ func TestBlockWithReceipts(t *testing.T) { blockID := rpc.BlockID{Number: block1.Number} mockReader.EXPECT().BlockByNumber(blockID.Number).Return(block1, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block1.Number + 1, }, nil) @@ -822,7 +829,7 @@ func TestRpcBlockAdaptation(t *testing.T) { require.NoError(t, err) latestBlock.Header.SequencerAddress = nil mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) block, rpcErr := handler.BlockWithTxs(rpc.BlockID{Latest: true}) require.NoError(t, err, rpcErr) diff --git a/rpc/v6/class_test.go b/rpc/v6/class_test.go index 9e4becdd3c..d47ecd1646 100644 --- a/rpc/v6/class_test.go +++ b/rpc/v6/class_test.go @@ -107,11 +107,11 @@ func TestClassAt(t *testing.T) { cairo0ContractAddress, _ := new(felt.Felt).SetRandom() cairo0ClassHash := utils.HexToFelt(t, "0x4631b6b3fa31e140524b7d21ba784cea223e618bffe60b5bbdca44a8b45be04") - mockState.EXPECT().ContractClassHash(cairo0ContractAddress).Return(cairo0ClassHash, nil) + mockState.EXPECT().ContractClassHash(cairo0ContractAddress).Return(*cairo0ClassHash, nil) cairo1ContractAddress, _ := new(felt.Felt).SetRandom() cairo1ClassHash := utils.HexToFelt(t, "0x1cd2edfb485241c4403254d550de0a097fa76743cd30696f714a491a454bad5") - mockState.EXPECT().ContractClassHash(cairo1ContractAddress).Return(cairo1ClassHash, nil) + mockState.EXPECT().ContractClassHash(cairo1ContractAddress).Return(*cairo1ClassHash, nil) mockState.EXPECT().Class(gomock.Any()).DoAndReturn(func(classHash *felt.Felt) (*core.DeclaredClass, error) { class, err := integGw.Class(t.Context(), classHash) @@ -185,7 +185,7 @@ func TestClassHashAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, errors.New("non-existent contract")) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, errors.New("non-existent contract")) classHash, rpcErr := handler.ClassHashAt(rpc.BlockID{Latest: true}, felt.Zero) require.Nil(t, classHash) @@ -196,7 +196,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) classHash, rpcErr := handler.ClassHashAt(rpc.BlockID{Latest: true}, felt.Zero) require.Nil(t, rpcErr) @@ -205,7 +205,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) classHash, rpcErr := handler.ClassHashAt(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) require.Nil(t, rpcErr) @@ -214,7 +214,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) classHash, rpcErr := handler.ClassHashAt(rpc.BlockID{Number: 0}, felt.Zero) require.Nil(t, rpcErr) @@ -225,7 +225,7 @@ func TestClassHashAt(t *testing.T) { pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) classHash, rpcErr := handler.ClassHashAt(rpc.BlockID{Pending: true}, felt.Zero) require.Nil(t, rpcErr) diff --git a/rpc/v6/contract_test.go b/rpc/v6/contract_test.go index 07ae7a2937..4a40d14114 100644 --- a/rpc/v6/contract_test.go +++ b/rpc/v6/contract_test.go @@ -54,7 +54,7 @@ func TestNonce(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(nil, errors.New("non-existent contract")) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(felt.Zero, errors.New("non-existent contract")) nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) require.Nil(t, nonce) @@ -65,7 +65,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) require.Nil(t, rpcErr) @@ -74,7 +74,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) nonce, rpcErr := handler.Nonce(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) require.Nil(t, rpcErr) @@ -83,7 +83,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) nonce, rpcErr := handler.Nonce(rpc.BlockID{Number: 0}, felt.Zero) require.Nil(t, rpcErr) @@ -94,7 +94,7 @@ func TestNonce(t *testing.T) { pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) nonce, rpcErr := handler.Nonce(rpc.BlockID{Pending: true}, felt.Zero) require.Nil(t, rpcErr) @@ -140,7 +140,7 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, db.ErrKeyNotFound) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) require.Nil(t, storageValue) @@ -149,7 +149,7 @@ func TestStorageAt(t *testing.T) { t.Run("internal error while retrieving contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, errors.New("some internal error")) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, errors.New("some internal error")) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) require.Nil(t, storageValue) @@ -158,8 +158,8 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) require.Equal(t, storageValue, &felt.Zero) @@ -168,8 +168,8 @@ func TestStorageAt(t *testing.T) { t.Run("internal error while retrieving key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, errors.New("some internal error")) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, errors.New("some internal error")) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) require.Nil(t, storageValue) @@ -180,8 +180,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -190,8 +190,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) require.Nil(t, rpcErr) @@ -200,8 +200,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) require.Nil(t, rpcErr) @@ -212,8 +212,8 @@ func TestStorageAt(t *testing.T) { pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Pending: true}) require.Nil(t, rpcErr) diff --git a/rpc/v6/events_test.go b/rpc/v6/events_test.go index f85ef2c5d7..1895bc3a99 100644 --- a/rpc/v6/events_test.go +++ b/rpc/v6/events_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -24,7 +25,7 @@ func TestEvents(t *testing.T) { testDB := memory.New() n := &utils.Sepolia - chain := blockchain.New(testDB, n) + chain := blockchain.New(testDB, n, statetestutils.UseNewState()) mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) diff --git a/rpc/v6/handlers_test.go b/rpc/v6/handlers_test.go index dc964a7e66..5e4054c9e6 100644 --- a/rpc/v6/handlers_test.go +++ b/rpc/v6/handlers_test.go @@ -40,7 +40,7 @@ func TestThrottledVMError(t *testing.T) { t.Run("call", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(new(felt.Felt), nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().Class(new(felt.Felt)).Return(&core.DeclaredClass{Class: &core.Cairo1Class{ Program: []*felt.Felt{ new(felt.Felt), diff --git a/rpc/v6/state_update_test.go b/rpc/v6/state_update_test.go index 174740883e..ebd43257ab 100644 --- a/rpc/v6/state_update_test.go +++ b/rpc/v6/state_update_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -33,7 +34,7 @@ func TestStateUpdate(t *testing.T) { n := &utils.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PendingData().Return(nil, sync.ErrPendingBlockNotFound) diff --git a/rpc/v6/trace_test.go b/rpc/v6/trace_test.go index 89b80b0204..ff46e23d10 100644 --- a/rpc/v6/trace_test.go +++ b/rpc/v6/trace_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -84,7 +85,7 @@ func AssertTracedBlockTransactions(t *testing.T, n *utils.Network, tests map[str mockReader.EXPECT().BlockByNumber(gomock.Any()).DoAndReturn(func(number uint64) (block *core.Block, err error) { return gateway.BlockByNumber(t.Context(), number) }).AnyTimes() - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() for description, test := range tests { t.Run(description, func(t *testing.T) { @@ -125,7 +126,7 @@ func TestTraceBlockTransactionsReturnsError(t *testing.T) { mockReader.EXPECT().BlockByHash(gomock.Any()).DoAndReturn(func(_ *felt.Felt) (block *core.Block, err error) { return mockReader.BlockByNumber(blockNumber) }) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() // No feeder client is set handler := rpc.New(mockReader, nil, nil, n, nil) @@ -392,7 +393,7 @@ func TestTraceTransaction(t *testing.T) { return gateway.BlockByNumber(t.Context(), blockNumber) }).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: 19, // Doesn't really matter for this test }, nil) @@ -507,7 +508,7 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) handler := rpc.New(chain, nil, nil, n, log) update, rpcErr := handler.TraceBlockTransactions(t.Context(), id) @@ -1121,7 +1122,7 @@ func TestCall(t *testing.T) { t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, errors.New("unknown contract")) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, errors.New("unknown contract")) res, rpcErr := handler.Call(&rpc.FunctionCall{}, &rpc.BlockID{Latest: true}) require.Nil(t, res) @@ -1158,7 +1159,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1213,7 +1214,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), false).Return(expectedRes, nil) @@ -1254,7 +1255,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) @@ -1294,7 +1295,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) diff --git a/rpc/v6/transaction_test.go b/rpc/v6/transaction_test.go index 2dc3aff8c8..2d132ba023 100644 --- a/rpc/v6/transaction_test.go +++ b/rpc/v6/transaction_test.go @@ -706,7 +706,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(tx0HashInBlock4850).Return(block.Transactions[0], nil) mockReader.EXPECT().Receipt(tx0HashInBlock4850).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, errors.New("some internal error")) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, errors.New("some internal error")) txReceipt, rpcErr := handler.TransactionReceiptByHash(*tx0HashInBlock4850) @@ -795,7 +795,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(tx0HashInBlock4850).Return(block.Transactions[0], nil) mockReader.EXPECT().Receipt(tx0HashInBlock4850).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{BlockNumber: 4851}, nil) // block number not very important here + mockReader.EXPECT().L1Head().Return(core.L1Head{BlockNumber: 4851}, nil) // block number not very important here txReceipt, rpcErr := handler.TransactionReceiptByHash(*tx0HashInBlock4850) @@ -994,7 +994,7 @@ func TestLegacyTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[i].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[i], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[i], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block0.Number, BlockHash: block0.Hash, StateRoot: block0.GlobalStateRoot, @@ -1029,7 +1029,7 @@ func TestLegacyTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(revertedTxnHash).Return(blockWithRevertedTxn.Transactions[revertedTxnIdx], nil) mockReader.EXPECT().Receipt(revertedTxnHash).Return(blockWithRevertedTxn.Receipts[revertedTxnIdx], blockWithRevertedTxn.Hash, blockWithRevertedTxn.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, revertedTxnHash, expected) }) @@ -1086,7 +1086,7 @@ func TestLegacyTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -1514,7 +1514,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) diff --git a/rpc/v7/block_test.go b/rpc/v7/block_test.go index 6a5fc4088d..ce0dbd5b99 100644 --- a/rpc/v7/block_test.go +++ b/rpc/v7/block_test.go @@ -2,12 +2,14 @@ package rpcv7_test import ( "errors" + "os" "testing" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -22,6 +24,11 @@ import ( "go.uber.org/mock/gomock" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestBlockId(t *testing.T) { t.Parallel() tests := map[string]struct { @@ -107,7 +114,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { //nolint:goconst mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -159,7 +166,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -169,7 +176,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Hash: latestBlockHash}) require.Nil(t, rpcErr) @@ -179,7 +186,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Number: latestBlockNumber}) require.Nil(t, rpcErr) @@ -189,7 +196,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -210,7 +217,7 @@ func TestBlockWithTxHashes(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) block, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -227,7 +234,7 @@ func TestBlockWithTxHashes(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) blockWTxHashes, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -269,7 +276,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -327,7 +334,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -340,7 +347,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Hash: latestBlockHash}) require.Nil(t, rpcErr) @@ -353,7 +360,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Number: latestBlockNumber}) require.Nil(t, rpcErr) @@ -369,7 +376,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -395,7 +402,7 @@ func TestBlockWithTxs(t *testing.T) { &pending, nil, ).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(rpcv7.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -416,7 +423,7 @@ func TestBlockWithTxs(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) blockWithTxs, rpcErr := handler.BlockWithTxs(rpcv7.BlockID{Pending: true}) require.Nil(t, rpcErr) @@ -454,7 +461,7 @@ func TestBlockWithTxHashesV013(t *testing.T) { require.True(t, ok) mockReader.EXPECT().BlockByNumber(gomock.Any()).Return(coreBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) got, rpcErr := handler.BlockWithTxs(rpcv7.BlockID{Number: blockNumber}) require.Nil(t, rpcErr) got.Transactions = got.Transactions[:1] @@ -537,7 +544,7 @@ func TestBlockWithReceipts(t *testing.T) { err := errors.New("l1 failure") mockReader.EXPECT().BlockByNumber(blockID.Number).Return(block, nil) - mockReader.EXPECT().L1Head().Return(nil, err) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, err) resp, rpcErr := handler.BlockWithReceipts(blockID) assert.Nil(t, resp) @@ -556,7 +563,7 @@ func TestBlockWithReceipts(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) resp, rpcErr := handler.BlockWithReceipts(rpcv7.BlockID{Pending: true}) header := resp.BlockHeader @@ -603,7 +610,7 @@ func TestBlockWithReceipts(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(block0.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) resp, rpcErr := handler.BlockWithReceipts(rpcv7.BlockID{Pending: true}) header := resp.BlockHeader @@ -630,7 +637,7 @@ func TestBlockWithReceipts(t *testing.T) { blockID := rpcv7.BlockID{Number: block1.Number} mockReader.EXPECT().BlockByNumber(blockID.Number).Return(block1, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block1.Number + 1, }, nil) @@ -685,7 +692,7 @@ func TestRpcBlockAdaptation(t *testing.T) { require.NoError(t, err) latestBlock.Header.SequencerAddress = nil mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) block, rpcErr := handler.BlockWithTxs(rpcv7.BlockID{Latest: true}) require.NoError(t, err, rpcErr) diff --git a/rpc/v7/handlers_test.go b/rpc/v7/handlers_test.go index 67d309e7e3..b947326631 100644 --- a/rpc/v7/handlers_test.go +++ b/rpc/v7/handlers_test.go @@ -41,7 +41,7 @@ func TestThrottledVMError(t *testing.T) { t.Run("call", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(new(felt.Felt), nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().Class(new(felt.Felt)).Return(&core.DeclaredClass{Class: &core.Cairo1Class{ Program: []*felt.Felt{ new(felt.Felt), diff --git a/rpc/v7/storage_test.go b/rpc/v7/storage_test.go index 9a04a251d3..fe0cf0babc 100644 --- a/rpc/v7/storage_test.go +++ b/rpc/v7/storage_test.go @@ -53,7 +53,7 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, db.ErrKeyNotFound) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Latest: true}) require.Nil(t, storageValue) @@ -62,8 +62,8 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Latest: true}) require.Equal(t, storageValue, &felt.Zero) @@ -72,8 +72,8 @@ func TestStorageAt(t *testing.T) { t.Run("internal error while retrieving key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, errors.New("some internal error")) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, errors.New("some internal error")) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Latest: true}) require.Nil(t, storageValue) @@ -84,8 +84,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Latest: true}) require.Nil(t, rpcErr) @@ -94,8 +94,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Hash: &felt.Zero}) require.Nil(t, rpcErr) @@ -104,8 +104,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Number: 0}) require.Nil(t, rpcErr) @@ -116,8 +116,8 @@ func TestStorageAt(t *testing.T) { pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) storageValue, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpcv7.BlockID{Pending: true}) require.Nil(t, rpcErr) diff --git a/rpc/v7/trace_test.go b/rpc/v7/trace_test.go index 388bb4a50b..035559f527 100644 --- a/rpc/v7/trace_test.go +++ b/rpc/v7/trace_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -95,7 +96,7 @@ func AssertTracedBlockTransactions(t *testing.T, n *utils.Network, tests map[str return }).AnyTimes() - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() for description, test := range tests { t.Run(description, func(t *testing.T) { @@ -138,7 +139,7 @@ func TestTraceBlockTransactionsReturnsError(t *testing.T) { mockReader.EXPECT().BlockByHash(gomock.Any()).DoAndReturn(func(_ *felt.Felt) (block *core.Block, err error) { return mockReader.BlockByNumber(blockNumber) }) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() // No feeder client is set handler := rpcv7.New(mockReader, nil, nil, n, nil) @@ -498,7 +499,7 @@ func TestTraceTransaction(t *testing.T) { return gateway.BlockByNumber(t.Context(), blockNumber) }).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: 19, // Doesn't really matter for this test }, nil) @@ -622,7 +623,7 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) handler := rpcv7.New(chain, nil, nil, n, log) if description == "pending" { @@ -1507,7 +1508,7 @@ func TestCall(t *testing.T) { t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, errors.New("unknown contract")) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, errors.New("unknown contract")) res, rpcErr := handler.Call(rpcv7.FunctionCall{}, rpcv7.BlockID{Latest: true}) require.Nil(t, res) @@ -1546,7 +1547,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1601,7 +1602,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1647,7 +1648,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) @@ -1687,7 +1688,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) diff --git a/rpc/v7/transaction_test.go b/rpc/v7/transaction_test.go index 59122b1dbe..1408a1dcc3 100644 --- a/rpc/v7/transaction_test.go +++ b/rpc/v7/transaction_test.go @@ -337,7 +337,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[test.index].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[test.index], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[test.index], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txHash, test.expected) }) @@ -403,7 +403,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[i].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[i], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[i], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block0.Number, BlockHash: block0.Hash, StateRoot: block0.GlobalStateRoot, @@ -438,7 +438,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(revertedTxnHash).Return(blockWithRevertedTxn.Transactions[revertedTxnIdx], nil) mockReader.EXPECT().Receipt(revertedTxnHash).Return(blockWithRevertedTxn.Receipts[revertedTxnIdx], blockWithRevertedTxn.Hash, blockWithRevertedTxn.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, revertedTxnHash, expected) }) @@ -506,7 +506,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -888,7 +888,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) handler := rpc.New(mockReader, nil, nil, test.network, nil) @@ -904,7 +904,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) diff --git a/rpc/v8/block_test.go b/rpc/v8/block_test.go index 78d05ed33d..e6064222a2 100644 --- a/rpc/v8/block_test.go +++ b/rpc/v8/block_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -108,7 +109,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { //nolint:goconst mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -160,7 +161,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) latest := blockIDLatest(t) block, rpcErr := handler.BlockWithTxHashes(&latest) @@ -171,7 +172,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) hash := blockIDHash(t, latestBlockHash) block, rpcErr := handler.BlockWithTxHashes(&hash) @@ -182,7 +183,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) number := blockIDNumber(t, latestBlockNumber) block, rpcErr := handler.BlockWithTxHashes(&number) @@ -193,7 +194,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -215,7 +216,7 @@ func TestBlockWithTxHashes(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) pendingID := blockIDPending(t) block, rpcErr := handler.BlockWithTxHashes(&pendingID) @@ -233,7 +234,7 @@ func TestBlockWithTxHashes(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) pending := blockIDPending(t) blockWTxHashes, rpcErr := handler.BlockWithTxHashes(&pending) require.Nil(t, rpcErr) @@ -278,7 +279,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pending" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -331,7 +332,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) latest := blockIDLatest(t) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&latest) @@ -345,7 +346,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) hash := blockIDHash(t, latestBlockHash) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&hash) @@ -359,7 +360,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) number := blockIDNumber(t, latestBlockNumber) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&number) @@ -376,7 +377,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -403,7 +404,7 @@ func TestBlockWithTxs(t *testing.T) { &pending, nil, ).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) pendingID := blockIDPending(t) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&pendingID) @@ -425,7 +426,7 @@ func TestBlockWithTxs(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(latestBlock.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) pending := blockIDPending(t) blockWithTxs, rpcErr := handler.BlockWithTxs(&pending) require.Nil(t, rpcErr) @@ -466,7 +467,7 @@ func TestBlockWithTxHashesV013(t *testing.T) { require.True(t, ok) mockReader.EXPECT().BlockByNumber(gomock.Any()).Return(coreBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) blockID := blockIDNumber(t, blockNumber) got, rpcErr := handler.BlockWithTxs(&blockID) @@ -561,7 +562,7 @@ func TestBlockWithReceipts(t *testing.T) { err := errors.New("l1 failure") mockReader.EXPECT().BlockByNumber(blockID.Number()).Return(block, nil) - mockReader.EXPECT().L1Head().Return(nil, err) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, err) resp, rpcErr := handler.BlockWithReceipts(&blockID) assert.Nil(t, resp) @@ -580,7 +581,7 @@ func TestBlockWithReceipts(t *testing.T) { &pending, nil, ) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) blockID := blockIDPending(t) resp, rpcErr := handler.BlockWithReceipts(&blockID) @@ -630,7 +631,7 @@ func TestBlockWithReceipts(t *testing.T) { ) mockReader.EXPECT().HeadsHeader().Return(block0.Header, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) blockID := blockIDPending(t) resp, rpcErr := handler.BlockWithReceipts(&blockID) @@ -662,7 +663,7 @@ func TestBlockWithReceipts(t *testing.T) { blockID := blockIDNumber(t, block1.Number) mockReader.EXPECT().BlockByNumber(blockID.Number()).Return(block1, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block1.Number + 1, }, nil) @@ -721,7 +722,7 @@ func TestRpcBlockAdaptation(t *testing.T) { require.NoError(t, err) latestBlock.Header.SequencerAddress = nil mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockID := blockIDLatest(t) block, rpcErr := handler.BlockWithTxs(&blockID) diff --git a/rpc/v8/handlers_test.go b/rpc/v8/handlers_test.go index 1426aff2cc..6a578246c1 100644 --- a/rpc/v8/handlers_test.go +++ b/rpc/v8/handlers_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/mocks" "github.com/NethermindEth/juno/node" rpcv6 "github.com/NethermindEth/juno/rpc/v6" - rpcv8 "github.com/NethermindEth/juno/rpc/v8" + rpc "github.com/NethermindEth/juno/rpc/v8" "github.com/NethermindEth/juno/sync" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" @@ -19,7 +19,7 @@ import ( func nopCloser() error { return nil } func TestSpecVersion(t *testing.T) { - handler := rpcv8.New(nil, nil, nil, nil) + handler := rpc.New(nil, nil, nil, nil) version, rpcErr := handler.SpecVersion() require.Nil(t, rpcErr) require.Equal(t, "0.8.1", version) @@ -34,14 +34,14 @@ func TestThrottledVMError(t *testing.T) { mockVM := mocks.NewMockVM(mockCtrl) throttledVM := node.NewThrottledVM(mockVM, 0, 0) - handler := rpcv8.New(mockReader, mockSyncReader, throttledVM, nil) + handler := rpc.New(mockReader, mockSyncReader, throttledVM, nil) mockState := mocks.NewMockStateReader(mockCtrl) throttledErr := "VM throughput limit reached" t.Run("call", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(new(felt.Felt), nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().Class(new(felt.Felt)).Return(&core.DeclaredClass{Class: &core.Cairo1Class{ Program: []*felt.Felt{ new(felt.Felt).SetUint64(3), @@ -51,7 +51,7 @@ func TestThrottledVMError(t *testing.T) { }}, nil) blockID := blockIDLatest(t) - _, rpcErr := handler.Call(&rpcv8.FunctionCall{}, &blockID) + _, rpcErr := handler.Call(&rpc.FunctionCall{}, &blockID) assert.Equal(t, throttledErr, rpcErr.Data) }) @@ -60,9 +60,9 @@ func TestThrottledVMError(t *testing.T) { mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil) blockID := blockIDLatest(t) - _, httpHeader, rpcErr := handler.SimulateTransactions(&blockID, []rpcv8.BroadcastedTransaction{}, []rpcv6.SimulationFlag{rpcv6.SkipFeeChargeFlag}) + _, httpHeader, rpcErr := handler.SimulateTransactions(&blockID, []rpc.BroadcastedTransaction{}, []rpcv6.SimulationFlag{rpcv6.SkipFeeChargeFlag}) assert.Equal(t, throttledErr, rpcErr.Data) - assert.NotEmpty(t, httpHeader.Get(rpcv8.ExecutionStepsHeader)) + assert.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) }) t.Run("trace", func(t *testing.T) { @@ -102,6 +102,6 @@ func TestThrottledVMError(t *testing.T) { blockID := blockIDHash(t, blockHash) _, httpHeader, rpcErr := handler.TraceBlockTransactions(t.Context(), &blockID) assert.Equal(t, throttledErr, rpcErr.Data) - assert.NotEmpty(t, httpHeader.Get(rpcv8.ExecutionStepsHeader)) + assert.NotEmpty(t, httpHeader.Get(rpc.ExecutionStepsHeader)) }) } diff --git a/rpc/v8/l1_test.go b/rpc/v8/l1_test.go index 039c83fb84..27941e1bba 100644 --- a/rpc/v8/l1_test.go +++ b/rpc/v8/l1_test.go @@ -85,11 +85,11 @@ func TestGetMessageStatus(t *testing.T) { mockSubscriber.EXPECT().TransactionReceipt(gomock.Any(), gomock.Any()).Return(&test.l1TxnReceipt, nil) for i, msg := range test.msgs { - mockReader.EXPECT().L1HandlerTxnHash(&test.msgHashes[i]).Return(msg.L1HandlerHash, nil) + mockReader.EXPECT().L1HandlerTxnHash(&test.msgHashes[i]).Return(*msg.L1HandlerHash, nil) // Expects for h.TransactionStatus() mockReader.EXPECT().TransactionByHash(msg.L1HandlerHash).Return(l1handlerTxns[i], nil) mockReader.EXPECT().Receipt(msg.L1HandlerHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{BlockNumber: uint64(test.l1HeadBlockNum)}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{BlockNumber: uint64(test.l1HeadBlockNum)}, nil) } msgStatuses, rpcErr := handler.GetMessageStatus(t.Context(), &test.l1TxnHash) require.Nil(t, rpcErr) diff --git a/rpc/v8/simulation_pkg_test.go b/rpc/v8/simulation_pkg_test.go index 5de7b08e81..b39e2efd3b 100644 --- a/rpc/v8/simulation_pkg_test.go +++ b/rpc/v8/simulation_pkg_test.go @@ -3,10 +3,12 @@ package rpcv8 import ( "encoding/json" "errors" + "os" "testing" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -14,6 +16,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + //nolint:dupl func TestCreateSimulatedTransactions(t *testing.T) { executionResults := vm.ExecutionResults{ diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 2f4373f64a..34c4c5e0bf 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" @@ -66,7 +67,7 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, db.ErrKeyNotFound) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -76,8 +77,8 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, nil) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -87,9 +88,9 @@ func TestStorageAt(t *testing.T) { t.Run("internal error while retrieving key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()). - Return(nil, errors.New("some internal error")) + Return(felt.Zero, errors.New("some internal error")) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -101,8 +102,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -112,8 +113,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDHash(t, &felt.Zero) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -123,8 +124,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDNumber(t, 0) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -136,8 +137,8 @@ func TestStorageAt(t *testing.T) { pending := sync.NewPending(nil, nil, nil) mockSyncReader.EXPECT().PendingData().Return(&pending, nil) mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) pendingID := blockIDPending(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &pendingID) require.Nil(t, rpcErr) @@ -604,7 +605,7 @@ func TestStorageProof_StorageRoots(t *testing.T) { log := utils.NewNopZapLogger() testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, log, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), time.Second) @@ -648,8 +649,8 @@ func TestStorageProof_StorageRoots(t *testing.T) { stgRoot, err := contractTrie.Hash() assert.NoError(t, err) - assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) - assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) + assert.Equal(t, *expectedClsRoot, clsRoot, clsRoot.String()) + assert.Equal(t, *expectedStgRoot, stgRoot, stgRoot.String()) verifyGlobalStateRoot(t, expectedGlobalRoot, &clsRoot, &stgRoot) }) @@ -664,11 +665,11 @@ func TestStorageProof_StorageRoots(t *testing.T) { leaf, err := contractTrie.Get(expectedContractAddress) assert.NoError(t, err) - assert.Equal(t, leaf, expectedContractLeaf, leaf.String()) + assert.Equal(t, leaf, *expectedContractLeaf, leaf.String()) clsHash, err := stateReader.ContractClassHash(expectedContractAddress) assert.NoError(t, err) - assert.Equal(t, clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) + assert.Equal(t, &clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) }) t.Run("get contract proof", func(t *testing.T) { diff --git a/rpc/v8/subscriptions_test.go b/rpc/v8/subscriptions_test.go index ea6dc91ce5..80d3827c8f 100644 --- a/rpc/v8/subscriptions_test.go +++ b/rpc/v8/subscriptions_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -345,7 +346,7 @@ func TestSubscribeTxnStatus(t *testing.T) { mockChain.EXPECT().TransactionByHash(txHash).Return(block.Transactions[0], nil) mockChain.EXPECT().Receipt(txHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockChain.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockChain.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) for i := range 3 { handler.pendingData.Send(&sync.Pending{Block: &core.Block{Header: &core.Header{}}}) handler.pendingData.Send(&sync.Pending{Block: &core.Block{Header: &core.Header{}}}) @@ -356,7 +357,7 @@ func TestSubscribeTxnStatus(t *testing.T) { l1Head := &core.L1Head{BlockNumber: block.Number} mockChain.EXPECT().TransactionByHash(txHash).Return(block.Transactions[0], nil) mockChain.EXPECT().Receipt(txHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockChain.EXPECT().L1Head().Return(l1Head, nil) + mockChain.EXPECT().L1Head().Return(*l1Head, nil) handler.l1Heads.Send(l1Head) assertNextTxnStatus(t, conn, id, txHash, TxnStatusAcceptedOnL1, TxnSuccess, "") }) @@ -497,10 +498,10 @@ func TestSubscribeNewHeadsHistorical(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &utils.Mainnet) + chain = blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) syncer := newFakeSyncer() ctx, cancel := context.WithCancel(t.Context()) diff --git a/rpc/v8/trace_test.go b/rpc/v8/trace_test.go index 6fa1c8e7ee..548d7e2b6e 100644 --- a/rpc/v8/trace_test.go +++ b/rpc/v8/trace_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -96,7 +97,7 @@ func AssertTracedBlockTransactions(t *testing.T, n *utils.Network, tests map[str return }).AnyTimes() - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() for description, test := range tests { t.Run(description, func(t *testing.T) { @@ -138,7 +139,7 @@ func TestTraceBlockTransactionsReturnsError(t *testing.T) { mockReader.EXPECT().BlockByHash(gomock.Any()).DoAndReturn(func(_ *felt.Felt) (block *core.Block, err error) { return mockReader.BlockByNumber(blockNumber) }) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() // No feeder client is set handler := rpc.New(mockReader, nil, nil, nil) @@ -471,7 +472,7 @@ func TestTraceTransaction(t *testing.T) { return gateway.BlockByNumber(t.Context(), blockNumber) }).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: 19, // Doesn't really matter for this test }, nil) @@ -586,7 +587,7 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) handler := rpc.New(chain, nil, nil, log) if description == "pending" { @@ -1274,7 +1275,7 @@ func TestCall(t *testing.T) { t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, errors.New("unknown contract")) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, errors.New("unknown contract")) blockID := blockIDLatest(t) res, rpcErr := handler.Call(&rpc.FunctionCall{}, &blockID) @@ -1312,7 +1313,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1363,7 +1364,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1411,7 +1412,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) diff --git a/rpc/v8/transaction_test.go b/rpc/v8/transaction_test.go index 43f3adf2cd..3960dce49c 100644 --- a/rpc/v8/transaction_test.go +++ b/rpc/v8/transaction_test.go @@ -708,7 +708,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[test.index].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[test.index], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[test.index], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txHash, test.expected) }) @@ -781,7 +781,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[i].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[i], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[i], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block0.Number, BlockHash: block0.Hash, StateRoot: block0.GlobalStateRoot, @@ -820,7 +820,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(revertedTxnHash).Return(blockWithRevertedTxn.Transactions[revertedTxnIdx], nil) mockReader.EXPECT().Receipt(revertedTxnHash).Return(blockWithRevertedTxn.Receipts[revertedTxnIdx], blockWithRevertedTxn.Hash, blockWithRevertedTxn.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, revertedTxnHash, expected) }) @@ -884,7 +884,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -934,7 +934,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -1363,7 +1363,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) handler := rpc.New(mockReader, nil, nil, nil) @@ -1379,7 +1379,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) @@ -1397,7 +1397,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) diff --git a/rpc/v9/block_test.go b/rpc/v9/block_test.go index cb41b51530..fb52affd3d 100644 --- a/rpc/v9/block_test.go +++ b/rpc/v9/block_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -176,7 +177,7 @@ func TestBlockTransactionCount(t *testing.T) { t.Run("blockID - l1_accepted", func(t *testing.T) { mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: latestBlock.Number, BlockHash: latestBlock.Hash, StateRoot: latestBlock.GlobalStateRoot, @@ -223,7 +224,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pre_confirmed" { //nolint:goconst mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -275,7 +276,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) latest := blockIDLatest(t) block, rpcErr := handler.BlockWithTxHashes(&latest) @@ -286,7 +287,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) hash := blockIDHash(t, latestBlockHash) block, rpcErr := handler.BlockWithTxHashes(&hash) @@ -297,7 +298,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) number := blockIDNumber(t, latestBlockNumber) block, rpcErr := handler.BlockWithTxHashes(&number) @@ -308,7 +309,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -324,7 +325,7 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run("blockID - l1_accepted", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -346,7 +347,7 @@ func TestBlockWithTxHashes(t *testing.T) { &preConfirmed, nil, ) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) preConfirmedID := blockIDPreConfirmed(t) block, rpcErr := handler.BlockWithTxHashes(&preConfirmedID) @@ -373,7 +374,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -426,7 +427,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) latest := blockIDLatest(t) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&latest) @@ -440,7 +441,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().BlockByHash(latestBlockHash).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) hash := blockIDHash(t, latestBlockHash) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&hash) @@ -454,7 +455,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) number := blockIDNumber(t, latestBlockNumber) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&number) @@ -471,7 +472,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - number accepted on l1", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -492,7 +493,7 @@ func TestBlockWithTxs(t *testing.T) { t.Run("blockID - l1_accepted", func(t *testing.T) { mockReader.EXPECT().BlockByNumber(latestBlockNumber).Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -519,7 +520,7 @@ func TestBlockWithTxs(t *testing.T) { &preConfirmed, nil, ).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) preConfirmedID := blockIDPreConfirmed(t) blockWithTxHashes, rpcErr := handler.BlockWithTxHashes(&preConfirmedID) @@ -547,7 +548,7 @@ func TestBlockWithTxHashesV013(t *testing.T) { require.True(t, ok) mockReader.EXPECT().BlockByNumber(gomock.Any()).Return(coreBlock, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) blockID := blockIDNumber(t, blockNumber) got, rpcErr := handler.BlockWithTxs(&blockID) @@ -640,7 +641,7 @@ func TestBlockWithReceipts(t *testing.T) { err := errors.New("l1 failure") mockReader.EXPECT().BlockByNumber(blockID.Number()).Return(block, nil) - mockReader.EXPECT().L1Head().Return(nil, err) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, err) resp, rpcErr := handler.BlockWithReceipts(&blockID) assert.Nil(t, resp) @@ -663,7 +664,7 @@ func TestBlockWithReceipts(t *testing.T) { &preConfirmed, nil, ).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) blockID := blockIDPreConfirmed(t) resp, rpcErr := handler.BlockWithReceipts(&blockID) @@ -705,7 +706,7 @@ func TestBlockWithReceipts(t *testing.T) { blockID := blockIDNumber(t, block1.Number) mockReader.EXPECT().BlockByNumber(blockID.Number()).Return(block1, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block1.Number + 1, }, nil) @@ -762,7 +763,7 @@ func TestRpcBlockAdaptation(t *testing.T) { require.NoError(t, err) latestBlock.Header.SequencerAddress = nil mockReader.EXPECT().Head().Return(latestBlock, nil).Times(2) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).Times(2) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).Times(2) blockID := blockIDLatest(t) block, rpcErr := handler.BlockWithTxs(&blockID) diff --git a/rpc/v9/class_test.go b/rpc/v9/class_test.go index fda48724cd..3c7ce4a1a1 100644 --- a/rpc/v9/class_test.go +++ b/rpc/v9/class_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" rpcv6 "github.com/NethermindEth/juno/rpc/v6" - rpcv9 "github.com/NethermindEth/juno/rpc/v9" + rpc "github.com/NethermindEth/juno/rpc/v9" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" @@ -39,7 +39,7 @@ func TestClass(t *testing.T) { return nil }, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil).AnyTimes() - handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) latest := blockIDLatest(t) @@ -70,7 +70,7 @@ func TestClass(t *testing.T) { t.Run("state by id error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) - handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) @@ -82,7 +82,7 @@ func TestClass(t *testing.T) { t.Run("class hash not found error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockState := mocks.NewMockStateReader(mockCtrl) - handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(mockState, func() error { return nil @@ -108,11 +108,11 @@ func TestClassAt(t *testing.T) { cairo0ContractAddress, _ := new(felt.Felt).SetRandom() cairo0ClassHash := utils.HexToFelt(t, "0x4631b6b3fa31e140524b7d21ba784cea223e618bffe60b5bbdca44a8b45be04") - mockState.EXPECT().ContractClassHash(cairo0ContractAddress).Return(cairo0ClassHash, nil) + mockState.EXPECT().ContractClassHash(cairo0ContractAddress).Return(*cairo0ClassHash, nil) cairo1ContractAddress, _ := new(felt.Felt).SetRandom() cairo1ClassHash := utils.HexToFelt(t, "0x1cd2edfb485241c4403254d550de0a097fa76743cd30696f714a491a454bad5") - mockState.EXPECT().ContractClassHash(cairo1ContractAddress).Return(cairo1ClassHash, nil) + mockState.EXPECT().ContractClassHash(cairo1ContractAddress).Return(*cairo1ClassHash, nil) mockState.EXPECT().Class(gomock.Any()).DoAndReturn(func(classHash *felt.Felt) (*core.DeclaredClass, error) { class, err := integGw.Class(t.Context(), classHash) @@ -122,7 +122,7 @@ func TestClassAt(t *testing.T) { return nil }, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil).AnyTimes() - handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) latest := blockIDLatest(t) @@ -155,7 +155,7 @@ func TestClassHashAt(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockSyncReader := mocks.NewMockSyncReader(mockCtrl) log := utils.NewNopZapLogger() - handler := rpcv9.New(mockReader, mockSyncReader, nil, log) + handler := rpc.New(mockReader, mockSyncReader, nil, log) t.Run("empty blockchain", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) @@ -185,7 +185,7 @@ func TestClassHashAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, errors.New("non-existent contract")) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, errors.New("non-existent contract")) latest := blockIDLatest(t) classHash, rpcErr := handler.ClassHashAt(&latest, &felt.Zero) require.Nil(t, classHash) @@ -196,7 +196,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) latest := blockIDLatest(t) classHash, rpcErr := handler.ClassHashAt(&latest, &felt.Zero) require.Nil(t, rpcErr) @@ -205,7 +205,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) hash := blockIDHash(t, &felt.Zero) classHash, rpcErr := handler.ClassHashAt(&hash, &felt.Zero) require.Nil(t, rpcErr) @@ -214,7 +214,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) number := blockIDNumber(t, 0) classHash, rpcErr := handler.ClassHashAt(&number, &felt.Zero) require.Nil(t, rpcErr) @@ -223,7 +223,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - pre_confirmed", func(t *testing.T) { mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) preConfirmed := blockIDPreConfirmed(t) classHash, rpcErr := handler.ClassHashAt(&preConfirmed, &felt.Zero) require.Nil(t, rpcErr) @@ -233,7 +233,7 @@ func TestClassHashAt(t *testing.T) { t.Run("blockID - l1_accepted", func(t *testing.T) { l1HeadBlockNumber := uint64(10) mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: l1HeadBlockNumber, BlockHash: &felt.Zero, StateRoot: &felt.Zero, @@ -241,7 +241,7 @@ func TestClassHashAt(t *testing.T) { nil, ) mockReader.EXPECT().StateAtBlockNumber(l1HeadBlockNumber).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) l1AcceptedID := blockIDL1Accepted(t) classHash, rpcErr := handler.ClassHashAt(&l1AcceptedID, &felt.Zero) require.Nil(t, rpcErr) diff --git a/rpc/v9/events_test.go b/rpc/v9/events_test.go index d89ae1318f..a64dd31852 100644 --- a/rpc/v9/events_test.go +++ b/rpc/v9/events_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -24,7 +25,7 @@ func TestEvents(t *testing.T) { testDB := memory.New() n := &utils.Sepolia - chain := blockchain.New(testDB, n) + chain := blockchain.New(testDB, n, statetestutils.UseNewState()) mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) diff --git a/rpc/v9/handlers_test.go b/rpc/v9/handlers_test.go index 507b6d7ba1..1c0ecbc305 100644 --- a/rpc/v9/handlers_test.go +++ b/rpc/v9/handlers_test.go @@ -40,7 +40,7 @@ func TestThrottledVMError(t *testing.T) { t.Run("call", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(new(felt.Felt), nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().Class(new(felt.Felt)).Return(&core.DeclaredClass{Class: &core.Cairo1Class{ Program: []*felt.Felt{ new(felt.Felt).SetUint64(3), diff --git a/rpc/v9/l1_test.go b/rpc/v9/l1_test.go index 2fd779ef76..0de2f2e156 100644 --- a/rpc/v9/l1_test.go +++ b/rpc/v9/l1_test.go @@ -87,11 +87,11 @@ func TestGetMessageStatus(t *testing.T) { mockSubscriber.EXPECT().TransactionReceipt(gomock.Any(), gomock.Any()).Return(&test.l1TxnReceipt, nil) for i, msg := range test.msgs { - mockReader.EXPECT().L1HandlerTxnHash(&test.msgHashes[i]).Return(msg.L1HandlerHash, nil) + mockReader.EXPECT().L1HandlerTxnHash(&test.msgHashes[i]).Return(*msg.L1HandlerHash, nil) // Expects for h.TransactionStatus() mockReader.EXPECT().TransactionByHash(msg.L1HandlerHash).Return(l1handlerTxns[i], nil) mockReader.EXPECT().Receipt(msg.L1HandlerHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{BlockNumber: uint64(test.l1HeadBlockNum)}, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{BlockNumber: uint64(test.l1HeadBlockNum)}, nil) } msgStatuses, rpcErr := handler.GetMessageStatus(t.Context(), &test.l1TxnHash) require.Nil(t, rpcErr) diff --git a/rpc/v9/nonce_test.go b/rpc/v9/nonce_test.go index d1ad43dfb0..fd15a0b5a1 100644 --- a/rpc/v9/nonce_test.go +++ b/rpc/v9/nonce_test.go @@ -55,7 +55,7 @@ func TestNonce(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(nil, errors.New("non-existent contract")) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(felt.Zero, errors.New("non-existent contract")) latest := blockIDLatest(t) nonce, rpcErr := handler.Nonce(&latest, &felt.Zero) @@ -67,7 +67,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) latest := blockIDLatest(t) nonce, rpcErr := handler.Nonce(&latest, &felt.Zero) @@ -77,7 +77,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) hash := blockIDHash(t, &felt.Zero) nonce, rpcErr := handler.Nonce(&hash, &felt.Zero) @@ -87,7 +87,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) number := blockIDNumber(t, 0) nonce, rpcErr := handler.Nonce(&number, &felt.Zero) @@ -97,7 +97,7 @@ func TestNonce(t *testing.T) { t.Run("blockID - pre_confirmed", func(t *testing.T) { mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) preConfirmedBlockID := blockIDPreConfirmed(t) nonce, rpcErr := handler.Nonce(&preConfirmedBlockID, &felt.Zero) @@ -109,7 +109,7 @@ func TestNonce(t *testing.T) { l1AcceptedBlockNumber := uint64(10) mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: l1AcceptedBlockNumber, BlockHash: &felt.One, StateRoot: &felt.One, @@ -117,7 +117,7 @@ func TestNonce(t *testing.T) { nil, ) mockReader.EXPECT().StateAtBlockNumber(l1AcceptedBlockNumber).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(*expectedNonce, nil) l1AcceptedID := blockIDL1Accepted(t) nonce, rpcErr := handler.Nonce(&l1AcceptedID, &felt.Zero) diff --git a/rpc/v9/simulation_pkg_test.go b/rpc/v9/simulation_pkg_test.go index 84affce3ab..c68a64a0e7 100644 --- a/rpc/v9/simulation_pkg_test.go +++ b/rpc/v9/simulation_pkg_test.go @@ -3,10 +3,12 @@ package rpcv9 import ( "encoding/json" "errors" + "os" "testing" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -14,6 +16,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestCreateSimulatedTransactions(t *testing.T) { executionResults := vm.ExecutionResults{ OverallFees: []*felt.Felt{new(felt.Felt).SetUint64(10), new(felt.Felt).SetUint64(20)}, diff --git a/rpc/v9/state_update_test.go b/rpc/v9/state_update_test.go index e7c7fd12fc..95afcc9b46 100644 --- a/rpc/v9/state_update_test.go +++ b/rpc/v9/state_update_test.go @@ -7,11 +7,12 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" rpcv6 "github.com/NethermindEth/juno/rpc/v6" - rpcv9 "github.com/NethermindEth/juno/rpc/v9" + rpc "github.com/NethermindEth/juno/rpc/v9" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/sync" "github.com/NethermindEth/juno/utils" @@ -21,7 +22,7 @@ import ( ) func TestStateUpdate(t *testing.T) { - errTests := map[string]rpcv9.BlockID{ + errTests := map[string]rpc.BlockID{ "latest": blockIDLatest(t), "pre_confirmed": blockIDPreConfirmed(t), "hash": blockIDHash(t, &felt.One), @@ -35,13 +36,13 @@ func TestStateUpdate(t *testing.T) { n := &utils.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PendingData().Return(nil, sync.ErrPendingBlockNotFound) } log := utils.NewNopZapLogger() - handler := rpcv9.New(chain, mockSyncReader, nil, log) + handler := rpc.New(chain, mockSyncReader, nil, log) update, rpcErr := handler.StateUpdate(&id) assert.Empty(t, update) @@ -51,7 +52,7 @@ func TestStateUpdate(t *testing.T) { log := utils.NewNopZapLogger() mockReader := mocks.NewMockReader(mockCtrl) - handler := rpcv9.New(mockReader, mockSyncReader, nil, log) + handler := rpc.New(mockReader, mockSyncReader, nil, log) client := feeder.NewTestClient(t, n) mainnetGw := adaptfeeder.New(client) @@ -148,7 +149,7 @@ func TestStateUpdate(t *testing.T) { t.Run("l1_accepted", func(t *testing.T) { mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: uint64(21656), BlockHash: update21656.BlockHash, StateRoot: update21656.NewRoot, diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 3663e554d5..d3a7fc317b 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" @@ -66,7 +67,7 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(felt.Zero, db.ErrKeyNotFound) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -76,8 +77,8 @@ func TestStorageAt(t *testing.T) { t.Run("non-existent key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(felt.Zero, nil) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -87,9 +88,9 @@ func TestStorageAt(t *testing.T) { t.Run("internal error while retrieving key", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()). - Return(nil, errors.New("some internal error")) + Return(felt.Zero, errors.New("some internal error")) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -101,8 +102,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - latest", func(t *testing.T) { //nolint:dupl mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDLatest(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -112,8 +113,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - hash", func(t *testing.T) { mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDHash(t, &felt.Zero) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -123,8 +124,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - number", func(t *testing.T) { mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDNumber(t, 0) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -134,8 +135,8 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - pre_confirmed", func(t *testing.T) { //nolint:dupl //false alarm block tag differs mockSyncReader.EXPECT().PendingState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) preConfirmedID := blockIDPreConfirmed(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &preConfirmedID) require.Nil(t, rpcErr) @@ -145,7 +146,7 @@ func TestStorageAt(t *testing.T) { t.Run("blockID - l1_accepted", func(t *testing.T) { l1HeadBlockNumber := uint64(10) mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: l1HeadBlockNumber, BlockHash: &felt.Zero, StateRoot: &felt.Zero, @@ -153,8 +154,8 @@ func TestStorageAt(t *testing.T) { nil, ) mockReader.EXPECT().StateAtBlockNumber(l1HeadBlockNumber).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedStorage, nil) blockID := blockIDL1Accepted(t) storageValue, rpcErr := handler.StorageAt(&felt.Zero, &felt.Zero, &blockID) @@ -622,7 +623,7 @@ func TestStorageProof_StorageRoots(t *testing.T) { log := utils.NewNopZapLogger() testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, log, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), time.Second) @@ -666,8 +667,8 @@ func TestStorageProof_StorageRoots(t *testing.T) { stgRoot, err := contractTrie.Hash() assert.NoError(t, err) - assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) - assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) + assert.Equal(t, *expectedClsRoot, clsRoot, clsRoot.String()) + assert.Equal(t, *expectedStgRoot, stgRoot, stgRoot.String()) verifyGlobalStateRoot(t, expectedGlobalRoot, &clsRoot, &stgRoot) }) @@ -682,11 +683,11 @@ func TestStorageProof_StorageRoots(t *testing.T) { leaf, err := contractTrie.Get(expectedContractAddress) assert.NoError(t, err) - assert.Equal(t, leaf, expectedContractLeaf, leaf.String()) + assert.Equal(t, leaf, *expectedContractLeaf, leaf.String()) clsHash, err := stateReader.ContractClassHash(expectedContractAddress) assert.NoError(t, err) - assert.Equal(t, clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) + assert.Equal(t, &clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) }) t.Run("get contract proof", func(t *testing.T) { diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index 1046da0603..e9870d67d5 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -399,7 +400,7 @@ func TestSubscribeTxnStatus(t *testing.T) { // Accepted on l1 Status mockChain.EXPECT().TransactionByHash(txHash).Return(block.Transactions[0], nil) mockChain.EXPECT().Receipt(txHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockChain.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockChain.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) handler.newHeads.Send(&core.Block{Header: &core.Header{Number: block.Number + 1}}) @@ -408,7 +409,7 @@ func TestSubscribeTxnStatus(t *testing.T) { l1Head := &core.L1Head{BlockNumber: block.Number} mockChain.EXPECT().TransactionByHash(txHash).Return(block.Transactions[0], nil) mockChain.EXPECT().Receipt(txHash).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockChain.EXPECT().L1Head().Return(l1Head, nil) + mockChain.EXPECT().L1Head().Return(*l1Head, nil) handler.l1Heads.Send(l1Head) assertNextTxnStatus(t, conn, id, txHash, TxnStatusAcceptedOnL1, TxnSuccess, "") }) @@ -549,10 +550,10 @@ func TestSubscribeNewHeadsHistorical(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &utils.Mainnet) + chain = blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) syncer := newFakeSyncer() ctx, cancel := context.WithCancel(t.Context()) diff --git a/rpc/v9/trace_test.go b/rpc/v9/trace_test.go index 1aa8e50da8..9c70c6ec7d 100644 --- a/rpc/v9/trace_test.go +++ b/rpc/v9/trace_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -96,7 +97,7 @@ func AssertTracedBlockTransactions(t *testing.T, n *utils.Network, tests map[str return }).AnyTimes() - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() for description, test := range tests { t.Run(description, func(t *testing.T) { @@ -138,7 +139,7 @@ func TestTraceBlockTransactionsReturnsError(t *testing.T) { mockReader.EXPECT().BlockByHash(gomock.Any()).DoAndReturn(func(_ *felt.Felt) (block *core.Block, err error) { return mockReader.BlockByNumber(blockNumber) }) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound).AnyTimes() + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound).AnyTimes() // No feeder client is set handler := rpc.New(mockReader, nil, nil, nil) @@ -574,7 +575,7 @@ func TestTraceTransaction(t *testing.T) { return gateway.BlockByNumber(t.Context(), blockNumber) }).Times(2) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: 19, // Doesn't really matter for this test }, nil) @@ -690,7 +691,7 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) handler := rpc.New(chain, nil, nil, log) if description == "pre_confirmed" { @@ -1235,7 +1236,7 @@ func TestCall(t *testing.T) { t.Run("call - unknown contract", func(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, errors.New("unknown contract")) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(felt.Zero, errors.New("unknown contract")) blockID := blockIDLatest(t) res, rpcErr := handler.Call(&rpc.FunctionCall{}, &blockID) @@ -1273,7 +1274,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1324,7 +1325,7 @@ func TestCall(t *testing.T) { mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(&vm.CallInfo{ @@ -1372,7 +1373,7 @@ func TestCall(t *testing.T) { } mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) - mockState.EXPECT().ContractClassHash(contractAddr).Return(classHash, nil) + mockState.EXPECT().ContractClassHash(contractAddr).Return(*classHash, nil) mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: &cairoClass}, nil) mockReader.EXPECT().Network().Return(n) mockVM.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedRes, nil) diff --git a/rpc/v9/transaction_test.go b/rpc/v9/transaction_test.go index db0810a116..82bf1e1e7f 100644 --- a/rpc/v9/transaction_test.go +++ b/rpc/v9/transaction_test.go @@ -621,7 +621,7 @@ func TestTransactionByBlockIdAndIndex(t *testing.T) { index := rand.Intn(int(latestBlock.TransactionCount)) mockReader.EXPECT().L1Head().Return( - &core.L1Head{ + core.L1Head{ BlockNumber: latestBlockNumber, BlockHash: latestBlockHash, StateRoot: latestBlock.GlobalStateRoot, @@ -775,7 +775,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[test.index].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[test.index], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[test.index], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txHash, test.expected) }) @@ -849,7 +849,7 @@ func TestTransactionReceiptByHash(t *testing.T) { txHash := block0.Transactions[i].Hash() mockReader.EXPECT().TransactionByHash(txHash).Return(block0.Transactions[i], nil) mockReader.EXPECT().Receipt(txHash).Return(block0.Receipts[i], block0.Hash, block0.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block0.Number, BlockHash: block0.Hash, StateRoot: block0.GlobalStateRoot, @@ -888,7 +888,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(revertedTxnHash).Return(blockWithRevertedTxn.Transactions[revertedTxnIdx], nil) mockReader.EXPECT().Receipt(revertedTxnHash).Return(blockWithRevertedTxn.Receipts[revertedTxnIdx], blockWithRevertedTxn.Hash, blockWithRevertedTxn.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, revertedTxnHash, expected) }) @@ -952,7 +952,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -1002,7 +1002,7 @@ func TestTransactionReceiptByHash(t *testing.T) { mockReader.EXPECT().TransactionByHash(txnHash).Return(block.Transactions[index], nil) mockReader.EXPECT().Receipt(txnHash).Return(block.Receipts[index], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, db.ErrKeyNotFound) checkTxReceipt(t, txnHash, expected) }) @@ -1459,7 +1459,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(nil, nil) + mockReader.EXPECT().L1Head().Return(core.L1Head{}, nil) handler := rpc.New(mockReader, nil, nil, nil) @@ -1475,7 +1475,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) @@ -1493,7 +1493,7 @@ func TestTransactionStatus(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockReader.EXPECT().TransactionByHash(tx.Hash()).Return(tx, nil) mockReader.EXPECT().Receipt(tx.Hash()).Return(block.Receipts[0], block.Hash, block.Number, nil) - mockReader.EXPECT().L1Head().Return(&core.L1Head{ + mockReader.EXPECT().L1Head().Return(core.L1Head{ BlockNumber: block.Number + 1, }, nil) diff --git a/sequencer/sequencer_test.go b/sequencer/sequencer_test.go index 9b7953a06d..3521c0a121 100644 --- a/sequencer/sequencer_test.go +++ b/sequencer/sequencer_test.go @@ -3,6 +3,7 @@ package sequencer_test import ( "context" "crypto/rand" + "os" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -23,13 +25,18 @@ import ( "go.uber.org/mock/gomock" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func getEmptySequencer(t *testing.T, blockTime time.Duration, seqAddr *felt.Felt) (sequencer.Sequencer, *blockchain.Blockchain) { t.Helper() testDB := memory.New() mockCtrl := gomock.NewController(t) mockVM := mocks.NewMockVM(mockCtrl) network := &utils.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New(testDB, network, statetestutils.UseNewState()) emptyStateDiff := core.EmptyStateDiff() require.NoError(t, bc.StoreGenesis(&emptyStateDiff, nil)) privKey, err := ecdsa.GenerateKey(rand.Reader) @@ -103,7 +110,7 @@ func getGenesisSequencer( testDB := memory.New() network := &utils.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New(testDB, network, statetestutils.UseNewState()) log := utils.NewNopZapLogger() privKey, err := ecdsa.GenerateKey(rand.Reader) require.NoError(t, err) diff --git a/sync/pending_test.go b/sync/pending_test.go index 172d8e53a7..3e8f46bc67 100644 --- a/sync/pending_test.go +++ b/sync/pending_test.go @@ -64,44 +64,44 @@ func TestPendingState(t *testing.T) { t.Run("deployed", func(t *testing.T) { cH, cErr := state.ContractClassHash(deployedAddr) require.NoError(t, cErr) - assert.Equal(t, deployedClassHash, cH) + assert.Equal(t, *deployedClassHash, cH) cH, cErr = state.ContractClassHash(deployedAddr2) require.NoError(t, cErr) - assert.Equal(t, deployedClassHash, cH) + assert.Equal(t, *deployedClassHash, cH) }) t.Run("replaced", func(t *testing.T) { cH, cErr := state.ContractClassHash(replacedAddr) require.NoError(t, cErr) - assert.Equal(t, replacedClassHash, cH) + assert.Equal(t, *replacedClassHash, cH) }) }) t.Run("from head", func(t *testing.T) { expectedClassHash := new(felt.Felt).SetUint64(37) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(expectedClassHash, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(*expectedClassHash, nil) cH, cErr := state.ContractClassHash(&felt.Zero) require.NoError(t, cErr) - assert.Equal(t, expectedClassHash, cH) + assert.Equal(t, *expectedClassHash, cH) }) }) t.Run("ContractNonce", func(t *testing.T) { t.Run("from pending", func(t *testing.T) { cN, cErr := state.ContractNonce(deployedAddr) require.NoError(t, cErr) - assert.Equal(t, new(felt.Felt).SetUint64(44), cN) + assert.Equal(t, new(felt.Felt).SetUint64(44), &cN) cN, cErr = state.ContractNonce(deployedAddr2) require.NoError(t, cErr) - assert.Equal(t, &felt.Zero, cN) + assert.Equal(t, felt.Zero, cN) }) t.Run("from head", func(t *testing.T) { expectedNonce := new(felt.Felt).SetUint64(1337) - mockState.EXPECT().ContractNonce(gomock.Any()).Return(expectedNonce, nil) + mockState.EXPECT().ContractNonce(gomock.Any()).Return(*expectedNonce, nil) cN, cErr := state.ContractNonce(&felt.Zero) require.NoError(t, cErr) - assert.Equal(t, expectedNonce, cN) + assert.Equal(t, *expectedNonce, cN) }) }) t.Run("ContractStorage", func(t *testing.T) { @@ -109,23 +109,23 @@ func TestPendingState(t *testing.T) { expectedValue := new(felt.Felt).SetUint64(37) cV, cErr := state.ContractStorage(deployedAddr, new(felt.Felt).SetUint64(44)) require.NoError(t, cErr) - assert.Equal(t, expectedValue, cV) + assert.Equal(t, *expectedValue, cV) cV, cErr = state.ContractStorage(deployedAddr, new(felt.Felt).SetUint64(0xDEADBEEF)) require.NoError(t, cErr) - assert.Equal(t, &felt.Zero, cV) + assert.Equal(t, felt.Zero, cV) cV, cErr = state.ContractStorage(deployedAddr2, new(felt.Felt).SetUint64(0xDEADBEEF)) require.NoError(t, cErr) - assert.Equal(t, &felt.Zero, cV) + assert.Equal(t, felt.Zero, cV) }) t.Run("from head", func(t *testing.T) { expectedValue := new(felt.Felt).SetUint64(0xDEADBEEF) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedValue, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(*expectedValue, nil) cV, cErr := state.ContractStorage(&felt.Zero, &felt.Zero) require.NoError(t, cErr) - assert.Equal(t, expectedValue, cV) + assert.Equal(t, *expectedValue, cV) }) }) t.Run("Class", func(t *testing.T) { diff --git a/sync/sync_test.go b/sync/sync_test.go index 53f01f005e..cf177d5051 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -3,6 +3,7 @@ package sync_test import ( "context" "errors" + "os" "sync/atomic" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -23,6 +25,11 @@ import ( const timeout = time.Second +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestSyncBlocks(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) @@ -55,7 +62,7 @@ func TestSyncBlocks(t *testing.T) { log := utils.NewNopZapLogger() t.Run("sync multiple blocks in an empty db", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, log, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), timeout) @@ -68,7 +75,7 @@ func TestSyncBlocks(t *testing.T) { t.Run("sync multiple blocks in a non-empty db", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) b0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) s0, err := gw.StateUpdate(t.Context(), 0) @@ -87,7 +94,7 @@ func TestSyncBlocks(t *testing.T) { t.Run("sync multiple blocks, with an unreliable gw", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) mockSNData := mocks.NewMockStarknetData(mockCtrl) @@ -149,7 +156,7 @@ func TestReorg(t *testing.T) { testDB := memory.New() // sync to Sepolia for 2 blocks - bc := blockchain.New(testDB, &utils.Sepolia) + bc := blockchain.New(testDB, &utils.Sepolia, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, sepoliaGw) synchronizer := sync.New(bc, dataSource, utils.NewNopZapLogger(), 0, 0, false, testDB) @@ -160,7 +167,7 @@ func TestReorg(t *testing.T) { require.NoError(t, bc.Stop()) t.Run("resync to mainnet with the same db", func(t *testing.T) { - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) // Ensure current head is Sepolia head head, err := bc.HeadsHeader() @@ -199,7 +206,7 @@ func TestPendingData(t *testing.T) { t.Run("starknet version <= 0.14.0", func(t *testing.T) { var synchronizer *sync.Synchronizer testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(chain, gw) synchronizer = sync.New(chain, dataSource, utils.NewNopZapLogger(), 0, 0, false, testDB) @@ -274,7 +281,7 @@ func TestPendingData(t *testing.T) { t.Run("starknet version > 0.14.0", func(t *testing.T) { var synchronizer *sync.Synchronizer testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(chain, gw) synchronizer = sync.New(chain, dataSource, utils.NewNopZapLogger(), 0, 0, false, testDB) @@ -346,7 +353,7 @@ func TestPendingData(t *testing.T) { t.Run("get pending state before index", func(t *testing.T) { var synchronizer *sync.Synchronizer testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(chain, gw) synchronizer = sync.New(chain, dataSource, utils.NewNopZapLogger(), 0, 0, false, testDB) @@ -373,7 +380,7 @@ func TestPendingData(t *testing.T) { require.NoError(t, err) expectedVal, err := new(felt.Felt).SetString("0x1d057bfbd3cadebffd74") require.NoError(t, err) - require.Equal(t, expectedVal, val) + require.Equal(t, *expectedVal, val) t.Cleanup(func() { require.NoError(t, pendingStateCloser()) }) @@ -388,7 +395,7 @@ func TestPendingData(t *testing.T) { require.NoError(t, err) expectedVal, err = new(felt.Felt).SetString("0x1d057bfbd3df63f5dd54") require.NoError(t, err) - require.Equal(t, expectedVal, val) + require.Equal(t, *expectedVal, val) t.Cleanup(func() { require.NoError(t, pendingStateCloser()) }) @@ -401,7 +408,7 @@ func TestSubscribeNewHeads(t *testing.T) { testDB := memory.New() log := utils.NewNopZapLogger() network := utils.Mainnet - chain := blockchain.New(testDB, &network) + chain := blockchain.New(testDB, &network, statetestutils.UseNewState()) feeder := feeder.NewTestClient(t, &network) gw := adaptfeeder.New(feeder) dataSource := sync.NewFeederGatewayDataSource(chain, gw) @@ -430,7 +437,7 @@ func TestSubscribePending(t *testing.T) { testDB := memory.New() log := utils.NewNopZapLogger() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New( bc, diff --git a/vm/vm_test.go b/vm/vm_test.go index 7f586bba48..d7927b2374 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/state/commonstate" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/rpc/rpccore" @@ -18,6 +19,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + statetestutils.Parse() + os.Exit(m.Run()) +} + func TestCallDeprecatedCairo(t *testing.T) { testDB := memory.New() txn := testDB.NewIndexedBatch() @@ -33,54 +39,51 @@ func TestCallDeprecatedCairo(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - for _, stateVersion := range []bool{true, false} { - stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) - require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(t, err) - - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, - }, + stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) - - entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") - - ret, err := New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) - - require.NoError(t, testState.Update(1, &core.StateUpdate{ - OldRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - NewRoot: utils.HexToFelt(t, "0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), - StateDiff: &core.StateDiff{ - StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ - *contractAddr: { - *utils.HexToFelt(t, "0x206f38f7e4f15e87567361213c28f235cccdaa1d7fd34c9db1dfe9489c6a091"): new(felt.Felt).SetUint64(1337), - }, + }, + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) + + entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") + + ret, err := New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + + require.NoError(t, testState.Update(1, &core.StateUpdate{ + OldRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + NewRoot: utils.HexToFelt(t, "0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), + StateDiff: &core.StateDiff{ + StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ + *contractAddr: { + *utils.HexToFelt(t, "0x206f38f7e4f15e87567361213c28f235cccdaa1d7fd34c9db1dfe9489c6a091"): new(felt.Felt).SetUint64(1337), }, }, - }, nil, false)) - - ret, err = New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret.Result) - } + }, + }, nil, false)) + + ret, err = New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Mainnet, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret.Result) } func TestCallDeprecatedCairoMaxSteps(t *testing.T) { @@ -98,33 +101,31 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - for _, stateVersion := range []bool{true, false} { - stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) - require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(t, err) + stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, - }, + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) - - entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") - - _, err = New(false, nil).Call(&CallInfo{ - ContractAddress: contractAddr, - ClassHash: classHash, - Selector: entryPoint, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 0, simpleClass.SierraVersion(), false, false) - assert.ErrorContains(t, err, "RunResources has no remaining steps") - } + }, + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) + + entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") + + _, err = New(false, nil).Call(&CallInfo{ + ContractAddress: contractAddr, + ClassHash: classHash, + Selector: entryPoint, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Mainnet, 0, simpleClass.SierraVersion(), false, false) + assert.ErrorContains(t, err, "RunResources has no remaining steps") } func TestCallCairo(t *testing.T) { @@ -142,62 +143,60 @@ func TestCallCairo(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - for _, stateVersion := range []bool{true, false} { - stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) - require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(t, err) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, - }, + stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) + }, + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) - logLevel := utils.NewLogLevel(utils.ERROR) - log, err := utils.NewZapLogger(logLevel, false) - require.NoError(t, err) + logLevel := utils.NewLogLevel(utils.ERROR) + log, err := utils.NewZapLogger(logLevel, false) + require.NoError(t, err) - // test_storage_read - entryPoint := utils.HexToFelt(t, "0x5df99ae77df976b4f0e5cf28c7dcfe09bd6e81aab787b19ac0c08e03d928cf") - storageLocation := utils.HexToFelt(t, "0x44") - ret, err := New(false, log).Call(&CallInfo{ - ContractAddress: contractAddr, - Selector: entryPoint, - Calldata: []felt.Felt{ - *storageLocation, - }, - }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) - - require.NoError(t, testState.Update(1, &core.StateUpdate{ - OldRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), - NewRoot: utils.HexToFelt(t, "0x7a9da0a7471a8d5118d3eefb8c26a6acbe204eb1eaa934606f4757a595fe552"), - StateDiff: &core.StateDiff{ - StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ - *contractAddr: { - *storageLocation: new(felt.Felt).SetUint64(37), - }, + // test_storage_read + entryPoint := utils.HexToFelt(t, "0x5df99ae77df976b4f0e5cf28c7dcfe09bd6e81aab787b19ac0c08e03d928cf") + storageLocation := utils.HexToFelt(t, "0x44") + ret, err := New(false, log).Call(&CallInfo{ + ContractAddress: contractAddr, + Selector: entryPoint, + Calldata: []felt.Felt{ + *storageLocation, + }, + }, &BlockInfo{Header: &core.Header{}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + + require.NoError(t, testState.Update(1, &core.StateUpdate{ + OldRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), + NewRoot: utils.HexToFelt(t, "0x7a9da0a7471a8d5118d3eefb8c26a6acbe204eb1eaa934606f4757a595fe552"), + StateDiff: &core.StateDiff{ + StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{ + *contractAddr: { + *storageLocation: new(felt.Felt).SetUint64(37), }, }, - }, nil, false)) - - ret, err = New(false, log).Call(&CallInfo{ - ContractAddress: contractAddr, - Selector: entryPoint, - Calldata: []felt.Felt{ - *storageLocation, - }, - }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) - require.NoError(t, err) - assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(37)}, ret.Result) - } + }, + }, nil, false)) + + ret, err = New(false, log).Call(&CallInfo{ + ContractAddress: contractAddr, + Selector: entryPoint, + Calldata: []felt.Felt{ + *storageLocation, + }, + }, &BlockInfo{Header: &core.Header{Number: 1}}, testState, &utils.Goerli, 1_000_000, simpleClass.SierraVersion(), false, false) + require.NoError(t, err) + assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(37)}, ret.Result) } func TestCallInfoErrorHandling(t *testing.T) { @@ -214,50 +213,48 @@ func TestCallInfoErrorHandling(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - for _, stateVersion := range []bool{true, false} { - stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) - require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(t, err) - require.NoError(t, testState.Update(0, &core.StateUpdate{ - OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0xa6258de574e5540253c4a52742137d58b9e8ad8f584115bee46d9d18255c42"), - StateDiff: &core.StateDiff{ - DeployedContracts: map[felt.Felt]*felt.Felt{ - *contractAddr: classHash, - }, + stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + require.NoError(t, err) + testState, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + require.NoError(t, testState.Update(0, &core.StateUpdate{ + OldRoot: &felt.Zero, + NewRoot: utils.HexToFelt(t, "0xa6258de574e5540253c4a52742137d58b9e8ad8f584115bee46d9d18255c42"), + StateDiff: &core.StateDiff{ + DeployedContracts: map[felt.Felt]*felt.Felt{ + *contractAddr: classHash, }, - }, map[felt.Felt]core.Class{ - *classHash: simpleClass, - }, false)) + }, + }, map[felt.Felt]core.Class{ + *classHash: simpleClass, + }, false)) - logLevel := utils.NewLogLevel(utils.ERROR) - log, err := utils.NewZapLogger(logLevel, false) - require.NoError(t, err) + logLevel := utils.NewLogLevel(utils.ERROR) + log, err := utils.NewZapLogger(logLevel, false) + require.NoError(t, err) - callInfo := &CallInfo{ - ClassHash: classHash, - ContractAddress: contractAddr, - Selector: utils.HexToFelt(t, "0x123"), // doesn't exist - Calldata: []felt.Felt{}, - } - - // Starknet version <0.13.4 should return an error - ret, err := New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ - ProtocolVersion: "0.13.0", - }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) - require.Equal(t, CallResult{}, ret) - require.ErrorContains(t, err, "not found in contract") - - // Starknet version 0.13.4 should return an "error" in the CallInfo - ret, err = New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ - ProtocolVersion: "0.13.4", - }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) - require.True(t, ret.ExecutionFailed) - require.Equal(t, len(ret.Result), 1) - require.Equal(t, ret.Result[0].String(), rpccore.EntrypointNotFoundFelt) - require.Nil(t, err) + callInfo := &CallInfo{ + ClassHash: classHash, + ContractAddress: contractAddr, + Selector: utils.HexToFelt(t, "0x123"), // doesn't exist + Calldata: []felt.Felt{}, } + + // Starknet version <0.13.4 should return an error + ret, err := New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ + ProtocolVersion: "0.13.0", + }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) + require.Equal(t, CallResult{}, ret) + require.ErrorContains(t, err, "not found in contract") + + // Starknet version 0.13.4 should return an "error" in the CallInfo + ret, err = New(false, log).Call(callInfo, &BlockInfo{Header: &core.Header{ + ProtocolVersion: "0.13.4", + }}, testState, &utils.Sepolia, 1_000_000, simpleClass.SierraVersion(), false, false) + require.True(t, ret.ExecutionFailed) + require.Equal(t, len(ret.Result), 1) + require.Equal(t, ret.Result[0].String(), rpccore.EntrypointNotFoundFelt) + require.Nil(t, err) } func TestExecute(t *testing.T) { @@ -269,35 +266,33 @@ func TestExecute(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - for _, stateVersion := range []bool{true, false} { - stateFactory, err := commonstate.NewStateFactory(stateVersion, triedb, stateDB) + stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + require.NoError(t, err) + state, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(t, err) + + t.Run("empty transaction list", func(t *testing.T) { + _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ + Header: &core.Header{ + Timestamp: 1666877926, + SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), + L1GasPriceETH: &felt.Zero, + L1GasPriceSTRK: &felt.Zero, + }, + }, state, + &network, false, false, false, false, false) require.NoError(t, err) - state, err := stateFactory.NewState(&felt.Zero, txn) + }) + t.Run("zero data", func(t *testing.T) { + _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ + Header: &core.Header{ + SequencerAddress: &felt.Zero, + L1GasPriceETH: &felt.Zero, + L1GasPriceSTRK: &felt.Zero, + }, + }, state, &network, false, false, false, false, false) require.NoError(t, err) - - t.Run("empty transaction list", func(t *testing.T) { - _, err := New(false, nil).Execute([]core.Transaction{}, []core.Class{}, []*felt.Felt{}, &BlockInfo{ - Header: &core.Header{ - Timestamp: 1666877926, - SequencerAddress: utils.HexToFelt(t, "0x46a89ae102987331d369645031b49c27738ed096f2789c24449966da4c6de6b"), - L1GasPriceETH: &felt.Zero, - L1GasPriceSTRK: &felt.Zero, - }, - }, state, - &network, false, false, false, false, false) - require.NoError(t, err) - }) - t.Run("zero data", func(t *testing.T) { - _, err := New(false, nil).Execute(nil, nil, []*felt.Felt{}, &BlockInfo{ - Header: &core.Header{ - SequencerAddress: &felt.Zero, - L1GasPriceETH: &felt.Zero, - L1GasPriceSTRK: &felt.Zero, - }, - }, state, &network, false, false, false, false, false) - require.NoError(t, err) - }) - } + }) } func TestSetVersionedConstants(t *testing.T) { From 885787fcd89fd72c2414eac84a0491f52bc4a1a0 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 4 Aug 2025 01:22:55 +0200 Subject: [PATCH 04/47] cleanups --- core/running_event_filter_test.go | 10 ++-------- core/state/state_test_utils/new_state_flag.go | 2 -- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/core/running_event_filter_test.go b/core/running_event_filter_test.go index bd42ccfa2c..c98f10b799 100644 --- a/core/running_event_filter_test.go +++ b/core/running_event_filter_test.go @@ -1,13 +1,13 @@ package core_test import ( - "flag" "testing" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/encoder" @@ -18,12 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -var newState bool - -func init() { - flag.BoolVar(&newState, "use-new-state", false, "...") -} - func testBloomWithRandomKeys(t *testing.T, numKeys uint) *bloom.BloomFilter { t.Helper() filter := bloom.New(core.EventsBloomLength, core.EventsBloomHashFuncs) @@ -74,7 +68,7 @@ func TestRunningEventFilter_LazyInitialization_EmptyDB(t *testing.T) { func TestRunningEventFilter_LazyInitialization_Preload(t *testing.T) { testDB := memory.New() n := &utils.Sepolia - chain := blockchain.New(testDB, n, newState) + chain := blockchain.New(testDB, n, statetestutils.UseNewState()) client := feeder.NewTestClient(t, n) gw := adaptfeeder.New(client) diff --git a/core/state/state_test_utils/new_state_flag.go b/core/state/state_test_utils/new_state_flag.go index 606dac39d3..a34a357c91 100644 --- a/core/state/state_test_utils/new_state_flag.go +++ b/core/state/state_test_utils/new_state_flag.go @@ -2,7 +2,6 @@ package statetestutils import ( "flag" - "fmt" "os" "strings" "sync" @@ -16,7 +15,6 @@ var ( func parseFlags() { flag.BoolVar(&useNewState, "use-new-state", false, "use new state implementation") - fmt.Println("use-new-state", useNewState) cleanArgs := []string{os.Args[0]} for i := 1; i < len(os.Args); i++ { From 56cb030ebfd186ae0b873b6718546f92a9cd1e3e Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 4 Aug 2025 17:32:34 +0200 Subject: [PATCH 05/47] fix remaining unit tests --- rpc/v6/helpers.go | 12 ++++++------ rpc/v7/helpers.go | 15 +++++++++------ rpc/v8/helpers.go | 12 ++++++------ rpc/v9/helpers.go | 12 ++++++------ 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index 490562bcbd..c59c9a6ac4 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -16,17 +16,17 @@ import ( "github.com/NethermindEth/juno/sync" ) -func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { +func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { l1Head, err := h.bcReader.L1Head() if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) + return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // nil is returned if l1 head doesn't exist - return &l1Head, nil + // empty l1Head is returned if l1 head doesn't exist + return l1Head, nil } -func isL1Verified(n uint64, l1 *core.L1Head) bool { - if l1 != nil && l1.BlockNumber >= n { +func isL1Verified(n uint64, l1 core.L1Head) bool { + if l1 != (core.L1Head{}) && l1.BlockNumber >= n { return true } return false diff --git a/rpc/v7/helpers.go b/rpc/v7/helpers.go index b6a21a8b01..dca0c560f7 100644 --- a/rpc/v7/helpers.go +++ b/rpc/v7/helpers.go @@ -16,17 +16,20 @@ import ( "github.com/NethermindEth/juno/sync" ) -func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { +func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { l1Head, err := h.bcReader.L1Head() if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) + return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // nil is returned if l1 head doesn't exist - return &l1Head, nil + if errors.Is(err, db.ErrKeyNotFound) { + return core.L1Head{}, nil + } + // empty l1Head is returned if l1 head doesn't exist + return l1Head, nil } -func isL1Verified(n uint64, l1 *core.L1Head) bool { - if l1 != nil && l1.BlockNumber >= n { +func isL1Verified(n uint64, l1 core.L1Head) bool { + if l1 != (core.L1Head{}) && l1.BlockNumber >= n { return true } return false diff --git a/rpc/v8/helpers.go b/rpc/v8/helpers.go index 7c7f0f723d..44a71963af 100644 --- a/rpc/v8/helpers.go +++ b/rpc/v8/helpers.go @@ -16,17 +16,17 @@ import ( "github.com/NethermindEth/juno/sync" ) -func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { +func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { l1Head, err := h.bcReader.L1Head() if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) + return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // nil is returned if l1 head doesn't exist - return &l1Head, nil + // empty l1Head is returned if l1 head doesn't exist + return l1Head, nil } -func isL1Verified(n uint64, l1 *core.L1Head) bool { - if l1 != nil && l1.BlockNumber >= n { +func isL1Verified(n uint64, l1 core.L1Head) bool { + if l1 != (core.L1Head{}) && l1.BlockNumber >= n { return true } return false diff --git a/rpc/v9/helpers.go b/rpc/v9/helpers.go index bb152427e8..a66e7c28c3 100644 --- a/rpc/v9/helpers.go +++ b/rpc/v9/helpers.go @@ -16,17 +16,17 @@ import ( "github.com/NethermindEth/juno/sync" ) -func (h *Handler) l1Head() (*core.L1Head, *jsonrpc.Error) { +func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { l1Head, err := h.bcReader.L1Head() if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return nil, jsonrpc.Err(jsonrpc.InternalError, err.Error()) + return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // nil is returned if l1 head doesn't exist - return &l1Head, nil + // empty l1Head is returned if l1 head doesn't exist + return l1Head, nil } -func isL1Verified(n uint64, l1 *core.L1Head) bool { - if l1 != nil && l1.BlockNumber >= n { +func isL1Verified(n uint64, l1 core.L1Head) bool { + if l1 != (core.L1Head{}) && l1.BlockNumber >= n { return true } return false From 4892f07876e95c205edb567003ddb59aba5d80af Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 5 Aug 2025 10:36:07 +0200 Subject: [PATCH 06/47] integrate env flag with the unit tests --- Makefile | 10 +- blockchain/blockchain_test.go | 6 -- cmd/juno/dbcmd_test.go | 6 -- consensus/integtest/integ_test.go | 6 -- .../p2p/validator/empty_fixtures_test.go | 6 -- consensus/proposer/proposer_test.go | 6 -- core/state/state_test_utils/new_state_flag.go | 35 ++---- l1/l1_pkg_test.go | 6 -- mempool/mempool_test.go | 5 - migration/migration_pkg_test.go | 6 -- node/node_test.go | 6 -- plugin/plugin_test.go | 6 -- rpc/v6/block_test.go | 6 -- rpc/v7/block_test.go | 6 -- rpc/v8/simulation_pkg_test.go | 7 -- rpc/v8/storage_test.go | 102 +++++++++++++----- rpc/v9/simulation_pkg_test.go | 7 -- rpc/v9/storage_test.go | 102 +++++++++++++----- sequencer/sequencer_test.go | 6 -- sync/sync_test.go | 6 -- vm/vm_test.go | 5 - 21 files changed, 165 insertions(+), 186 deletions(-) diff --git a/Makefile b/Makefile index be14c41d96..00fe60c694 100644 --- a/Makefile +++ b/Makefile @@ -71,10 +71,16 @@ clean-testcache: ## Clean Go test cache go clean -testcache test: clean-testcache rustdeps ## Run tests - go test $(GO_TAGS) ./... + go test $(GO_TAGS) -v ./... + +test: clean-testcache rustdeps ## Run tests + go test $(GO_TAGS) -v ./... + +test-new-state: clean-testcache rustdeps ## Run tests with new state + USE_NEW_STATE=true go test $(GO_TAGS) -v ./... test-cached: rustdeps ## Run cached tests - go test $(GO_TAGS) ./... + go test $(GO_TAGS) ./... -args test-race: clean-testcache rustdeps ## Run tests with race detection go test $(GO_TAGS) ./... -race $(TEST_RACE_LDFLAGS) diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index d5deb5bd72..37b94cf294 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -2,7 +2,6 @@ package blockchain_test import ( "fmt" - "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -19,11 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - var emptyCommitments = core.BlockCommitments{} func TestNew(t *testing.T) { diff --git a/cmd/juno/dbcmd_test.go b/cmd/juno/dbcmd_test.go index cdb344e145..07ef714fa2 100644 --- a/cmd/juno/dbcmd_test.go +++ b/cmd/juno/dbcmd_test.go @@ -1,7 +1,6 @@ package main_test import ( - "os" "strconv" "testing" @@ -18,11 +17,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - var emptyCommitments = core.BlockCommitments{} func TestDBCmd(t *testing.T) { diff --git a/consensus/integtest/integ_test.go b/consensus/integtest/integ_test.go index 762f10c84a..987146831e 100644 --- a/consensus/integtest/integ_test.go +++ b/consensus/integtest/integ_test.go @@ -2,7 +2,6 @@ package integtest import ( "fmt" - "os" "testing" "time" @@ -76,11 +75,6 @@ func getTimeoutFn(cfg *testConfig) func(types.Step, types.Round) time.Duration { } } -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func newDB(t *testing.T) *mocks.MockTendermintDB[starknet.Value, starknet.Hash, starknet.Address] { t.Helper() ctrl := gomock.NewController(t) diff --git a/consensus/p2p/validator/empty_fixtures_test.go b/consensus/p2p/validator/empty_fixtures_test.go index bec88e41f9..52335a5df0 100644 --- a/consensus/p2p/validator/empty_fixtures_test.go +++ b/consensus/p2p/validator/empty_fixtures_test.go @@ -2,7 +2,6 @@ package validator import ( "math/rand/v2" - "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -27,11 +26,6 @@ type EmptyTestFixture struct { Proposal *starknet.Proposal } -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func NewEmptyTestFixture( t *testing.T, executor *mockExecutor, diff --git a/consensus/proposer/proposer_test.go b/consensus/proposer/proposer_test.go index 8de6f89fae..2b6325088d 100644 --- a/consensus/proposer/proposer_test.go +++ b/consensus/proposer/proposer_test.go @@ -3,7 +3,6 @@ package proposer_test import ( "context" "fmt" - "os" "slices" "testing" "time" @@ -39,11 +38,6 @@ const ( var allBatchSizes = []int{1, 0, 3, 2, 4, 0, 1} -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestProposer(t *testing.T) { logger, err := utils.NewZapLogger(utils.NewLogLevel(logLevel), true) require.NoError(t, err) diff --git a/core/state/state_test_utils/new_state_flag.go b/core/state/state_test_utils/new_state_flag.go index a34a357c91..90bcb92c6f 100644 --- a/core/state/state_test_utils/new_state_flag.go +++ b/core/state/state_test_utils/new_state_flag.go @@ -1,42 +1,21 @@ package statetestutils import ( - "flag" "os" - "strings" + "strconv" "sync" ) var ( - once sync.Once - parsed bool useNewState bool + once sync.Once ) -func parseFlags() { - flag.BoolVar(&useNewState, "use-new-state", false, "use new state implementation") - - cleanArgs := []string{os.Args[0]} - for i := 1; i < len(os.Args); i++ { - arg := os.Args[i] - if arg == "-use-new-state" || strings.HasPrefix(arg, "-use-new-state=") { - continue - } - cleanArgs = append(cleanArgs, arg) - } - os.Args = cleanArgs - - flag.Parse() - parsed = true -} - -func Parse() { - once.Do(parseFlags) -} - func UseNewState() bool { - if !parsed { - Parse() - } + once.Do(func() { + val := os.Getenv("USE_NEW_STATE") + parsed, err := strconv.ParseBool(val) + useNewState = err == nil && parsed + }) return useNewState } diff --git a/l1/l1_pkg_test.go b/l1/l1_pkg_test.go index dbceb69c50..d19fddadc1 100644 --- a/l1/l1_pkg_test.go +++ b/l1/l1_pkg_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math/big" - "os" "testing" "time" @@ -115,11 +114,6 @@ var longSequenceOfBlocks = []*l1Block{ }, } -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestClient(t *testing.T) { t.Parallel() diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 1ac6885274..7bf9ec95c3 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -41,11 +41,6 @@ func setupDatabase(dbPath string, dltExisting bool) (db.KeyValueStore, func(), e return persistentPool, closer, nil } -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestMempool(t *testing.T) { testDB, dbCloser, err := setupDatabase("testmempool", true) log := utils.NewNopZapLogger() diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index 9fcea5dda0..dfab2f136d 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "math/rand" - "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -27,11 +26,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestMigration0000(t *testing.T) { testDB := memory.New() diff --git a/node/node_test.go b/node/node_test.go index eeab1b1e5f..b439f56e0a 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -2,7 +2,6 @@ package node_test import ( "context" - "os" "testing" "time" @@ -17,11 +16,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - // Create a new node with all services enabled. func TestNewNode(t *testing.T) { config := &node.Config{ diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index 8d5bfc89f7..3694dcdd72 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -2,7 +2,6 @@ package plugin_test import ( "context" - "os" "testing" "time" @@ -19,11 +18,6 @@ import ( "go.uber.org/mock/gomock" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestPlugin(t *testing.T) { timeout := time.Second mockCtrl := gomock.NewController(t) diff --git a/rpc/v6/block_test.go b/rpc/v6/block_test.go index 4e00a8354f..8fa1850ef5 100644 --- a/rpc/v6/block_test.go +++ b/rpc/v6/block_test.go @@ -2,7 +2,6 @@ package rpcv6_test import ( "errors" - "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -23,11 +22,6 @@ import ( "go.uber.org/mock/gomock" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestBlockId(t *testing.T) { t.Parallel() tests := map[string]struct { diff --git a/rpc/v7/block_test.go b/rpc/v7/block_test.go index ce0dbd5b99..5641cafcc0 100644 --- a/rpc/v7/block_test.go +++ b/rpc/v7/block_test.go @@ -2,7 +2,6 @@ package rpcv7_test import ( "errors" - "os" "testing" "github.com/NethermindEth/juno/blockchain" @@ -24,11 +23,6 @@ import ( "go.uber.org/mock/gomock" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestBlockId(t *testing.T) { t.Parallel() tests := map[string]struct { diff --git a/rpc/v8/simulation_pkg_test.go b/rpc/v8/simulation_pkg_test.go index b39e2efd3b..5de7b08e81 100644 --- a/rpc/v8/simulation_pkg_test.go +++ b/rpc/v8/simulation_pkg_test.go @@ -3,12 +3,10 @@ package rpcv8 import ( "encoding/json" "errors" - "os" "testing" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -16,11 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - //nolint:dupl func TestCreateSimulatedTransactions(t *testing.T) { executionResults := vm.ExecutionResults{ diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 34c4c5e0bf..77e5661fe5 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -11,8 +11,12 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commontrie" statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -164,11 +168,45 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - tempTrie := emptyTrie(t) - _, _ = tempTrie.Put(key, value) - _, _ = tempTrie.Put(key2, value2) - _ = tempTrie.Commit() - trieRoot, _ := tempTrie.Root() + var classTrie, contractTrie commontrie.Trie + trieRoot := &felt.Zero + + if !statetestutils.UseNewState() { + tempTrie := emptyTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ = tempTrie.Root() + classTrie = commontrie.NewTrieAdapter(tempTrie) + contractTrie = commontrie.NewTrieAdapter(tempTrie) + } else { + newComm := new(felt.Felt).SetUint64(1) + createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { + tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) + _ = tr.Update(key, value) + _ = tr.Update(key2, value2) + require.NoError(t, err) + _, nodes := tr.Commit() + err = trieDB.Update(newComm, &felt.Zero, trienode.NewMergeNodeSet(nodes)) + require.NoError(t, err) + return tr + } + + // TODO(weiihann): should have a better way of testing + trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) + createTrie(t, trieutils.NewClassTrieID(felt.Zero), &trieDB) + contractTrie2 := createTrie(t, trieutils.NewContractTrieID(felt.Zero), &trieDB) + tmpTrieRoot := contractTrie2.Hash() + trieRoot = &tmpTrieRoot + + // recreate because the previous ones are committed + classTrie2, err := trie2.New(trieutils.NewClassTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + require.NoError(t, err) + contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + require.NoError(t, err) + classTrie = commontrie.NewTrie2Adapter(classTrie2) + contractTrie = commontrie.NewTrie2Adapter(contractTrie2) + } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -178,8 +216,8 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() - mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() - mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(classTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(contractTrie, nil).AnyTimes() log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, log) @@ -275,14 +313,14 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, nil, nil) require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key, *key2}, nil, nil) @@ -295,12 +333,12 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { - mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(1) - mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(0) + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) // TODO(maksym): after integration change to state.ErrContractNotDeployed + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) // TODO(maksym): after integration change to state.ErrContractNotDeployed proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*noSuchKey}, nil) require.Nil(t, rpcErr) @@ -308,13 +346,13 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) - mockState.EXPECT().ContractNonce(key).Return(nonce, nil).Times(1) - classHasah := new(felt.Felt).SetUint64(1234) - mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil).Times(1) + mockState.EXPECT().ContractNonce(key).Return(*nonce, nil).Times(1) + classHash := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil).Times(1) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -324,13 +362,13 @@ func TestStorageProof(t *testing.T) { require.NotNil(t, proof.ContractsProof.LeavesData[0]) ld := proof.ContractsProof.LeavesData[0] require.Equal(t, nonce, ld.Nonce) - require.Equal(t, classHasah, ld.ClassHash) + require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, contractTrie.HashFn()) }) t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xdead") - mockState.EXPECT().ContractStorageTrie(contract).Return(emptyTrie(t), nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(emptyCommonTrie(t), nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -342,7 +380,7 @@ func TestStorageProof(t *testing.T) { //nolint:dupl t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xabcd") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -351,12 +389,12 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xadd0") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -365,13 +403,13 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) - mockState.EXPECT().ContractNonce(key).Return(nonce, nil) - classHasah := new(felt.Felt).SetUint64(1234) - mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) + mockState.EXPECT().ContractNonce(key).Return(*nonce, nil) + classHash := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil) proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -786,6 +824,16 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } +func emptyCommonTrie(t *testing.T) commontrie.Trie { + if statetestutils.UseNewState() { + tempTrie, err := trie2.NewEmptyPedersen() + require.NoError(t, err) + return commontrie.NewTrie2Adapter(tempTrie) + } else { + return commontrie.NewTrieAdapter(emptyTrie(t)) + } +} + func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) if classRoot.IsZero() { diff --git a/rpc/v9/simulation_pkg_test.go b/rpc/v9/simulation_pkg_test.go index c68a64a0e7..84affce3ab 100644 --- a/rpc/v9/simulation_pkg_test.go +++ b/rpc/v9/simulation_pkg_test.go @@ -3,12 +3,10 @@ package rpcv9 import ( "encoding/json" "errors" - "os" "testing" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -16,11 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestCreateSimulatedTransactions(t *testing.T) { executionResults := vm.ExecutionResults{ OverallFees: []*felt.Felt{new(felt.Felt).SetUint64(10), new(felt.Felt).SetUint64(20)}, diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index d3a7fc317b..b3eea096a9 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -11,8 +11,12 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commontrie" statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -182,11 +186,45 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - tempTrie := emptyTrie(t) - _, _ = tempTrie.Put(key, value) - _, _ = tempTrie.Put(key2, value2) - _ = tempTrie.Commit() - trieRoot, _ := tempTrie.Root() + var classTrie, contractTrie commontrie.Trie + trieRoot := &felt.Zero + + if !statetestutils.UseNewState() { + tempTrie := emptyTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ = tempTrie.Root() + classTrie = commontrie.NewTrieAdapter(tempTrie) + contractTrie = commontrie.NewTrieAdapter(tempTrie) + } else { + newComm := new(felt.Felt).SetUint64(1) + createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { + tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) + _ = tr.Update(key, value) + _ = tr.Update(key2, value2) + require.NoError(t, err) + _, nodes := tr.Commit() + err = trieDB.Update(newComm, &felt.Zero, trienode.NewMergeNodeSet(nodes)) + require.NoError(t, err) + return tr + } + + // TODO(weiihann): should have a better way of testing + trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) + createTrie(t, trieutils.NewClassTrieID(felt.Zero), &trieDB) + contractTrie2 := createTrie(t, trieutils.NewContractTrieID(felt.Zero), &trieDB) + tmpTrieRoot := contractTrie2.Hash() + trieRoot = &tmpTrieRoot + + // recreate because the previous ones are committed + classTrie2, err := trie2.New(trieutils.NewClassTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + require.NoError(t, err) + contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + require.NoError(t, err) + classTrie = commontrie.NewTrie2Adapter(classTrie2) + contractTrie = commontrie.NewTrie2Adapter(contractTrie2) + } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -196,8 +234,8 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() - mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() - mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(classTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(contractTrie, nil).AnyTimes() log := utils.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, log) @@ -293,14 +331,14 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, nil, nil) require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key, *key2}, nil, nil) @@ -313,12 +351,12 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { - mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(1) - mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(0) + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) // TODO(maksym): after integration change to state.ErrContractNotDeployed + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) // TODO(maksym): after integration change to state.ErrContractNotDeployed proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*noSuchKey}, nil) require.Nil(t, rpcErr) @@ -326,13 +364,13 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) - mockState.EXPECT().ContractNonce(key).Return(nonce, nil).Times(1) - classHasah := new(felt.Felt).SetUint64(1234) - mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil).Times(1) + mockState.EXPECT().ContractNonce(key).Return(*nonce, nil).Times(1) + classHash := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil).Times(1) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -342,13 +380,13 @@ func TestStorageProof(t *testing.T) { require.NotNil(t, proof.ContractsProof.LeavesData[0]) ld := proof.ContractsProof.LeavesData[0] require.Equal(t, nonce, ld.Nonce) - require.Equal(t, classHasah, ld.ClassHash) + require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, contractTrie.HashFn()) }) t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xdead") - mockState.EXPECT().ContractStorageTrie(contract).Return(emptyTrie(t), nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(emptyCommonTrie(t), nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -360,7 +398,7 @@ func TestStorageProof(t *testing.T) { //nolint:dupl t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xabcd") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -369,12 +407,12 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { contract := utils.HexToFelt(t, "0xadd0") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -383,13 +421,13 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) - mockState.EXPECT().ContractNonce(key).Return(nonce, nil) - classHasah := new(felt.Felt).SetUint64(1234) - mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) + mockState.EXPECT().ContractNonce(key).Return(*nonce, nil) + classHash := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil) proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -804,6 +842,16 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } +func emptyCommonTrie(t *testing.T) commontrie.Trie { + if statetestutils.UseNewState() { + tempTrie, err := trie2.NewEmptyPedersen() + require.NoError(t, err) + return commontrie.NewTrie2Adapter(tempTrie) + } else { + return commontrie.NewTrieAdapter(emptyTrie(t)) + } +} + func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) if classRoot.IsZero() { diff --git a/sequencer/sequencer_test.go b/sequencer/sequencer_test.go index 3521c0a121..bdc8d7955b 100644 --- a/sequencer/sequencer_test.go +++ b/sequencer/sequencer_test.go @@ -3,7 +3,6 @@ package sequencer_test import ( "context" "crypto/rand" - "os" "testing" "time" @@ -25,11 +24,6 @@ import ( "go.uber.org/mock/gomock" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func getEmptySequencer(t *testing.T, blockTime time.Duration, seqAddr *felt.Felt) (sequencer.Sequencer, *blockchain.Blockchain) { t.Helper() testDB := memory.New() diff --git a/sync/sync_test.go b/sync/sync_test.go index cf177d5051..5aefe9bc7d 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -3,7 +3,6 @@ package sync_test import ( "context" "errors" - "os" "sync/atomic" "testing" "time" @@ -25,11 +24,6 @@ import ( const timeout = time.Second -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestSyncBlocks(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) diff --git a/vm/vm_test.go b/vm/vm_test.go index d7927b2374..3604ab99e1 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -19,11 +19,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - statetestutils.Parse() - os.Exit(m.Run()) -} - func TestCallDeprecatedCairo(t *testing.T) { testDB := memory.New() txn := testDB.NewIndexedBatch() From 636275399ccca5c65d1f3e1a567b402268e73d9a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 5 Aug 2025 11:44:34 +0200 Subject: [PATCH 07/47] add failing tests, to be fixed in a moment --- core/state/commonstate/state_test.go | 224 ++++++++++++++++++ core/state/commontrie/trie_test.go | 76 ++++++ .../state_test_utils/new_state_flag_test.go | 26 ++ 3 files changed, 326 insertions(+) create mode 100644 core/state/commonstate/state_test.go create mode 100644 core/state/commontrie/trie_test.go create mode 100644 core/state/state_test_utils/new_state_flag_test.go diff --git a/core/state/commonstate/state_test.go b/core/state/commonstate/state_test.go new file mode 100644 index 0000000000..5c2585f59e --- /dev/null +++ b/core/state/commonstate/state_test.go @@ -0,0 +1,224 @@ +package commonstate_test + +import ( + "testing" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/trie2/triedb" + "github.com/NethermindEth/juno/db/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCoreStateAdapter_ContractStorageAt(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + key := &felt.Felt{} + blockNumber := uint64(0) + + value, err := coreStateAdapter.ContractStorageAt(addr, key, blockNumber) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func TestCoreStateAdapter_ContractNonceAt(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + blockNumber := uint64(0) + + nonce, err := coreStateAdapter.ContractNonceAt(addr, blockNumber) + require.NoError(t, err) + assert.Equal(t, felt.Zero, nonce) +} + +func TestCoreStateAdapter_ContractClassHashAt(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + blockNumber := uint64(0) + + classHash, err := coreStateAdapter.ContractClassHashAt(addr, blockNumber) + require.NoError(t, err) + assert.Equal(t, felt.Zero, classHash) +} + +func TestCoreStateAdapter_ContractDeployedAt(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + blockNumber := uint64(0) + + deployed, err := coreStateAdapter.ContractDeployedAt(addr, blockNumber) + require.NoError(t, err) + assert.False(t, deployed) +} + +func TestCoreStateAdapter_ContractClassHash(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + + classHash, err := coreStateAdapter.ContractClassHash(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, classHash) +} + +func TestCoreStateAdapter_ContractNonce(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + + nonce, err := coreStateAdapter.ContractNonce(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, nonce) +} + +func TestCoreStateAdapter_ContractStorage(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + key := &felt.Felt{} + + value, err := coreStateAdapter.ContractStorage(addr, key) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func TestCoreStateAdapter_Class(t *testing.T) { + coreStateAdapter := setupCoreStateAdapter(t) + classHash := &felt.Felt{} + + class, err := coreStateAdapter.Class(classHash) + require.NoError(t, err) + assert.Nil(t, class) +} + +func TestStateAdapter_ClassTrie(t *testing.T) { + stateAdapter := setupStateAdapter(t) + + trie, err := stateAdapter.ClassTrie() + require.NoError(t, err) + assert.Nil(t, trie) +} + +func TestStateAdapter_ContractTrie(t *testing.T) { + stateAdapter := setupStateAdapter(t) + + trie, err := stateAdapter.ContractTrie() + require.NoError(t, err) + assert.Nil(t, trie) +} + +func TestStateAdapter_ContractStorageTrie(t *testing.T) { + stateAdapter := setupStateAdapter(t) + addr := &felt.Felt{} + + trie, err := stateAdapter.ContractStorageTrie(addr) + require.NoError(t, err) + assert.Nil(t, trie) +} + +func TestCoreStateReaderAdapter_Class(t *testing.T) { + coreStateReaderAdapter := setupStateAdapter(t) + classHash := &felt.Felt{} + + class, err := coreStateReaderAdapter.Class(classHash) + require.NoError(t, err) + assert.Nil(t, class) +} + +func TestCoreStateReaderAdapter_ContractClassHash(t *testing.T) { + coreStateReaderAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + + classHash, err := coreStateReaderAdapter.ContractClassHash(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, classHash) +} + +func TestCoreStateReaderAdapter_ContractNonce(t *testing.T) { + coreStateReaderAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + + nonce, err := coreStateReaderAdapter.ContractNonce(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, nonce) +} + +func TestCoreStateReaderAdapter_ContractStorage(t *testing.T) { + coreStateReaderAdapter := setupCoreStateAdapter(t) + addr := &felt.Felt{} + key := &felt.Felt{} + + value, err := coreStateReaderAdapter.ContractStorage(addr, key) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func TestStateReaderAdapter_Class(t *testing.T) { + stateReaderAdapter := setupStateAdapter(t) + classHash := &felt.Felt{} + + class, err := stateReaderAdapter.Class(classHash) + require.NoError(t, err) + assert.Nil(t, class) +} + +func TestStateReaderAdapter_ContractClassHash(t *testing.T) { + stateReaderAdapter := setupStateAdapter(t) + addr := &felt.Felt{} + + classHash, err := stateReaderAdapter.ContractClassHash(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, classHash) +} + +func TestStateReaderAdapter_ContractNonce(t *testing.T) { + stateReaderAdapter := setupStateAdapter(t) + addr := &felt.Felt{} + + nonce, err := stateReaderAdapter.ContractNonce(addr) + require.NoError(t, err) + assert.Equal(t, felt.Zero, nonce) +} + +func TestStateReaderAdapter_ContractStorage(t *testing.T) { + stateReaderAdapter := setupStateAdapter() + addr := &felt.Felt{} + key := &felt.Felt{} + + value, err := stateReaderAdapter.ContractStorage(addr, key) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func setupCoreStateAdapter(t *testing.T) *commonstate.CoreStateAdapter { + state := setupCoreState() + return commonstate.NewCoreStateAdapter(state) +} + +func setupCoreState() *core.State { + testDB := memory.New() + txn := testDB.NewIndexedBatch() + + return core.NewState(txn) +} + +func setupStateAdapter() *commonstate.StateAdapter { + state := setupState([]*core.StateUpdate{}, 0) + return commonstate.NewStateAdapter(state) +} + +func setupState(t *testing.T, stateUpdates []*core.StateUpdate, blocks uint64) *state.State { + stateDB := newTestStateDB() + state, err := state.New(&felt.Zero, stateDB) + require.NoError(t, err) + return state +} + +func newTestStateDB() *state.StateDB { + memDB := memory.New() + db, err := triedb.New(memDB, nil) + if err != nil { + panic(err) + } + return state.NewStateDB(memDB, db) +} diff --git a/core/state/commontrie/trie_test.go b/core/state/commontrie/trie_test.go new file mode 100644 index 0000000000..57842f4a56 --- /dev/null +++ b/core/state/commontrie/trie_test.go @@ -0,0 +1,76 @@ +package commontrie_test + +import ( + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commontrie" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrieAdapter_Update(t *testing.T) { + trie := &commontrie.TrieAdapter{} + key := &felt.Felt{} + value := &felt.Felt{} + + err := trie.Update(key, value) + require.NoError(t, err) +} + +func TestTrieAdapter_Get(t *testing.T) { + trie := &commontrie.TrieAdapter{} + key := &felt.Felt{} + + value, err := trie.Get(key) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func TestTrieAdapter_Hash(t *testing.T) { + trie := &commontrie.TrieAdapter{} + + hash, err := trie.Hash() + require.NoError(t, err) + assert.Equal(t, felt.Zero, hash) +} + +func TestTrieAdapter_HashFn(t *testing.T) { + trie := &commontrie.TrieAdapter{} + + hashFn := trie.HashFn() + assert.NotNil(t, hashFn) +} + +func TestTrie2Adapter_Update(t *testing.T) { + trie := &commontrie.Trie2Adapter{} + key := &felt.Felt{} + value := &felt.Felt{} + + err := trie.Update(key, value) + require.NoError(t, err) +} + +func TestTrie2Adapter_Get(t *testing.T) { + trie := &commontrie.Trie2Adapter{} + key := &felt.Felt{} + + value, err := trie.Get(key) + require.NoError(t, err) + assert.Equal(t, felt.Zero, value) +} + +func TestTrie2Adapter_Hash(t *testing.T) { + trie := &commontrie.Trie2Adapter{} + + hash, err := trie.Hash() + require.NoError(t, err) + assert.Equal(t, felt.Zero, hash) +} + +func TestTrie2Adapter_HashFn(t *testing.T) { + trie := &commontrie.Trie2Adapter{} + + hashFn := trie.HashFn() + assert.NotNil(t, hashFn) +} diff --git a/core/state/state_test_utils/new_state_flag_test.go b/core/state/state_test_utils/new_state_flag_test.go new file mode 100644 index 0000000000..809e4c932b --- /dev/null +++ b/core/state/state_test_utils/new_state_flag_test.go @@ -0,0 +1,26 @@ +package statetestutils_test + +import ( + "os" + "testing" + + statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + "github.com/stretchr/testify/assert" +) + +func TestUseNewState(t *testing.T) { + t.Run("default false", func(t *testing.T) { + os.Unsetenv("USE_NEW_STATE") + assert.False(t, statetestutils.UseNewState()) + }) + + t.Run("env true", func(t *testing.T) { + os.Setenv("USE_NEW_STATE", "true") + assert.True(t, statetestutils.UseNewState()) + }) + + t.Run("env false", func(t *testing.T) { + os.Setenv("USE_NEW_STATE", "false") + assert.False(t, statetestutils.UseNewState()) + }) +} From 1faaf815dfaef9acc877d96bdfca67f13bf8e592 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 5 Aug 2025 15:39:08 +0200 Subject: [PATCH 08/47] small fix for failing unit tests --- Makefile | 3 --- node/metrics_test.go | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 00fe60c694..5d0ba863f3 100644 --- a/Makefile +++ b/Makefile @@ -73,9 +73,6 @@ clean-testcache: ## Clean Go test cache test: clean-testcache rustdeps ## Run tests go test $(GO_TAGS) -v ./... -test: clean-testcache rustdeps ## Run tests - go test $(GO_TAGS) -v ./... - test-new-state: clean-testcache rustdeps ## Run tests with new state USE_NEW_STATE=true go test $(GO_TAGS) -v ./... diff --git a/node/metrics_test.go b/node/metrics_test.go index b5c01c41a1..b65bc3bb9a 100644 --- a/node/metrics_test.go +++ b/node/metrics_test.go @@ -45,7 +45,7 @@ func TestMakeL1Metrics(t *testing.T) { reg := prometheus.NewRegistry() prometheus.DefaultRegisterer = reg - head := &core.L1Head{BlockNumber: 42} + head := core.L1Head{BlockNumber: 42} mockBCReader.EXPECT().L1Head().Return(head, nil).AnyTimes() mockSubscriber.EXPECT().FinalisedHeight(gomock.Any()).Return(uint64(100), nil).AnyTimes() mockSubscriber.EXPECT().LatestHeight(gomock.Any()).Return(uint64(101), nil).AnyTimes() @@ -69,7 +69,7 @@ func TestMakeL1Metrics(t *testing.T) { reg := prometheus.NewRegistry() prometheus.DefaultRegisterer = reg - mockBCReader.EXPECT().L1Head().Return(nil, errors.New("err")).AnyTimes() + mockBCReader.EXPECT().L1Head().Return(core.L1Head{}, errors.New("err")).AnyTimes() mockSubscriber.EXPECT().FinalisedHeight(gomock.Any()).Return(uint64(0), errors.New("err")).AnyTimes() mockSubscriber.EXPECT().LatestHeight(gomock.Any()).Return(uint64(0), errors.New("err")).AnyTimes() From 595fdc24ed4081e13562d3bc070d794733de711a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 19:38:01 +0200 Subject: [PATCH 09/47] address comments, linter --- blockchain/blockchain.go | 10 ++--- .../state_test_utils/new_state_flag_test.go | 4 +- rpc/v8/storage.go | 43 +++++++++---------- rpc/v8/storage_test.go | 12 +++--- rpc/v9/storage.go | 42 +++++++++--------- rpc/v9/storage_test.go | 12 +++--- 6 files changed, 61 insertions(+), 62 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 977c632530..a2f1689889 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -90,7 +90,7 @@ type Blockchain struct { network *utils.Network database db.KeyValueStore trieDB *triedb.Database - StateDB *state.StateDB // TODO(weiihann): not sure if it's a good idea to expose this + stateDB *state.StateDB // TODO(weiihann): not sure if it's a good idea to expose this listener EventListener l1HeadFeed *feed.Feed[*core.L1Head] cachedFilters *AggregatedBloomFilterCache @@ -121,7 +121,7 @@ func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) * return &Blockchain{ database: database, trieDB: trieDB, - StateDB: stateDB, + stateDB: stateDB, network: network, listener: &SelectiveListener{}, l1HeadFeed: feed.New[*core.L1Head](), @@ -215,7 +215,7 @@ func (b *Blockchain) StateUpdateByHash(hash *felt.Felt) (*core.StateUpdate, erro func (b *Blockchain) L1HandlerTxnHash(msgHash *common.Hash) (felt.Felt, error) { b.listener.OnRead("L1HandlerTxnHash") - return core.GetL1HandlerTxnHashByMsgHash(b.database, msgHash.Bytes()) // TODO: return felt value + return core.GetL1HandlerTxnHashByMsgHash(b.database, msgHash.Bytes()) } // TransactionByBlockNumberAndIndex gets the transaction for a given block number and index. @@ -554,7 +554,7 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) { if err != nil { return ret, err } - state, err := state.New(stateUpdate.NewRoot, b.StateDB) + state, err := state.New(stateUpdate.NewRoot, b.stateDB) if err != nil { return ret, err } @@ -635,7 +635,7 @@ func (b *Blockchain) revertHead() error { if err != nil { return err } - state, err := state.New(stateUpdate.NewRoot, b.StateDB) + state, err := state.New(stateUpdate.NewRoot, b.stateDB) if err != nil { return err } diff --git a/core/state/state_test_utils/new_state_flag_test.go b/core/state/state_test_utils/new_state_flag_test.go index 809e4c932b..a61d45718a 100644 --- a/core/state/state_test_utils/new_state_flag_test.go +++ b/core/state/state_test_utils/new_state_flag_test.go @@ -15,12 +15,12 @@ func TestUseNewState(t *testing.T) { }) t.Run("env true", func(t *testing.T) { - os.Setenv("USE_NEW_STATE", "true") + t.Setenv("USE_NEW_STATE", "true") assert.True(t, statetestutils.UseNewState()) }) t.Run("env false", func(t *testing.T) { - os.Setenv("USE_NEW_STATE", "false") + t.Setenv("USE_NEW_STATE", "false") assert.False(t, statetestutils.UseNewState()) }) } diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index ee7310d3e5..88f2012ea8 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -213,22 +213,22 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { switch t := tr.(type) { - case *commontrie.TrieAdapter: + case *commontrie.DeprecatedTrieAdapter: classProof := trie.NewProofNodeSet() for _, class := range classes { - if err := t.Trie.Prove(&class, classProof); err != nil { + if err := (*trie.Trie)(t).Prove(&class, classProof); err != nil { return nil, err } } - return adaptTrie1ProofNodes(classProof), nil - case *commontrie.Trie2Adapter: + return adaptDeprecatedTrieProofNodes(classProof), nil + case *commontrie.TrieAdapter: classProof := trie2.NewProofNodeSet() for _, class := range classes { - if err := t.Trie.Prove(&class, classProof); err != nil { + if err := (*trie2.Trie)(t).Prove(&class, classProof); err != nil { return nil, err } } - return adaptTrie2ProofNodes(classProof), nil + return adaptTrieProofNodes(classProof), nil default: return nil, fmt.Errorf("unknown trie type: %T", tr) } @@ -236,16 +236,16 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { switch t := tr.(type) { + case *commontrie.DeprecatedTrieAdapter: + return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) case *commontrie.TrieAdapter: - return getContractProofWithTrie(t.Trie, state, contracts) - case *commontrie.Trie2Adapter: - return getContractProofWithTrie2(t.Trie, state, contracts) + return getContractProofWithTrie((*trie2.Trie)(t), state, contracts) default: return nil, fmt.Errorf("unknown trie type: %T", tr) } } -func getContractProofWithTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -280,12 +280,12 @@ func getContractProofWithTrie(tr *trie.Trie, state commonstate.StateReader, cont } return &ContractProof{ - Nodes: adaptTrie1ProofNodes(contractProof), + Nodes: adaptDeprecatedTrieProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } -func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -317,7 +317,7 @@ func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, co } return &ContractProof{ - Nodes: adaptTrie2ProofNodes(contractProof), + Nodes: adaptTrieProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } @@ -331,22 +331,22 @@ func getContractStorageProof(state commonstate.StateReader, storageKeys []Storag } switch t := contractStorageTrie.(type) { - case *commontrie.TrieAdapter: + case *commontrie.DeprecatedTrieAdapter: contractStorageProof := trie.NewProofNodeSet() for _, key := range storageKey.Keys { - if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + if err := (*trie.Trie)(t).Prove(&key, contractStorageProof); err != nil { return nil, err } } - contractStorageRes[i] = adaptTrie1ProofNodes(contractStorageProof) - case *commontrie.Trie2Adapter: + contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) + case *commontrie.TrieAdapter: contractStorageProof := trie2.NewProofNodeSet() for _, key := range storageKey.Keys { - if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + if err := (*trie2.Trie)(t).Prove(&key, contractStorageProof); err != nil { return nil, err } } - contractStorageRes[i] = adaptTrie2ProofNodes(contractStorageProof) + contractStorageRes[i] = adaptTrieProofNodes(contractStorageProof) default: return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } @@ -355,7 +355,7 @@ func getContractStorageProof(state commonstate.StateReader, storageKeys []Storag return contractStorageRes, nil } -func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptDeprecatedTrieProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -385,7 +385,7 @@ func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } -func adaptTrie2ProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { +func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -467,7 +467,6 @@ type ContractProof struct { Nodes []*HashToNode `json:"nodes"` LeavesData []*LeafData `json:"contract_leaves_data"` } - type GlobalRoots struct { ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 77e5661fe5..474a7fb370 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -177,8 +177,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = commontrie.NewTrieAdapter(tempTrie) - contractTrie = commontrie.NewTrieAdapter(tempTrie) + classTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) + contractTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { @@ -204,8 +204,8 @@ func TestStorageProof(t *testing.T) { require.NoError(t, err) contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) require.NoError(t, err) - classTrie = commontrie.NewTrie2Adapter(classTrie2) - contractTrie = commontrie.NewTrie2Adapter(contractTrie2) + classTrie = commontrie.NewTrieAdapter(classTrie2) + contractTrie = commontrie.NewTrieAdapter(contractTrie2) } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -828,9 +828,9 @@ func emptyCommonTrie(t *testing.T) commontrie.Trie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return commontrie.NewTrie2Adapter(tempTrie) + return commontrie.NewTrieAdapter(tempTrie) } else { - return commontrie.NewTrieAdapter(emptyTrie(t)) + return commontrie.NewDeprecatedTrieAdapter(emptyTrie(t)) } } diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 76a8a4d339..42247fbf9b 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -211,22 +211,22 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { switch t := tr.(type) { - case *commontrie.TrieAdapter: + case *commontrie.DeprecatedTrieAdapter: classProof := trie.NewProofNodeSet() for _, class := range classes { - if err := t.Trie.Prove(&class, classProof); err != nil { + if err := (*trie.Trie)(t).Prove(&class, classProof); err != nil { return nil, err } } - return adaptTrie1ProofNodes(classProof), nil - case *commontrie.Trie2Adapter: + return adaptDeprecatedTrieProofNodes(classProof), nil + case *commontrie.TrieAdapter: classProof := trie2.NewProofNodeSet() for _, class := range classes { - if err := t.Trie.Prove(&class, classProof); err != nil { + if err := (*trie2.Trie)(t).Prove(&class, classProof); err != nil { return nil, err } } - return adaptTrie2ProofNodes(classProof), nil + return adaptTrieProofNodes(classProof), nil default: return nil, fmt.Errorf("unknown trie type: %T", tr) } @@ -234,16 +234,16 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { switch t := tr.(type) { + case *commontrie.DeprecatedTrieAdapter: + return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) case *commontrie.TrieAdapter: - return getContractProofWithTrie1(t.Trie, state, contracts) - case *commontrie.Trie2Adapter: - return getContractProofWithTrie2(t.Trie, state, contracts) + return getContractProofWithTrie((*trie2.Trie)(t), state, contracts) default: return nil, fmt.Errorf("unknown trie type: %T", tr) } } -func getContractProofWithTrie1(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -278,12 +278,12 @@ func getContractProofWithTrie1(tr *trie.Trie, state commonstate.StateReader, con } return &ContractProof{ - Nodes: adaptTrie1ProofNodes(contractProof), + Nodes: adaptDeprecatedTrieProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } -func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -315,7 +315,7 @@ func getContractProofWithTrie2(tr *trie2.Trie, state commonstate.StateReader, co } return &ContractProof{ - Nodes: adaptTrie2ProofNodes(contractProof), + Nodes: adaptTrieProofNodes(contractProof), LeavesData: contractLeavesData, }, nil } @@ -329,22 +329,22 @@ func getContractStorageProof(state commonstate.StateReader, storageKeys []Storag } switch t := contractStorageTrie.(type) { - case *commontrie.TrieAdapter: + case *commontrie.DeprecatedTrieAdapter: contractStorageProof := trie.NewProofNodeSet() for _, key := range storageKey.Keys { - if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + if err := (*trie.Trie)(t).Prove(&key, contractStorageProof); err != nil { return nil, err } } - contractStorageRes[i] = adaptTrie1ProofNodes(contractStorageProof) - case *commontrie.Trie2Adapter: + contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) + case *commontrie.TrieAdapter: contractStorageProof := trie2.NewProofNodeSet() for _, key := range storageKey.Keys { - if err := t.Trie.Prove(&key, contractStorageProof); err != nil { + if err := (*trie2.Trie)(t).Prove(&key, contractStorageProof); err != nil { return nil, err } } - contractStorageRes[i] = adaptTrie2ProofNodes(contractStorageProof) + contractStorageRes[i] = adaptTrieProofNodes(contractStorageProof) default: return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } @@ -353,7 +353,7 @@ func getContractStorageProof(state commonstate.StateReader, storageKeys []Storag return contractStorageRes, nil } -func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptDeprecatedTrieProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -383,7 +383,7 @@ func adaptTrie1ProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } -func adaptTrie2ProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { +func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index b3eea096a9..45622bf944 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -195,8 +195,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = commontrie.NewTrieAdapter(tempTrie) - contractTrie = commontrie.NewTrieAdapter(tempTrie) + classTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) + contractTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { @@ -222,8 +222,8 @@ func TestStorageProof(t *testing.T) { require.NoError(t, err) contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) require.NoError(t, err) - classTrie = commontrie.NewTrie2Adapter(classTrie2) - contractTrie = commontrie.NewTrie2Adapter(contractTrie2) + classTrie = commontrie.NewTrieAdapter(classTrie2) + contractTrie = commontrie.NewTrieAdapter(contractTrie2) } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -846,9 +846,9 @@ func emptyCommonTrie(t *testing.T) commontrie.Trie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return commontrie.NewTrie2Adapter(tempTrie) + return commontrie.NewTrieAdapter(tempTrie) } else { - return commontrie.NewTrieAdapter(emptyTrie(t)) + return commontrie.NewDeprecatedTrieAdapter(emptyTrie(t)) } } From 40ebb7c4b1a533fdac7aac6d7a0107144adcc414 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 19:38:28 +0200 Subject: [PATCH 10/47] remove todo --- blockchain/blockchain.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index a2f1689889..c12b98f841 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -90,7 +90,7 @@ type Blockchain struct { network *utils.Network database db.KeyValueStore trieDB *triedb.Database - stateDB *state.StateDB // TODO(weiihann): not sure if it's a good idea to expose this + stateDB *state.StateDB listener EventListener l1HeadFeed *feed.Feed[*core.L1Head] cachedFilters *AggregatedBloomFilterCache From 5b3a1b8a30925629fd63d37a75770666d7a9c46c Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 19:40:00 +0200 Subject: [PATCH 11/47] fixes in the makefile --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 5d0ba863f3..531b792fd0 100644 --- a/Makefile +++ b/Makefile @@ -71,13 +71,13 @@ clean-testcache: ## Clean Go test cache go clean -testcache test: clean-testcache rustdeps ## Run tests - go test $(GO_TAGS) -v ./... + go test $(GO_TAGS) ./... test-new-state: clean-testcache rustdeps ## Run tests with new state - USE_NEW_STATE=true go test $(GO_TAGS) -v ./... + USE_NEW_STATE=true go test $(GO_TAGS) ./... test-cached: rustdeps ## Run cached tests - go test $(GO_TAGS) ./... -args + go test $(GO_TAGS) ./... test-race: clean-testcache rustdeps ## Run tests with race detection go test $(GO_TAGS) ./... -race $(TEST_RACE_LDFLAGS) From 81bc7a96458ef923c10025f17ae75a65fbf90654 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 22:03:32 +0200 Subject: [PATCH 12/47] unit tests fixes --- rpc/v8/storage.go | 31 +++++++++++++++++++++---------- rpc/v8/storage_test.go | 13 ++++++++++--- rpc/v9/storage.go | 25 +++++++++++++++++++------ rpc/v9/storage_test.go | 10 ++++++++-- vm/vm_test.go | 18 ++++++++++++++++-- 5 files changed, 74 insertions(+), 23 deletions(-) diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 88f2012ea8..8a689c217e 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -285,26 +285,26 @@ func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateRe }, nil } -func getContractProofWithTrie(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) - for i, contract := range contracts { if err := tr.Prove(&contract, contractProof); err != nil { return nil, err } + fmt.Println("contractProof", *contractProof) root := tr.Hash() - nonce, err := state.ContractNonce(&contract) + nonce, err := st.ContractNonce(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + if errors.Is(err, state.ErrContractNotDeployed) { // contract does not exist, skip getting leaf data continue } return nil, err } - classHash, err := state.ContractClassHash(&contract) + classHash, err := st.ContractClassHash(&contract) if err != nil { return nil, err } @@ -394,15 +394,15 @@ func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { switch n := nodeList[i].(type) { case *trienode.BinaryNode: node = &BinaryNode{ - Left: &hash, - Right: &hash, + Left: nodeFelt(n.Children[0]), + Right: nodeFelt(n.Children[1]), } case *trienode.EdgeNode: - path := n.Path.Felt() + pathFelt := n.Path.Felt() node = &EdgeNode{ - Path: path.String(), + Path: pathFelt.String(), Length: int(n.Path.Len()), - Child: &hash, + Child: nodeFelt(n.Child), } } @@ -415,6 +415,17 @@ func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { return nodes } +func nodeFelt(n trienode.Node) *felt.Felt { + switch n := n.(type) { + case *trienode.HashNode: + return (*felt.Felt)(n) + case *trienode.ValueNode: + return (*felt.Felt)(n) + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 474a7fb370..23bff0809f 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -3,6 +3,7 @@ package rpcv8_test import ( "context" "errors" + "fmt" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/state/commontrie" statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" @@ -337,9 +339,13 @@ func TestStorageProof(t *testing.T) { verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { - mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) // TODO(maksym): after integration change to state.ErrContractNotDeployed - mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) // TODO(maksym): after integration change to state.ErrContractNotDeployed - + if statetestutils.UseNewState() { + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(0) + } else { + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) + } proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*noSuchKey}, nil) require.Nil(t, rpcErr) require.NotNil(t, proof) @@ -802,6 +808,7 @@ func verifyIf( proofSet := trie.NewProofNodeSet() for _, hn := range proof { + fmt.Println("hn", hn, "hash", hn.Hash, "node", hn.Node) proofSet.Put(*hn.Hash, hn.Node.AsProofNode()) } diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 42247fbf9b..e72f7bccc9 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie" @@ -40,7 +41,8 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(address) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } h.log.Errorw("Failed to get contract nonce", "err", err) @@ -392,15 +394,15 @@ func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { switch n := nodeList[i].(type) { case *trienode.BinaryNode: node = &BinaryNode{ - Left: &hash, - Right: &hash, + Left: nodeFelt(n.Children[0]), + Right: nodeFelt(n.Children[1]), } case *trienode.EdgeNode: - path := n.Path.Felt() + pathFelt := n.Path.Felt() node = &EdgeNode{ - Path: path.String(), + Path: pathFelt.String(), Length: int(n.Path.Len()), - Child: &hash, + Child: nodeFelt(n.Child), } } @@ -413,6 +415,17 @@ func adaptTrieProofNodes(proof *trie2.ProofNodeSet) []*HashToNode { return nodes } +func nodeFelt(n trienode.Node) *felt.Felt { + switch n := n.(type) { + case *trienode.HashNode: + return (*felt.Felt)(n) + case *trienode.ValueNode: + return (*felt.Felt)(n) + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 45622bf944..2691bfeed3 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/state/commontrie" statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/core/trie" @@ -355,8 +356,13 @@ func TestStorageProof(t *testing.T) { verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { - mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) // TODO(maksym): after integration change to state.ErrContractNotDeployed - mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) // TODO(maksym): after integration change to state.ErrContractNotDeployed + if statetestutils.UseNewState() { + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(0) + } else { + mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) + } proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*noSuchKey}, nil) require.Nil(t, rpcErr) diff --git a/vm/vm_test.go b/vm/vm_test.go index 3604ab99e1..a7889cb246 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -38,9 +38,10 @@ func TestCallDeprecatedCairo(t *testing.T) { require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) + newRoot := utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8") require.NoError(t, testState.Update(0, &core.StateUpdate{ OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), + NewRoot: newRoot, StateDiff: &core.StateDiff{ DeployedContracts: map[felt.Felt]*felt.Felt{ *contractAddr: classHash, @@ -60,6 +61,12 @@ func TestCallDeprecatedCairo(t *testing.T) { require.NoError(t, err) assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + // if new state, we need to create a new state with the new root + if statetestutils.UseNewState() { + testState, err = stateFactory.NewState(newRoot, txn) + require.NoError(t, err) + } + require.NoError(t, testState.Update(1, &core.StateUpdate{ OldRoot: utils.HexToFelt(t, "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), NewRoot: utils.HexToFelt(t, "0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), @@ -142,9 +149,10 @@ func TestCallCairo(t *testing.T) { require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) + newRoot := utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84") require.NoError(t, testState.Update(0, &core.StateUpdate{ OldRoot: &felt.Zero, - NewRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), + NewRoot: newRoot, StateDiff: &core.StateDiff{ DeployedContracts: map[felt.Felt]*felt.Felt{ *contractAddr: classHash, @@ -171,6 +179,12 @@ func TestCallCairo(t *testing.T) { require.NoError(t, err) assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + // if new state, we need to create a new state with the new root + if statetestutils.UseNewState() { + testState, err = stateFactory.NewState(newRoot, txn) + require.NoError(t, err) + } + require.NoError(t, testState.Update(1, &core.StateUpdate{ OldRoot: utils.HexToFelt(t, "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84"), NewRoot: utils.HexToFelt(t, "0x7a9da0a7471a8d5118d3eefb8c26a6acbe204eb1eaa934606f4757a595fe552"), From af861b6b7e00f4d56a00b9780ade7f1c3769f293 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 22:10:37 +0200 Subject: [PATCH 13/47] fixes for the unit tests, lint --- rpc/v8/storage.go | 1 - rpc/v9/storage.go | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 8a689c217e..60017f6150 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -292,7 +292,6 @@ func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contra if err := tr.Prove(&contract, contractProof); err != nil { return nil, err } - fmt.Println("contractProof", *contractProof) root := tr.Hash() diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index e72f7bccc9..c332c40a2a 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -285,7 +285,7 @@ func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateRe }, nil } -func getContractProofWithTrie(tr *trie2.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -296,15 +296,15 @@ func getContractProofWithTrie(tr *trie2.Trie, state commonstate.StateReader, con root := tr.Hash() - nonce, err := state.ContractNonce(&contract) + nonce, err := st.ContractNonce(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + if errors.Is(err, state.ErrContractNotDeployed) { // contract does not exist, skip getting leaf data continue } return nil, err } - classHash, err := state.ContractClassHash(&contract) + classHash, err := st.ContractClassHash(&contract) if err != nil { return nil, err } From 2b330723c5ae462170fdee4913fb0924c9c90cc7 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 6 Aug 2025 22:30:10 +0200 Subject: [PATCH 14/47] remove flag test --- .../state_test_utils/new_state_flag_test.go | 26 ------------------- 1 file changed, 26 deletions(-) delete mode 100644 core/state/state_test_utils/new_state_flag_test.go diff --git a/core/state/state_test_utils/new_state_flag_test.go b/core/state/state_test_utils/new_state_flag_test.go deleted file mode 100644 index a61d45718a..0000000000 --- a/core/state/state_test_utils/new_state_flag_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package statetestutils_test - -import ( - "os" - "testing" - - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" - "github.com/stretchr/testify/assert" -) - -func TestUseNewState(t *testing.T) { - t.Run("default false", func(t *testing.T) { - os.Unsetenv("USE_NEW_STATE") - assert.False(t, statetestutils.UseNewState()) - }) - - t.Run("env true", func(t *testing.T) { - t.Setenv("USE_NEW_STATE", "true") - assert.True(t, statetestutils.UseNewState()) - }) - - t.Run("env false", func(t *testing.T) { - t.Setenv("USE_NEW_STATE", "false") - assert.False(t, statetestutils.UseNewState()) - }) -} From a08c88e31d1eaafe69ddaf072a7f1f4f84fe4bcd Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 8 Aug 2025 18:56:16 +0200 Subject: [PATCH 15/47] unit tests pass!!!!! --- blockchain/blockchain.go | 14 ++++---- cmd/juno/dbcmd.go | 8 ++--- cmd/juno/dbcmd_test.go | 7 ++++ core/state.go | 1 + core/state/commonstate/state.go | 7 +++- core/state/history_test.go | 4 +-- core/state/state.go | 8 +++-- core/state/state_test.go | 54 ++++++++++++++--------------- core/state_test.go | 60 ++++++++++++++++----------------- rpc/v8/storage_test.go | 2 -- vm/vm.go | 8 ----- vm/vm_test.go | 12 +++---- 12 files changed, 95 insertions(+), 90 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index c12b98f841..a4eebc024e 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -290,7 +290,7 @@ func (b *Blockchain) storeCoreState( } state := core.NewState(txn) - if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { + if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil { return err } if err := core.WriteBlockHeader(txn, block.Header); err != nil { @@ -342,7 +342,7 @@ func (b *Blockchain) store( if err != nil { return err } - if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { + if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil { return err } batch := b.database.NewBatch() @@ -693,8 +693,7 @@ func (b *Blockchain) Simulate( // Simulate without commit txn := b.database.NewIndexedBatch() defer txn.Reset() - - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { + if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, false); err != nil { return SimulateResult{}, err } @@ -729,7 +728,7 @@ func (b *Blockchain) Finalise( ) error { if !b.StateFactory.UseNewState { err := b.database.Update(func(txn db.IndexedBatch) error { - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { + if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, true); err != nil { return err } commitments, err := b.updateBlockHash(block, stateUpdate) @@ -751,7 +750,7 @@ func (b *Blockchain) Finalise( return b.runningFilter.Insert(block.EventsBloom, block.Number) } else { batch := b.database.NewBatch() - if err := b.updateStateRoots(nil, block, stateUpdate, newClasses); err != nil { + if err := b.updateStateRoots(nil, block, stateUpdate, newClasses, true); err != nil { return err } commitments, err := b.updateBlockHash(block, stateUpdate) @@ -780,6 +779,7 @@ func (b *Blockchain) updateStateRoots( block *core.Block, stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.Class, + flushChanges bool, ) error { var height uint64 var err error @@ -808,7 +808,7 @@ func (b *Blockchain) updateStateRoots( stateUpdate.OldRoot = &oldStateRoot // Apply state update - if err = state.Update(block.Number, stateUpdate, newClasses, true); err != nil { + if err = state.Update(block.Number, stateUpdate, newClasses, true, flushChanges); err != nil { return err } diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index 3ff88c4695..35102b939f 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -83,7 +83,7 @@ func dbInfo(cmd *cobra.Command, args []string) error { return err } - stateVersion, err := cmd.Flags().GetBool(newStateF) + newState, err := cmd.Flags().GetBool(newStateF) if err != nil { return err } @@ -94,7 +94,7 @@ func dbInfo(cmd *cobra.Command, args []string) error { } defer database.Close() - chain := blockchain.New(database, nil, stateVersion) + chain := blockchain.New(database, nil, newState) var info DBInfo // Get the latest block information @@ -154,7 +154,7 @@ func dbRevert(cmd *cobra.Command, args []string) error { return fmt.Errorf("--%v cannot be 0", dbRevertToBlockF) } - stateVersion, err := cmd.Flags().GetBool(newStateF) + newState, err := cmd.Flags().GetBool(newStateF) if err != nil { return err } @@ -166,7 +166,7 @@ func dbRevert(cmd *cobra.Command, args []string) error { defer database.Close() for { - chain := blockchain.New(database, nil, stateVersion) + chain := blockchain.New(database, nil, newState) head, err := chain.Head() if err != nil { return fmt.Errorf("failed to get the latest block information: %v", err) diff --git a/cmd/juno/dbcmd_test.go b/cmd/juno/dbcmd_test.go index 07ef714fa2..fbef9038b4 100644 --- a/cmd/juno/dbcmd_test.go +++ b/cmd/juno/dbcmd_test.go @@ -22,6 +22,9 @@ var emptyCommitments = core.BlockCommitments{} func TestDBCmd(t *testing.T) { t.Run("retrieve info when db contains block0", func(t *testing.T) { cmd := juno.DBInfoCmd() + if statetestutils.UseNewState() { + require.NoError(t, cmd.Flags().Set("new-state", "true")) + } executeCmdInDB(t, cmd) }) @@ -45,6 +48,9 @@ func TestDBCmd(t *testing.T) { require.NoError(t, cmd.Flags().Set("db-path", dbPath)) require.NoError(t, cmd.Flags().Set("to-block", strconv.Itoa(int(revertToBlock)))) + if statetestutils.UseNewState() { + require.NoError(t, cmd.Flags().Set("new-state", "true")) + } require.NoError(t, cmd.Execute()) // unfortunately we cannot use blockchain from prepareDB because @@ -91,6 +97,7 @@ func prepareDB(t *testing.T, network *utils.Network, syncToBlock uint64) string require.NoError(t, chain.Store(block, &emptyCommitments, stateUpdate, nil)) } + require.NoError(t, chain.Stop()) require.NoError(t, testDB.Close()) return dbPath diff --git a/core/state.go b/core/state.go index 12d5690454..576b56370c 100644 --- a/core/state.go +++ b/core/state.go @@ -233,6 +233,7 @@ func (s *State) Update( update *StateUpdate, declaredClasses map[felt.Felt]Class, skipVerifyNewRoot bool, + flushChanges bool, // TODO(maksym): added to satisfy the interface, but not used ) error { err := s.verifyStateUpdateRoot(update.OldRoot) if err != nil { diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index c7707ee4fe..504c9bb497 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -19,7 +19,12 @@ type State interface { ContractClassHashAt(addr *felt.Felt, blockNumber uint64) (felt.Felt, error) ContractDeployedAt(addr *felt.Felt, blockNumber uint64) (bool, error) - Update(blockNum uint64, update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool) error + Update(blockNum uint64, + update *core.StateUpdate, + declaredClasses map[felt.Felt]core.Class, + skipVerifyNewRoot bool, + flushChanges bool, + ) error Revert(blockNum uint64, update *core.StateUpdate) error Commitment() (felt.Felt, error) } diff --git a/core/state/history_test.go b/core/state/history_test.go index bcfbff70b1..d16cb83b2d 100644 --- a/core/state/history_test.go +++ b/core/state/history_test.go @@ -126,7 +126,7 @@ func TestStateHistoryClassOperations(t *testing.T) { } state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - err = state.Update(0, stateUpdate, classes, false) + err = state.Update(0, stateUpdate, classes, false, true) require.NoError(t, err) stateComm, err := state.Commitment() require.NoError(t, err) @@ -142,7 +142,7 @@ func TestStateHistoryClassOperations(t *testing.T) { state, err = New(&stateComm, stateDB) require.NoError(t, err) - err = state.Update(1, stateUpdate, classes2, false) + err = state.Update(1, stateUpdate, classes2, false, true) require.NoError(t, err) historyBlock0, err := NewStateHistory(0, &felt.Zero, stateDB) diff --git a/core/state/state.go b/core/state/state.go index 2272868a39..307260b790 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -168,6 +168,7 @@ func (s *State) Update( update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool, + flushChanges bool, ) error { if err := s.verifyComm(update.OldRoot); err != nil { return err @@ -232,8 +233,10 @@ func (s *State) Update( deployedContracts: update.StateDiff.ReplacedClasses, }) - if err := s.flush(blockNum, &stateUpdate, dirtyClasses, true); err != nil { - return err + if flushChanges { + if err := s.flush(blockNum, &stateUpdate, dirtyClasses, true); err != nil { + return err + } } return nil @@ -568,7 +571,6 @@ func (s *State) verifyComm(comm *felt.Felt) error { if err != nil { return err } - if !curComm.Equal(comm) { return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", comm, &curComm) } diff --git a/core/state/state_test.go b/core/state/state_test.go index 8634769ecc..d5fd9f78e3 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -80,7 +80,7 @@ func TestUpdate(t *testing.T) { stateDB := setupState(t, stateUpdates, 0) state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - err = state.Update(block0, su, nil, false) + err = state.Update(block0, su, nil, false, true) require.Error(t, err) }) @@ -94,7 +94,7 @@ func TestUpdate(t *testing.T) { stateDB := setupState(t, stateUpdates, 0) state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - err = state.Update(block0, su, nil, false) + err = state.Update(block0, su, nil, false, true) require.Error(t, err) }) @@ -103,13 +103,13 @@ func TestUpdate(t *testing.T) { stateDB := setupState(t, stateUpdates, 3) state, err := New(stateUpdates[3].OldRoot, stateDB) require.NoError(t, err) - require.Error(t, state.Update(block3, stateUpdates[3], nil, false)) + require.Error(t, state.Update(block3, stateUpdates[3], nil, false, true)) }) t.Run("with class definition", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 3) state, err := New(stateUpdates[3].OldRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block3, su3, su3DeclaredClasses(), false)) + require.NoError(t, state.Update(block3, su3, su3DeclaredClasses(), false, true)) }) }) @@ -143,7 +143,7 @@ func TestUpdate(t *testing.T) { stateDB := setupState(t, stateUpdates, 5) state, err := New(stateUpdates[4].NewRoot, stateDB) require.NoError(t, err) - require.ErrorIs(t, state.Update(block5, su5, nil, false), ErrContractNotDeployed) + require.ErrorIs(t, state.Update(block5, su5, nil, false, true), ErrContractNotDeployed) }) } @@ -187,7 +187,7 @@ func TestNonce(t *testing.T) { stateDB := setupState(t, nil, 0) state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, state.Update(block0, su0, nil, false, true)) gotNonce, err := state.ContractNonce(addr) require.NoError(t, err) @@ -198,7 +198,7 @@ func TestNonce(t *testing.T) { stateDB := setupState(t, nil, 0) state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, state.Update(block0, su0, nil, false, true)) expectedNonce := new(felt.Felt).SetUint64(1) su1 := &core.StateUpdate{ @@ -211,7 +211,7 @@ func TestNonce(t *testing.T) { state1, err := New(su1.OldRoot, stateDB) require.NoError(t, err) - require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, state1.Update(block1, su1, nil, false, true)) gotNonce, err := state1.ContractNonce(addr) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestClass(t *testing.T) { require.NoError(t, state.Update(0, su0, map[felt.Felt]core.Class{ *cairo0Hash: cairo0Class, *cairo1Hash: cairo1Class, - }, false)) + }, false, true)) gotCairo1Class, err := state.Class(cairo1Hash) require.NoError(t, err) @@ -320,7 +320,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, replaceStateUpdate, nil, false)) + require.NoError(t, state.Update(block2, replaceStateUpdate, nil, false, true)) gotClassHash, err := state.ContractClassHash(&su1FirstDeployedAddress) require.NoError(t, err) @@ -351,7 +351,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, nonceStateUpdate, nil, false)) + require.NoError(t, state.Update(block2, nonceStateUpdate, nil, false, true)) gotNonce, err := state.ContractNonce(&su1FirstDeployedAddress) require.NoError(t, err) @@ -385,7 +385,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, storageStateUpdate, nil, false)) + require.NoError(t, state.Update(block2, storageStateUpdate, nil, false, true)) gotStorage, err := state.ContractStorage(&su1FirstDeployedAddress, replacedVal) require.NoError(t, err) assert.Equal(t, *replacedVal, gotStorage) @@ -448,7 +448,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, declaredClassesStateUpdate, classesM, false)) + require.NoError(t, state.Update(block2, declaredClassesStateUpdate, classesM, false, true)) state, err = New(declaredClassesStateUpdate.NewRoot, stateDB) require.NoError(t, err) @@ -469,7 +469,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, su2, nil, false)) + require.NoError(t, state.Update(block2, su2, nil, false, true)) state, err = New(su2.NewRoot, stateDB) require.NoError(t, err) @@ -477,7 +477,7 @@ func TestRevert(t *testing.T) { state, err = New(su1.NewRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block2, su2, nil, false)) + require.NoError(t, state.Update(block2, su2, nil, false, true)) }) t.Run("should be able to revert all the updates", func(t *testing.T) { @@ -513,7 +513,7 @@ func TestRevert(t *testing.T) { state, err := New(su1.OldRoot, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block1, &su1, nil, false)) + require.NoError(t, state.Update(block1, &su1, nil, false, true)) state, err = New(su1.NewRoot, stateDB) require.NoError(t, err) @@ -543,7 +543,7 @@ func TestRevert(t *testing.T) { state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block0, declareDiff, newClasses, false)) + require.NoError(t, state.Update(block0, declareDiff, newClasses, false, true)) declaredClass, err := state.Class(classHash) require.NoError(t, err) @@ -555,7 +555,7 @@ func TestRevert(t *testing.T) { state, err = New(declareDiff.NewRoot, stateDB) require.NoError(t, err) declareDiff.OldRoot = declareDiff.NewRoot - require.NoError(t, state.Update(block1, declareDiff, newClasses, false)) + require.NoError(t, state.Update(block1, declareDiff, newClasses, false, true)) // Redeclaring should not change the declared at block number declaredClass, err = state.Class(classHash) @@ -611,7 +611,7 @@ func TestRevert(t *testing.T) { state, err := New(&felt.Zero, stateDB) require.NoError(t, err) - require.NoError(t, state.Update(block0, su, nil, false)) + require.NoError(t, state.Update(block0, su, nil, false, true)) state, err = New(su.NewRoot, stateDB) require.NoError(t, err) @@ -695,7 +695,7 @@ func TestContractHistory(t *testing.T) { }, } - require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, state.Update(block0, su0, nil, false, true)) gotNonce, err := state.ContractNonceAt(addr, block0) require.NoError(t, err) @@ -715,12 +715,12 @@ func TestContractHistory(t *testing.T) { state0, err := New(&felt.Zero, stateDB) require.NoError(t, err) su0 := emptyStateUpdate - require.NoError(t, state0.Update(block0, su0, nil, false)) + require.NoError(t, state0.Update(block0, su0, nil, false, true)) state1, err := New(su0.NewRoot, stateDB) require.NoError(t, err) su1 := su - require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, state1.Update(block1, su1, nil, false, true)) gotNonce, err := state1.ContractNonceAt(addr, block0) require.NoError(t, err) @@ -740,7 +740,7 @@ func TestContractHistory(t *testing.T) { state0, err := New(&felt.Zero, stateDB) require.NoError(t, err) su0 := su - require.NoError(t, state0.Update(block0, su0, nil, false)) + require.NoError(t, state0.Update(block0, su0, nil, false, true)) state1, err := New(su0.NewRoot, stateDB) require.NoError(t, err) @@ -749,7 +749,7 @@ func TestContractHistory(t *testing.T) { NewRoot: su0.NewRoot, StateDiff: &core.StateDiff{}, } - require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, state1.Update(block1, su1, nil, false, true)) state2, err := New(su1.NewRoot, stateDB) require.NoError(t, err) @@ -766,7 +766,7 @@ func TestContractHistory(t *testing.T) { }, }, } - require.NoError(t, state2.Update(block2, su2, nil, false)) + require.NoError(t, state2.Update(block2, su2, nil, false, true)) gotNonce, err := state2.ContractNonceAt(addr, block1) require.NoError(t, err) @@ -802,7 +802,7 @@ func BenchmarkStateUpdate(b *testing.B) { for i, su := range stateUpdates { state, err := New(su.OldRoot, stateDB) require.NoError(b, err) - err = state.Update(uint64(i), su, nil, false) + err = state.Update(uint64(i), su, nil, false, true) if err != nil { b.Fatalf("Error updating state: %v", err) } @@ -840,7 +840,7 @@ func setupState(t *testing.T, stateUpdates []*core.StateUpdate, blocks uint64) * if i == 3 { declaredClasses = su3DeclaredClasses() } - require.NoError(t, state.Update(uint64(i), su, declaredClasses, false), "failed to update state for block %d", i) + require.NoError(t, state.Update(uint64(i), su, declaredClasses, false, true), "failed to update state for block %d", i) newComm, err := state.Commitment() require.NoError(t, err) assert.Equal(t, *su.NewRoot, newComm) diff --git a/core/state_test.go b/core/state_test.go index bdeb975f63..8906aeb64c 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -41,7 +41,7 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) t.Run("empty state updated with mainnet block 0 state update", func(t *testing.T) { - require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) gotNewRoot, rerr := state.Root() require.NoError(t, rerr) assert.Equal(t, su0.NewRoot, gotNewRoot) @@ -53,7 +53,7 @@ func TestUpdate(t *testing.T) { OldRoot: oldRoot, } expectedErr := fmt.Sprintf("state's current root: %s does not match the expected root: %s", su0.NewRoot, oldRoot) - require.EqualError(t, state.Update(1, su, nil, false), expectedErr) + require.EqualError(t, state.Update(1, su, nil, false, true), expectedErr) }) t.Run("error when state new root doesn't match state update's new root", func(t *testing.T) { @@ -64,16 +64,16 @@ func TestUpdate(t *testing.T) { StateDiff: new(core.StateDiff), } expectedErr := fmt.Sprintf("state's current root: %s does not match the expected root: %s", su0.NewRoot, newRoot) - require.EqualError(t, state.Update(1, su, nil, false), expectedErr) + require.EqualError(t, state.Update(1, su, nil, false, true), expectedErr) }) t.Run("non-empty state updated multiple times", func(t *testing.T) { - require.NoError(t, state.Update(1, su1, nil, false)) + require.NoError(t, state.Update(1, su1, nil, false, true)) gotNewRoot, rerr := state.Root() require.NoError(t, rerr) assert.Equal(t, su1.NewRoot, gotNewRoot) - require.NoError(t, state.Update(2, su2, nil, false)) + require.NoError(t, state.Update(2, su2, nil, false, true)) gotNewRoot, err = state.Root() require.NoError(t, err) assert.Equal(t, su2.NewRoot, gotNewRoot) @@ -91,11 +91,11 @@ func TestUpdate(t *testing.T) { t.Run("post v0.11.0 declared classes affect root", func(t *testing.T) { t.Run("without class definition", func(t *testing.T) { - require.Error(t, state.Update(3, su3, nil, false)) + require.Error(t, state.Update(3, su3, nil, false, true)) }) require.NoError(t, state.Update(3, su3, map[felt.Felt]core.Class{ *utils.HexToFelt(t, "0xDEADBEEF"): &core.Cairo1Class{}, - }, false)) + }, false, true)) assert.NotEqual(t, su3.NewRoot, su3.OldRoot) }) @@ -114,7 +114,7 @@ func TestUpdate(t *testing.T) { } t.Run("update systemContracts storage", func(t *testing.T) { - require.NoError(t, state.Update(4, su4, nil, false)) + require.NoError(t, state.Update(4, su4, nil, false, true)) gotValue, err := state.ContractStorage(scAddr, scKey) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestUpdate(t *testing.T) { StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{*scAddr2: {*scKey: scValue}}, }, } - assert.ErrorIs(t, state.Update(5, su5, nil, false), core.ErrContractNotDeployed) + assert.ErrorIs(t, state.Update(5, su5, nil, false, true), core.ErrContractNotDeployed) }) } @@ -160,8 +160,8 @@ func TestContractClassHash(t *testing.T) { su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false)) - require.NoError(t, state.Update(1, su1, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(1, su1, nil, false, true)) allDeployedContracts := make(map[felt.Felt]*felt.Felt) @@ -187,7 +187,7 @@ func TestContractClassHash(t *testing.T) { }, } - require.NoError(t, state.Update(2, replaceUpdate, nil, false)) + require.NoError(t, state.Update(2, replaceUpdate, nil, false, true)) gotClassHash, err := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, err) @@ -214,7 +214,7 @@ func TestNonce(t *testing.T) { }, } - require.NoError(t, state.Update(0, su, nil, false)) + require.NoError(t, state.Update(0, su, nil, false, true)) t.Run("newly deployed contract has zero nonce", func(t *testing.T) { nonce, err := state.ContractNonce(addr) @@ -232,7 +232,7 @@ func TestNonce(t *testing.T) { }, } - require.NoError(t, state.Update(1, su, nil, false)) + require.NoError(t, state.Update(1, su, nil, false, true)) gotNonce, err := state.ContractNonce(addr) require.NoError(t, err) @@ -249,7 +249,7 @@ func TestStateHistory(t *testing.T) { state := core.NewState(txn) su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) contractAddr := utils.HexToFelt(t, "0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") changedLoc := utils.HexToFelt(t, "0x5") @@ -275,7 +275,7 @@ func TestStateHistory(t *testing.T) { }, }, } - require.NoError(t, state.Update(1, su, nil, false)) + require.NoError(t, state.Update(1, su, nil, false, true)) t.Run("should give old value for a location that changed after the given height", func(t *testing.T) { oldValue, err := state.ContractStorageAt(contractAddr, changedLoc, 0) @@ -299,8 +299,8 @@ func TestContractIsDeployedAt(t *testing.T) { su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false)) - require.NoError(t, state.Update(1, su1, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(1, su1, nil, false, true)) t.Run("deployed on genesis", func(t *testing.T) { deployedOn0 := utils.HexToFelt(t, "0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") @@ -351,7 +351,7 @@ func TestClass(t *testing.T) { require.NoError(t, state.Update(0, su0, map[felt.Felt]core.Class{ *cairo0Hash: cairo0Class, *cairo1Hash: cairo1Class, - }, false)) + }, false, true)) gotCairo1Class, err := state.Class(cairo1Hash) require.NoError(t, err) @@ -373,10 +373,10 @@ func TestRevert(t *testing.T) { state := core.NewState(txn) su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(1, su1, nil, false)) + require.NoError(t, state.Update(1, su1, nil, false, true)) t.Run("revert a replaced class", func(t *testing.T) { replaceStateUpdate := &core.StateUpdate{ @@ -389,7 +389,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, replaceStateUpdate, nil, false)) + require.NoError(t, state.Update(2, replaceStateUpdate, nil, false, true)) require.NoError(t, state.Revert(2, replaceStateUpdate)) classHash, sErr := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -407,7 +407,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, nonceStateUpdate, nil, false)) + require.NoError(t, state.Update(2, nonceStateUpdate, nil, false, true)) require.NoError(t, state.Revert(2, nonceStateUpdate)) nonce, sErr := state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -459,7 +459,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, declaredClassesStateUpdate, classesM, false)) + require.NoError(t, state.Update(2, declaredClassesStateUpdate, classesM, false, true)) require.NoError(t, state.Revert(2, declaredClassesStateUpdate)) var decClass *core.DeclaredClass @@ -475,7 +475,7 @@ func TestRevert(t *testing.T) { su2, err := gw.StateUpdate(t.Context(), 2) require.NoError(t, err) t.Run("should be able to apply new update after a Revert", func(t *testing.T) { - require.NoError(t, state.Update(2, su2, nil, false)) + require.NoError(t, state.Update(2, su2, nil, false, true)) }) t.Run("should be able to revert all the state", func(t *testing.T) { @@ -522,7 +522,7 @@ func TestRevertGenesisStateDiff(t *testing.T) { }, }, } - require.NoError(t, state.Update(0, su, nil, false)) + require.NoError(t, state.Update(0, su, nil, false, true)) require.NoError(t, state.Revert(0, su)) } @@ -538,7 +538,7 @@ func TestRevertSystemContracts(t *testing.T) { su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(0, su0, nil, false, true)) su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) @@ -554,7 +554,7 @@ func TestRevertSystemContracts(t *testing.T) { su1.StateDiff.StorageDiffs[*scAddr] = map[felt.Felt]*felt.Felt{*scKey: scValue} - require.NoError(t, state.Update(1, su1, nil, false)) + require.NoError(t, state.Update(1, su1, nil, false, true)) require.NoError(t, state.Revert(1, su1)) @@ -587,7 +587,7 @@ func TestRevertDeclaredClasses(t *testing.T) { *sierraHash: &core.Cairo1Class{}, } - require.NoError(t, state.Update(0, declareDiff, newClasses, false)) + require.NoError(t, state.Update(0, declareDiff, newClasses, false, true)) declaredClass, err := state.Class(classHash) require.NoError(t, err) assert.Equal(t, uint64(0), declaredClass.At) @@ -596,7 +596,7 @@ func TestRevertDeclaredClasses(t *testing.T) { assert.Equal(t, uint64(0), sierraClass.At) declareDiff.OldRoot = declareDiff.NewRoot - require.NoError(t, state.Update(1, declareDiff, newClasses, false)) + require.NoError(t, state.Update(1, declareDiff, newClasses, false, true)) t.Run("re-declaring a class shouldnt change it's DeclaredAt attribute", func(t *testing.T) { declaredClass, err = state.Class(classHash) diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 23bff0809f..74e65ef448 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -3,7 +3,6 @@ package rpcv8_test import ( "context" "errors" - "fmt" "testing" "time" @@ -808,7 +807,6 @@ func verifyIf( proofSet := trie.NewProofNodeSet() for _, hn := range proof { - fmt.Println("hn", hn, "hash", hn.Hash, "node", hn.Node) proofSet.Put(*hn.Hash, hn.Node.AsProofNode()) } diff --git a/vm/vm.go b/vm/vm.go index 9348f341e7..145b9d5d8c 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -233,14 +233,6 @@ func makeCBlockInfo(blockInfo *BlockInfo) C.BlockInfo { return cBlockInfo } -func makeByteFromBool(b bool) byte { - var boolByte byte - if b { - boolByte = 1 - } - return boolByte -} - func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state commonstate.StateReader, network *utils.Network, maxSteps uint64, sierraVersion string, structuredErrStack, returnStateDiff bool, ) (CallResult, error) { diff --git a/vm/vm_test.go b/vm/vm_test.go index a7889cb246..8f460dc192 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -49,7 +49,7 @@ func TestCallDeprecatedCairo(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false)) + }, false, true)) entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") @@ -77,7 +77,7 @@ func TestCallDeprecatedCairo(t *testing.T) { }, }, }, - }, nil, false)) + }, nil, false, true)) ret, err = New(false, nil).Call(&CallInfo{ ContractAddress: contractAddr, @@ -118,7 +118,7 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false)) + }, false, true)) entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") @@ -160,7 +160,7 @@ func TestCallCairo(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false)) + }, false, true)) logLevel := utils.NewLogLevel(utils.ERROR) log, err := utils.NewZapLogger(logLevel, false) @@ -195,7 +195,7 @@ func TestCallCairo(t *testing.T) { }, }, }, - }, nil, false)) + }, nil, false, true)) ret, err = New(false, log).Call(&CallInfo{ ContractAddress: contractAddr, @@ -236,7 +236,7 @@ func TestCallInfoErrorHandling(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false)) + }, false, true)) logLevel := utils.NewLogLevel(utils.ERROR) log, err := utils.NewZapLogger(logLevel, false) From 16d3dd8726e44c9e188978e1f80887c5b06ab8e3 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sun, 10 Aug 2025 19:52:34 +0200 Subject: [PATCH 16/47] add some benchmarks --- consensus/integtest/integ_test.go | 3 +- core/state/commonstate/state_test.go | 94 +++++++++++++++++++++++++--- core/state/commontrie/trie_test.go | 73 +++++++++++++++++++++ 3 files changed, 158 insertions(+), 12 deletions(-) diff --git a/consensus/integtest/integ_test.go b/consensus/integtest/integ_test.go index ffec46e69d..fa8357ae97 100644 --- a/consensus/integtest/integ_test.go +++ b/consensus/integtest/integ_test.go @@ -57,9 +57,8 @@ func getBlockchain(t *testing.T, genesisDiff core.StateDiff, genesisClasses map[ t.Helper() testDB := memory.New() network := &utils.Mainnet - log := utils.NewNopZapLogger() - bc := bc.New(testDB, network, statetestutils.UseNewState()) + bc := blockchain.New(testDB, network, statetestutils.UseNewState()) require.NoError(t, bc.StoreGenesis(&genesisDiff, genesisClasses)) return bc } diff --git a/core/state/commonstate/state_test.go b/core/state/commonstate/state_test.go index b4f9c3d164..d8ceda17b0 100644 --- a/core/state/commonstate/state_test.go +++ b/core/state/commonstate/state_test.go @@ -1,13 +1,17 @@ package commonstate import ( + "context" "testing" + "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,16 +48,86 @@ func TestCoreStateReaderAdapter(t *testing.T) { assert.NotNil(t, coreStateReaderAdapter) } -func TestStateReaderAdapter(t *testing.T) { - memDB := memory.New() - db, err := triedb.New(memDB, nil) - if err != nil { - panic(err) +func fetchStateUpdates(samples int) ([]*core.StateUpdate, error) { + client := feeder.NewClient(utils.Mainnet.FeederURL). + WithAPIKey("YOUR_API_KEY") + gw := adaptfeeder.New(client) + + suList := make([]*core.StateUpdate, samples) + for i := 0; i < samples; i++ { + su, err := gw.StateUpdate(context.Background(), uint64(i)) + if err != nil { + return nil, err + } + suList[i] = su } - stateDB := state.NewStateDB(memDB, db) - state, err := state.New(&felt.Zero, stateDB) - require.NoError(t, err) + return suList, nil +} + +func BenchmarkStateUpdateNewState(b *testing.B) { + suList, err := fetchStateUpdates(100) + require.NoError(b, err) + + for n := 0; n < b.N; n++ { + b.Run("NewState", func(b *testing.B) { + b.ReportAllocs() + + b.StopTimer() + + memDB := memory.New() + trieDB, err := triedb.New(memDB, nil) + require.NoError(b, err) + stateDB := state.NewStateDB(memDB, trieDB) + txn := memDB.NewIndexedBatch() + stateFactory, err := NewStateFactory(true, trieDB, stateDB) + require.NoError(b, err) - stateReaderAdapter := NewStateReaderAdapter(state) - assert.NotNil(t, stateReaderAdapter) + state, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(b, err) + + b.StartTimer() + + for i := 0; i < len(suList); i++ { + declaredClasses := make(map[felt.Felt]core.Class) + if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { + b.Fatalf("Update failed: %v", err) + } + state, err = stateFactory.NewState(suList[i].NewRoot, txn) + require.NoError(b, err) + } + }) + } +} + +func BenchmarkStateUpdateOldState(b *testing.B) { + suList, err := fetchStateUpdates(300) + require.NoError(b, err) + + for n := 0; n < b.N; n++ { + b.Run("OldState", func(b *testing.B) { + b.ReportAllocs() + + b.StopTimer() + + memDB := memory.New() + trieDB, err := triedb.New(memDB, nil) + require.NoError(b, err) + stateDB := state.NewStateDB(memDB, trieDB) + txn := memDB.NewIndexedBatch() + stateFactory, err := NewStateFactory(false, trieDB, stateDB) + require.NoError(b, err) + + state, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(b, err) + + b.StartTimer() + + for i := 0; i < len(suList); i++ { + declaredClasses := make(map[felt.Felt]core.Class) + if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { + b.Fatalf("Update failed: %v", err) + } + } + }) + } } diff --git a/core/state/commontrie/trie_test.go b/core/state/commontrie/trie_test.go index 3c7dc2de94..4e942c5289 100644 --- a/core/state/commontrie/trie_test.go +++ b/core/state/commontrie/trie_test.go @@ -76,3 +76,76 @@ func TestTrieAdapter(t *testing.T) { assert.NotNil(t, hashFn) }) } + +func BenchmarkDeprecatedTrieAdapter(b *testing.B) { + memDB := memory.New() + txn := memDB.NewIndexedBatch() + storage := trie.NewStorage(txn, db.ContractStorage.Key([]byte{0})) + trie, err := trie.NewTriePedersen(storage, 251) // Set a suitable height + if err != nil { + b.Fatalf("Failed to create trie: %v", err) + } + adapter := NewDeprecatedTrieAdapter(trie) + + b.Run("Update", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := felt.FromUint64(uint64(i)) + value := felt.FromUint64(uint64(i)) + if err := adapter.Update(&key, &value); err != nil { + b.Fatalf("Update failed: %v", err) + } + } + }) + + b.Run("Get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := felt.FromUint64(uint64(i)) + if _, err := adapter.Get(&key); err != nil { + b.Fatalf("Get failed: %v", err) + } + } + }) + + b.Run("Hash", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := adapter.Hash(); err != nil { + b.Fatalf("Hash failed: %v", err) + } + } + }) +} + +func BenchmarkTrieAdapter(b *testing.B) { + trie, err := trie2.NewEmptyPedersen() + if err != nil { + b.Fatalf("Failed to create trie: %v", err) + } + adapter := NewTrieAdapter(trie) + + b.Run("Update", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := felt.FromUint64(uint64(i)) + value := felt.FromUint64(uint64(i)) + if err := adapter.Update(&key, &value); err != nil { + b.Fatalf("Update failed: %v", err) + } + } + }) + + b.Run("Get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := felt.FromUint64(uint64(i)) + if _, err := adapter.Get(&key); err != nil { + b.Fatalf("Get failed: %v", err) + } + } + }) + + b.Run("Hash", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := adapter.Hash(); err != nil { + b.Fatalf("Hash failed: %v", err) + } + } + }) +} From 7a0bd3ba8199dea93b57925ab6f60c4fde2ade70 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 11 Aug 2025 16:50:27 +0200 Subject: [PATCH 17/47] add benchmarks --- blockchain/blockchain_test.go | 53 +++++++++++++++++++++ core/state/commonstate/state_test.go | 70 ++++++++++++++-------------- 2 files changed, 87 insertions(+), 36 deletions(-) diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index 37b94cf294..9880cf05bc 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -1,6 +1,7 @@ package blockchain_test import ( + "context" "fmt" "testing" @@ -696,3 +697,55 @@ func TestSubscribeL1Head(t *testing.T) { require.True(t, ok) assert.Equal(t, l1Head, got) } + +func fetchStateUpdatesAndBlocks(samples int) ([]*core.StateUpdate, []*core.Block, error) { + client := feeder.NewClient(utils.Mainnet.FeederURL) + gw := adaptfeeder.New(client) + + suList := make([]*core.StateUpdate, samples) + blocks := make([]*core.Block, samples) + for i := 0; i < samples; i++ { + fmt.Println("fetching", i) + su, err := gw.StateUpdate(context.Background(), uint64(i)) + if err != nil { + return nil, nil, err + } + suList[i] = su + block, err := gw.BlockByNumber(context.Background(), uint64(i)) + if err != nil { + return nil, nil, err + } + blocks[i] = block + } + return suList, blocks, nil +} + +func BenchmarkBlockchainStore(b *testing.B) { + samples := 100 + stateUpdates, blocks, err := fetchStateUpdatesAndBlocks(samples) + require.NoError(b, err) + + b.Run("new", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + chain := blockchain.New(memory.New(), &utils.Mainnet, true) + b.StartTimer() + + for j := 0; j < samples; j++ { + chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil) + } + } + }) + + b.Run("old", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + chain := blockchain.New(memory.New(), &utils.Mainnet, false) + b.StartTimer() + + for j := range samples { + chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil) + } + } + }) +} diff --git a/core/state/commonstate/state_test.go b/core/state/commonstate/state_test.go index d8ceda17b0..f5873f5f02 100644 --- a/core/state/commonstate/state_test.go +++ b/core/state/commonstate/state_test.go @@ -2,6 +2,7 @@ package commonstate import ( "context" + "fmt" "testing" "github.com/NethermindEth/juno/clients/feeder" @@ -9,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/trie2/triedb" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" @@ -55,6 +57,7 @@ func fetchStateUpdates(samples int) ([]*core.StateUpdate, error) { suList := make([]*core.StateUpdate, samples) for i := 0; i < samples; i++ { + fmt.Println("fetching", i) su, err := gw.StateUpdate(context.Background(), uint64(i)) if err != nil { return nil, err @@ -64,30 +67,23 @@ func fetchStateUpdates(samples int) ([]*core.StateUpdate, error) { return suList, nil } -func BenchmarkStateUpdateNewState(b *testing.B) { - suList, err := fetchStateUpdates(100) +func BenchmarkStateUpdate(b *testing.B) { + samples := 50 + suList, err := fetchStateUpdates(samples) require.NoError(b, err) - for n := 0; n < b.N; n++ { - b.Run("NewState", func(b *testing.B) { + b.Run("NewState", func(b *testing.B) { + for n := 0; n < b.N; n++ { + b.ReportAllocs() b.StopTimer() - memDB := memory.New() - trieDB, err := triedb.New(memDB, nil) - require.NoError(b, err) - stateDB := state.NewStateDB(memDB, trieDB) - txn := memDB.NewIndexedBatch() - stateFactory, err := NewStateFactory(true, trieDB, stateDB) - require.NoError(b, err) - - state, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(b, err) + state, stateFactory, txn := prepareState(b, true) b.StartTimer() - for i := 0; i < len(suList); i++ { + for i := 0; i < samples; i++ { declaredClasses := make(map[felt.Felt]core.Class) if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { b.Fatalf("Update failed: %v", err) @@ -95,39 +91,41 @@ func BenchmarkStateUpdateNewState(b *testing.B) { state, err = stateFactory.NewState(suList[i].NewRoot, txn) require.NoError(b, err) } - }) - } -} + } + }) -func BenchmarkStateUpdateOldState(b *testing.B) { - suList, err := fetchStateUpdates(300) - require.NoError(b, err) + b.Run("OldState", func(b *testing.B) { + for n := 0; n < b.N; n++ { - for n := 0; n < b.N; n++ { - b.Run("OldState", func(b *testing.B) { b.ReportAllocs() b.StopTimer() - memDB := memory.New() - trieDB, err := triedb.New(memDB, nil) - require.NoError(b, err) - stateDB := state.NewStateDB(memDB, trieDB) - txn := memDB.NewIndexedBatch() - stateFactory, err := NewStateFactory(false, trieDB, stateDB) - require.NoError(b, err) - - state, err := stateFactory.NewState(&felt.Zero, txn) - require.NoError(b, err) + state, _, _ := prepareState(b, false) b.StartTimer() - for i := 0; i < len(suList); i++ { + for i := 0; i < samples; i++ { declaredClasses := make(map[felt.Felt]core.Class) if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { b.Fatalf("Update failed: %v", err) } } - }) - } + } + }) +} + +func prepareState(b *testing.B, newState bool) (State, *StateFactory, db.IndexedBatch) { + memDB := memory.New() + trieDB, err := triedb.New(memDB, nil) + require.NoError(b, err) + stateDB := state.NewStateDB(memDB, trieDB) + txn := memDB.NewIndexedBatch() + stateFactory, err := NewStateFactory(newState, trieDB, stateDB) + require.NoError(b, err) + + state, err := stateFactory.NewState(&felt.Zero, txn) + require.NoError(b, err) + + return state, stateFactory, txn } From f4aee63ba03b74ed5829690babf10119d33a5749 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 25 Aug 2025 10:21:01 +0200 Subject: [PATCH 18/47] fixes for e2e tests --- blockchain/blockchain.go | 47 +++++++++++++++++++++++++++++---- core/state/commonstate/state.go | 14 ---------- core/state/history.go | 21 +++++++++++---- core/state/state.go | 3 --- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index a4eebc024e..7cce2e1f62 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -461,14 +461,32 @@ func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (commonstate.StateRe b.listener.OnRead("StateAtBlockNumber") txn := b.database.NewIndexedBatch() - header, err := core.GetBlockHeaderByNumber(txn, blockNumber) + _, err := core.GetBlockHeaderByNumber(txn, blockNumber) + if err != nil { + return nil, nil, err + } + + if !b.StateFactory.UseNewState { + coreState := core.NewState(txn) + snapshot := core.NewStateSnapshot(coreState, blockNumber) + return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil + } + + height, err := core.GetChainHeight(txn) if err != nil { return nil, nil, err } - stateReader, err := b.StateFactory.NewStateReader(header.GlobalStateRoot, txn, blockNumber) + header, err := core.GetBlockHeaderByNumber(txn, height) + if err != nil { + return nil, nil, err + } - return stateReader, noopStateCloser, err + history, err := state.NewStateHistory(blockNumber, header.GlobalStateRoot, b.stateDB) + if err != nil { + return nil, nil, err + } + return commonstate.NewStateReaderAdapter(&history), noopStateCloser, nil } // StateAtBlockHash returns a StateReader that provides a stable view to the state at the given block hash @@ -484,8 +502,27 @@ func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateRe if err != nil { return nil, nil, err } - stateReader, err := b.StateFactory.NewStateReader(header.GlobalStateRoot, txn, header.Number) - return stateReader, noopStateCloser, err + if !b.StateFactory.UseNewState { + coreState := core.NewState(txn) + snapshot := core.NewStateSnapshot(coreState, header.Number) + return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil + } + + height, err := core.GetChainHeight(txn) + if err != nil { + return nil, nil, err + } + + headHeader, err := core.GetBlockHeaderByNumber(txn, height) + if err != nil { + return nil, nil, err + } + + history, err := state.NewStateHistory(header.Number, headHeader.GlobalStateRoot, b.stateDB) + if err != nil { + return nil, nil, err + } + return commonstate.NewStateReaderAdapter(&history), noopStateCloser, nil } // EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 504c9bb497..aa637216d9 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -309,20 +309,6 @@ func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (Sta return NewStateAdapter(stateState), nil } -func (sf *StateFactory) NewStateReader(stateRoot *felt.Felt, txn db.IndexedBatch, blockNumber uint64) (StateReader, error) { - if !sf.UseNewState { - coreState := core.NewState(txn) - snapshot := core.NewStateSnapshot(coreState, blockNumber) - return NewDeprecatedStateReaderAdapter(snapshot), nil - } - - history, err := state.NewStateHistory(blockNumber, stateRoot, sf.stateDB) - if err != nil { - return nil, err - } - return NewStateReaderAdapter(&history), nil -} - func (sf *StateFactory) EmptyState() (StateReader, error) { if !sf.UseNewState { memDB := memory.New() diff --git a/core/state/history.go b/core/state/history.go index af680d686e..a2e24a980e 100644 --- a/core/state/history.go +++ b/core/state/history.go @@ -1,6 +1,8 @@ package state import ( + "errors" + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie2" @@ -16,13 +18,13 @@ type StateHistory struct { } func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (StateHistory, error) { + state, err := New(stateRoot, db) + if err != nil { + return StateHistory{}, err + } return StateHistory{ blockNum: blockNum, - state: &State{ - initRoot: *stateRoot, - db: db, - stateObjects: make(map[felt.Felt]*stateObject), - }, + state: state, }, nil } @@ -32,6 +34,9 @@ func (s *StateHistory) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { } ret, err := s.state.ContractClassHashAt(addr, s.blockNum) if err != nil { + if errors.Is(err, ErrNoHistoryValue) { + return s.state.ContractClassHash(addr) + } return felt.Felt{}, err } return ret, nil @@ -43,6 +48,9 @@ func (s *StateHistory) ContractNonce(addr *felt.Felt) (felt.Felt, error) { } ret, err := s.state.ContractNonceAt(addr, s.blockNum) if err != nil { + if errors.Is(err, ErrNoHistoryValue) { + return s.state.ContractNonce(addr) + } return felt.Felt{}, err } return ret, nil @@ -54,6 +62,9 @@ func (s *StateHistory) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) } ret, err := s.state.ContractStorageAt(addr, key, s.blockNum) if err != nil { + if errors.Is(err, ErrNoHistoryValue) { + return s.state.ContractStorage(addr, key) + } return felt.Felt{}, err } return ret, nil diff --git a/core/state/state.go b/core/state/state.go index 307260b790..5605425c83 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -681,9 +681,6 @@ func (s *State) getHistoricalValue(prefix []byte, blockNum uint64) (felt.Felt, e return nil }) if err != nil { - if errors.Is(err, ErrNoHistoryValue) { - return felt.Zero, nil - } return felt.Zero, err } From c95c8344b64e53bb046ab3781b7a4fc619b2a5cf Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 26 Aug 2025 10:11:02 +0200 Subject: [PATCH 19/47] lint, unit tests --- blockchain/blockchain_test.go | 12 ++++++------ core/state/commonstate/state_test.go | 12 +++++------- core/state/commontrie/trie_test.go | 15 +++++++++------ rpc/v9/subscriptions_test.go | 23 +++++++++++++---------- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index 9880cf05bc..9e98827b31 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -704,7 +704,7 @@ func fetchStateUpdatesAndBlocks(samples int) ([]*core.StateUpdate, []*core.Block suList := make([]*core.StateUpdate, samples) blocks := make([]*core.Block, samples) - for i := 0; i < samples; i++ { + for i := range samples { fmt.Println("fetching", i) su, err := gw.StateUpdate(context.Background(), uint64(i)) if err != nil { @@ -726,25 +726,25 @@ func BenchmarkBlockchainStore(b *testing.B) { require.NoError(b, err) b.Run("new", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for b.Loop() { b.StopTimer() chain := blockchain.New(memory.New(), &utils.Mainnet, true) b.StartTimer() - for j := 0; j < samples; j++ { - chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil) + for j := range samples { + require.NoError(b, chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil)) } } }) b.Run("old", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for b.Loop() { b.StopTimer() chain := blockchain.New(memory.New(), &utils.Mainnet, false) b.StartTimer() for j := range samples { - chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil) + require.NoError(b, chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil)) } } }) diff --git a/core/state/commonstate/state_test.go b/core/state/commonstate/state_test.go index b344139339..af6a41116b 100644 --- a/core/state/commonstate/state_test.go +++ b/core/state/commonstate/state_test.go @@ -47,7 +47,7 @@ func fetchStateUpdates(samples int) ([]*core.StateUpdate, error) { gw := adaptfeeder.New(client) suList := make([]*core.StateUpdate, samples) - for i := 0; i < samples; i++ { + for i := range samples { fmt.Println("fetching", i) su, err := gw.StateUpdate(context.Background(), uint64(i)) if err != nil { @@ -64,8 +64,7 @@ func BenchmarkStateUpdate(b *testing.B) { require.NoError(b, err) b.Run("NewState", func(b *testing.B) { - for n := 0; n < b.N; n++ { - + for b.Loop() { b.ReportAllocs() b.StopTimer() @@ -74,7 +73,7 @@ func BenchmarkStateUpdate(b *testing.B) { b.StartTimer() - for i := 0; i < samples; i++ { + for i := range samples { declaredClasses := make(map[felt.Felt]core.Class) if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { b.Fatalf("Update failed: %v", err) @@ -86,8 +85,7 @@ func BenchmarkStateUpdate(b *testing.B) { }) b.Run("OldState", func(b *testing.B) { - for n := 0; n < b.N; n++ { - + for b.Loop() { b.ReportAllocs() b.StopTimer() @@ -96,7 +94,7 @@ func BenchmarkStateUpdate(b *testing.B) { b.StartTimer() - for i := 0; i < samples; i++ { + for i := range samples { declaredClasses := make(map[felt.Felt]core.Class) if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { b.Fatalf("Update failed: %v", err) diff --git a/core/state/commontrie/trie_test.go b/core/state/commontrie/trie_test.go index e699f8375d..9002521100 100644 --- a/core/state/commontrie/trie_test.go +++ b/core/state/commontrie/trie_test.go @@ -4,7 +4,10 @@ import ( "testing" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -51,7 +54,7 @@ func BenchmarkDeprecatedTrieAdapter(b *testing.B) { adapter := NewDeprecatedTrieAdapter(trie) b.Run("Update", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := range b.N { key := felt.FromUint64(uint64(i)) value := felt.FromUint64(uint64(i)) if err := adapter.Update(&key, &value); err != nil { @@ -61,7 +64,7 @@ func BenchmarkDeprecatedTrieAdapter(b *testing.B) { }) b.Run("Get", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := range b.N { key := felt.FromUint64(uint64(i)) if _, err := adapter.Get(&key); err != nil { b.Fatalf("Get failed: %v", err) @@ -70,7 +73,7 @@ func BenchmarkDeprecatedTrieAdapter(b *testing.B) { }) b.Run("Hash", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { if _, err := adapter.Hash(); err != nil { b.Fatalf("Hash failed: %v", err) } @@ -86,7 +89,7 @@ func BenchmarkTrieAdapter(b *testing.B) { adapter := NewTrieAdapter(trie) b.Run("Update", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := range b.N { key := felt.FromUint64(uint64(i)) value := felt.FromUint64(uint64(i)) if err := adapter.Update(&key, &value); err != nil { @@ -96,7 +99,7 @@ func BenchmarkTrieAdapter(b *testing.B) { }) b.Run("Get", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for i := range b.N { key := felt.FromUint64(uint64(i)) if _, err := adapter.Get(&key); err != nil { b.Fatalf("Get failed: %v", err) @@ -105,7 +108,7 @@ func BenchmarkTrieAdapter(b *testing.B) { }) b.Run("Hash", func(b *testing.B) { - for i := 0; i < b.N; i++ { + for b.Loop() { if _, err := adapter.Hash(); err != nil { b.Fatalf("Hash failed: %v", err) } diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index 0090f74a2a..25d4fa2a29 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/commonstate" statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" @@ -88,9 +89,11 @@ func (fs *fakeSyncer) PendingData() (core.PendingData, error) { } func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { return nil, nil, nil } +func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, error) { + return nil, nil, nil +} -func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (core.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { return nil, nil, nil } @@ -277,7 +280,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b1.Header, nil).Times(1) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) @@ -320,7 +323,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b1.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(b1Filtered, nil, nil) @@ -363,7 +366,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b1.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(append(b1Filtered, preConfirmed1Filtered...), nil, nil) @@ -398,7 +401,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b2.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: b1.Header.Number}, + core.L1Head{BlockNumber: b1.Header.Number}, nil, ) mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) @@ -425,7 +428,7 @@ func TestSubscribeEvents(t *testing.T) { mockChain.EXPECT().HeadsHeader().Return(b2.Header, nil) mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) cToken := new(blockchain.ContinuationToken) @@ -482,7 +485,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b1.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(append(b1FilteredBySenders, preConfirmedFilteredBySenders...), nil, nil) @@ -554,7 +557,7 @@ func TestSubscribeEvents(t *testing.T) { setupMocks: func() { mockChain.EXPECT().HeadsHeader().Return(b1.Header, nil) mockChain.EXPECT().L1Head().Return( - &core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, + core.L1Head{BlockNumber: uint64(max(0, int(b1.Header.Number)-1))}, nil, ) mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(append(b1FilteredByFromAddressAndKey, preConfirmedFilteredBySendersAndKey...), nil, nil) @@ -998,7 +1001,7 @@ func TestSubscriptionReorg(t *testing.T) { mockChain := mocks.NewMockReader(mockCtrl) l1Feed := feed.New[*core.L1Head]() mockChain.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) - mockChain.EXPECT().L1Head().Return(&core.L1Head{BlockNumber: 0}, nil) + mockChain.EXPECT().L1Head().Return(core.L1Head{BlockNumber: 0}, nil) syncer := newFakeSyncer() handler, server := setupRPC(t, ctx, mockChain, syncer) From cd1c9473582a01e5448ae3468d877bed8a5a7b72 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 26 Aug 2025 11:22:13 +0200 Subject: [PATCH 20/47] comments --- blockchain/blockchain.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 7cce2e1f62..9fac007fa9 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -105,7 +105,7 @@ func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) * } stateDB := state.NewStateDB(database, trieDB) - StateFactory, err := commonstate.NewStateFactory(stateVersion, trieDB, stateDB) + stateFactory, err := commonstate.NewStateFactory(stateVersion, trieDB, stateDB) if err != nil { panic(err) } @@ -127,7 +127,7 @@ func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) * l1HeadFeed: feed.New[*core.L1Head](), cachedFilters: &cachedFilters, runningFilter: runningFilter, - StateFactory: StateFactory, + StateFactory: stateFactory, } } @@ -272,13 +272,13 @@ func (b *Blockchain) Store(block *core.Block, blockCommitments *core.BlockCommit // old state // TODO(maksymmalick): remove this once we have a new state implementation if !b.StateFactory.UseNewState { - return b.storeCoreState(block, blockCommitments, stateUpdate, newClasses) + return b.deprecatedStore(block, blockCommitments, stateUpdate, newClasses) } return b.store(block, blockCommitments, stateUpdate, newClasses) } -func (b *Blockchain) storeCoreState( +func (b *Blockchain) deprecatedStore( block *core.Block, blockCommitments *core.BlockCommitments, stateUpdate *core.StateUpdate, @@ -467,8 +467,8 @@ func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (commonstate.StateRe } if !b.StateFactory.UseNewState { - coreState := core.NewState(txn) - snapshot := core.NewStateSnapshot(coreState, blockNumber) + deprecatedState := core.NewState(txn) + snapshot := core.NewStateSnapshot(deprecatedState, blockNumber) return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil } @@ -503,8 +503,8 @@ func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateRe return nil, nil, err } if !b.StateFactory.UseNewState { - coreState := core.NewState(txn) - snapshot := core.NewStateSnapshot(coreState, header.Number) + deprecatedState := core.NewState(txn) + snapshot := core.NewStateSnapshot(deprecatedState, header.Number) return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil } @@ -539,14 +539,14 @@ func (b *Blockchain) EventFilter(from *felt.Felt, keys [][]felt.Felt, pendingBlo // RevertHead reverts the head block func (b *Blockchain) RevertHead() error { if !b.StateFactory.UseNewState { - return b.database.Update(b.revertHeadCoreState) + return b.database.Update(b.deprecatedRevertHead) } return b.revertHead() } func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) { if !b.StateFactory.UseNewState { - reverseStateDiff, err := b.getReverseStateDiffCoreState() + reverseStateDiff, err := b.deprecatedGetReverseStateDiff() if err != nil { return core.StateDiff{}, err } @@ -557,7 +557,7 @@ func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) { } // TODO(maksymmalick): remove this once we have a new state integrated -func (b *Blockchain) getReverseStateDiffCoreState() (*core.StateDiff, error) { +func (b *Blockchain) deprecatedGetReverseStateDiff() (*core.StateDiff, error) { var reverseStateDiff *core.StateDiff txn := b.database.NewIndexedBatch() @@ -604,7 +604,7 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) { return ret, nil } -func (b *Blockchain) revertHeadCoreState(txn db.IndexedBatch) error { +func (b *Blockchain) deprecatedRevertHead(txn db.IndexedBatch) error { blockNumber, err := core.GetChainHeight(txn) if err != nil { return err From a7aeb400348049cff6f18cab81b007c63e98b03e Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 27 Aug 2025 23:05:50 +0200 Subject: [PATCH 21/47] fix for tx tracing --- core/state/commontrie/trie_test.go | 137 ++++++++++++++++------------- core/state/errors.go | 1 + core/state/history.go | 6 +- core/state/state.go | 5 +- 4 files changed, 86 insertions(+), 63 deletions(-) diff --git a/core/state/commontrie/trie_test.go b/core/state/commontrie/trie_test.go index 9002521100..fc74e8699d 100644 --- a/core/state/commontrie/trie_test.go +++ b/core/state/commontrie/trie_test.go @@ -43,75 +43,94 @@ func TestTrieAdapter(t *testing.T) { }) } -func BenchmarkDeprecatedTrieAdapter(b *testing.B) { - memDB := memory.New() - txn := memDB.NewIndexedBatch() - storage := trie.NewStorage(txn, db.ContractStorage.Key([]byte{0})) - trie, err := trie.NewTriePedersen(storage, 251) // Set a suitable height - if err != nil { - b.Fatalf("Failed to create trie: %v", err) +func BenchmarkTrieAdapters(b *testing.B) { + benchmarkData := make(map[felt.Felt]felt.Felt) + for i := 0; i < 10000; i++ { + key, err := new(felt.Felt).SetRandom() + require.NoError(b, err) + value, err := new(felt.Felt).SetRandom() + require.NoError(b, err) + benchmarkData[*key] = *value } - adapter := NewDeprecatedTrieAdapter(trie) - - b.Run("Update", func(b *testing.B) { - for i := range b.N { - key := felt.FromUint64(uint64(i)) - value := felt.FromUint64(uint64(i)) - if err := adapter.Update(&key, &value); err != nil { - b.Fatalf("Update failed: %v", err) - } - } - }) - b.Run("Get", func(b *testing.B) { - for i := range b.N { - key := felt.FromUint64(uint64(i)) - if _, err := adapter.Get(&key); err != nil { - b.Fatalf("Get failed: %v", err) - } + b.Run("DeprecatedTrieAdapter", func(b *testing.B) { + memDB := memory.New() + txn := memDB.NewIndexedBatch() + storage := trie.NewStorage(txn, db.ContractStorage.Key([]byte{0})) + trie, err := trie.NewTriePedersen(storage, 251) + if err != nil { + b.Fatalf("Failed to create trie: %v", err) } - }) - - b.Run("Hash", func(b *testing.B) { - for range b.N { - if _, err := adapter.Hash(); err != nil { - b.Fatalf("Hash failed: %v", err) + adapter := NewDeprecatedTrieAdapter(trie) + + b.Run("Update", func(b *testing.B) { + for b.Loop() { + for key, value := range benchmarkData { + if key.Uint64()%20 == 0 { + value = felt.FromUint64(0) + } + if err := adapter.Update(&key, &value); err != nil { + b.Fatalf("Update failed: %v", err) + } + } } - } - }) -} - -func BenchmarkTrieAdapter(b *testing.B) { - trie, err := trie2.NewEmptyPedersen() - if err != nil { - b.Fatalf("Failed to create trie: %v", err) - } - adapter := NewTrieAdapter(trie) + }) + + b.Run("Get", func(b *testing.B) { + for b.Loop() { + for key := range benchmarkData { + if _, err := adapter.Get(&key); err != nil { + b.Fatalf("Get failed: %v", err) + } + } + } + }) - b.Run("Update", func(b *testing.B) { - for i := range b.N { - key := felt.FromUint64(uint64(i)) - value := felt.FromUint64(uint64(i)) - if err := adapter.Update(&key, &value); err != nil { - b.Fatalf("Update failed: %v", err) + b.Run("Hash", func(b *testing.B) { + for b.Loop() { + if _, err := adapter.Hash(); err != nil { + b.Fatalf("Hash failed: %v", err) + } } - } + }) }) - b.Run("Get", func(b *testing.B) { - for i := range b.N { - key := felt.FromUint64(uint64(i)) - if _, err := adapter.Get(&key); err != nil { - b.Fatalf("Get failed: %v", err) - } + b.Run("TrieAdapter", func(b *testing.B) { + trie, err := trie2.NewEmptyPedersen() + if err != nil { + b.Fatalf("Failed to create trie: %v", err) } - }) + adapter := NewTrieAdapter(trie) + + b.Run("Update", func(b *testing.B) { + for b.Loop() { + for key, value := range benchmarkData { + if key.Uint64()%20 == 0 { + value = felt.FromUint64(0) + } + if err := adapter.Update(&key, &value); err != nil { + b.Fatalf("Update failed: %v", err) + } + } + } + }) + + b.Run("Get", func(b *testing.B) { + for b.Loop() { + for key := range benchmarkData { + if _, err := adapter.Get(&key); err != nil { + b.Fatalf("Get failed: %v", err) + } + } + } + }) - b.Run("Hash", func(b *testing.B) { - for b.Loop() { - if _, err := adapter.Hash(); err != nil { - b.Fatalf("Hash failed: %v", err) + b.Run("Hash", func(b *testing.B) { + for b.Loop() { + if _, err := adapter.Hash(); err != nil { + b.Fatalf("Hash failed: %v", err) + } } - } + }) }) } diff --git a/core/state/errors.go b/core/state/errors.go index 6e2c50cad7..ddb154b9e5 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -8,5 +8,6 @@ var ( ErrContractNotDeployed = errors.New("contract not deployed") ErrContractAlreadyDeployed = errors.New("contract already deployed") ErrNoHistoryValue = errors.New("no history value found") + ErrCheckHeadState = errors.New("check head state") ErrHistoricalTrieNotSupported = errors.New("cannot support historical trie") ) diff --git a/core/state/history.go b/core/state/history.go index a2e24a980e..fc1b129e3a 100644 --- a/core/state/history.go +++ b/core/state/history.go @@ -34,7 +34,7 @@ func (s *StateHistory) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { } ret, err := s.state.ContractClassHashAt(addr, s.blockNum) if err != nil { - if errors.Is(err, ErrNoHistoryValue) { + if errors.Is(err, ErrCheckHeadState) { return s.state.ContractClassHash(addr) } return felt.Felt{}, err @@ -48,7 +48,7 @@ func (s *StateHistory) ContractNonce(addr *felt.Felt) (felt.Felt, error) { } ret, err := s.state.ContractNonceAt(addr, s.blockNum) if err != nil { - if errors.Is(err, ErrNoHistoryValue) { + if errors.Is(err, ErrCheckHeadState) { return s.state.ContractNonce(addr) } return felt.Felt{}, err @@ -62,7 +62,7 @@ func (s *StateHistory) ContractStorage(addr, key *felt.Felt) (felt.Felt, error) } ret, err := s.state.ContractStorageAt(addr, key, s.blockNum) if err != nil { - if errors.Is(err, ErrNoHistoryValue) { + if errors.Is(err, ErrCheckHeadState) { return s.state.ContractStorage(addr, key) } return felt.Felt{}, err diff --git a/core/state/state.go b/core/state/state.go index 5605425c83..993c9be57b 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -681,6 +681,9 @@ func (s *State) getHistoricalValue(prefix []byte, blockNum uint64) (felt.Felt, e return nil }) if err != nil { + if errors.Is(err, ErrNoHistoryValue) { + return felt.Zero, nil + } return felt.Zero, err } @@ -701,7 +704,7 @@ func (s *State) valueAt(prefix []byte, blockNum uint64, cb func(val []byte) erro seekKey := binary.BigEndian.AppendUint64(prefix, blockNum) if !it.Seek(seekKey) { - return ErrNoHistoryValue + return ErrCheckHeadState } key := it.Key() From 10a2758bf16967f51358f5e855856b84b351ac76 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 29 Aug 2025 11:03:20 +0200 Subject: [PATCH 22/47] read storage from cache not db --- core/state/object.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/state/object.go b/core/state/object.go index 45a6d6d8d8..b267c95f34 100644 --- a/core/state/object.go +++ b/core/state/object.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/core/trie2/trieutils" - "github.com/NethermindEth/juno/db" "golang.org/x/exp/maps" ) @@ -50,8 +49,13 @@ func (s *stateObject) getStorage(key *felt.Felt) (felt.Felt, error) { return felt.Zero, err } + // TODO(maksym): test if this works instead of reading from disk path := tr.FeltToPath(key) - v, err := trieutils.GetNodeByPath(s.state.db.disk, db.ContractTrieStorage, &s.addr, &path, true) + reader, err := s.state.db.triedb.NodeReader(trieutils.NewContractStorageTrieID(s.state.initRoot, s.addr)) + if err != nil { + return felt.Zero, err + } + v, err := reader.Node(&s.addr, &path, nil, true) if err != nil { return felt.Zero, err } From c5619286e5b4e0cfe8b75213b180aff52b3ad509 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 29 Aug 2025 19:54:25 +0200 Subject: [PATCH 23/47] use trie2 static trie functions --- core/receipt.go | 17 +++++++---------- core/transaction.go | 12 ++++++------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/core/receipt.go b/core/receipt.go index 0852f2c5ce..cf1f56055e 100644 --- a/core/receipt.go +++ b/core/receipt.go @@ -6,7 +6,7 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" ) type GasConsumed struct { @@ -67,7 +67,7 @@ func messagesSentHash(messages []*L2ToL1Message) *felt.Felt { func receiptCommitment(receipts []*TransactionReceipt) (*felt.Felt, error) { return calculateCommitment( receipts, - trie.RunOnTempTriePoseidon, + trie2.RunOnTempTriePoseidon, func(receipt *TransactionReceipt) *felt.Felt { return receipt.hash() }, @@ -76,14 +76,14 @@ func receiptCommitment(receipts []*TransactionReceipt) (*felt.Felt, error) { // TODO(maksymmalick): change this to trie2 after integration done type ( - onTempTrieFunc func(uint8, func(*trie.Trie) error) error + onTempTrieFunc func(uint8, func(*trie2.Trie) error) error processFunc[T any] func(T) *felt.Felt ) // General function for parallel processing of items and calculation of a commitment func calculateCommitment[T any](items []T, runOnTempTrie onTempTrieFunc, process processFunc[T]) (*felt.Felt, error) { var commitment *felt.Felt - return commitment, runOnTempTrie(commitmentTrieHeight, func(trie *trie.Trie) error { + return commitment, runOnTempTrie(commitmentTrieHeight, func(trie *trie2.Trie) error { numWorkers := min(runtime.GOMAXPROCS(0), len(items)) results := make([]*felt.Felt, len(items)) var wg sync.WaitGroup @@ -108,16 +108,13 @@ func calculateCommitment[T any](items []T, runOnTempTrie onTempTrieFunc, process for i, res := range results { key := new(felt.Felt).SetUint64(uint64(i)) - if _, err := trie.Put(key, res); err != nil { + if err := trie.Update(key, res); err != nil { return err } } - root, err := trie.Root() - if err != nil { - return err - } - commitment = root + root := trie.Hash() + commitment = &root return nil }) diff --git a/core/transaction.go b/core/transaction.go index f290f8271a..9bf015cddb 100644 --- a/core/transaction.go +++ b/core/transaction.go @@ -10,7 +10,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/utils" "github.com/bits-and-blooms/bloom/v3" "github.com/ethereum/go-ethereum/common" @@ -668,13 +668,13 @@ func transactionCommitmentPedersen(transactions []Transaction, protocolVersion s return crypto.Pedersen(transaction.Hash(), signatureHash) } } - return calculateCommitment(transactions, trie.RunOnTempTriePedersen, hashFunc) + return calculateCommitment(transactions, trie2.RunOnTempTriePedersen, hashFunc) } // transactionCommitmentPoseidon0134 handles empty signatures compared to transactionCommitmentPoseidon0132: // empty signatures are interpreted as [] instead of [0] func transactionCommitmentPoseidon0134(transactions []Transaction) (*felt.Felt, error) { - return calculateCommitment(transactions, trie.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { + return calculateCommitment(transactions, trie2.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { var digest crypto.PoseidonDigest digest.Update(transaction.Hash()) @@ -688,7 +688,7 @@ func transactionCommitmentPoseidon0134(transactions []Transaction) (*felt.Felt, // transactionCommitmentPoseidon0132 is used to calculate tx commitment for 0.13.2 <= block.version < 0.13.4 func transactionCommitmentPoseidon0132(transactions []Transaction) (*felt.Felt, error) { - return calculateCommitment(transactions, trie.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { + return calculateCommitment(transactions, trie2.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { var digest crypto.PoseidonDigest digest.Update(transaction.Hash()) @@ -722,7 +722,7 @@ func eventCommitmentPoseidon(receipts []*TransactionReceipt) (*felt.Felt, error) }) } } - return calculateCommitment(items, trie.RunOnTempTriePoseidon, func(item *eventWithTxHash) *felt.Felt { + return calculateCommitment(items, trie2.RunOnTempTriePoseidon, func(item *eventWithTxHash) *felt.Felt { return crypto.PoseidonArray( slices.Concat( []*felt.Felt{ @@ -750,7 +750,7 @@ func eventCommitmentPedersen(receipts []*TransactionReceipt) (*felt.Felt, error) for _, receipt := range receipts { events = append(events, receipt.Events...) } - return calculateCommitment(events, trie.RunOnTempTriePedersen, func(event *Event) *felt.Felt { + return calculateCommitment(events, trie2.RunOnTempTriePedersen, func(event *Event) *felt.Felt { return crypto.PedersenArray( event.From, crypto.PedersenArray(event.Keys...), From b40d9d79b4036b118716e4be30dfe8d9c975fc2e Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 25 Sep 2025 10:51:22 +0200 Subject: [PATCH 24/47] remove benchmarks --- core/state/commontrie/trie_test.go | 95 ------------------------------ 1 file changed, 95 deletions(-) diff --git a/core/state/commontrie/trie_test.go b/core/state/commontrie/trie_test.go index fc74e8699d..08e717d058 100644 --- a/core/state/commontrie/trie_test.go +++ b/core/state/commontrie/trie_test.go @@ -4,10 +4,7 @@ import ( "testing" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,95 +39,3 @@ func TestTrieAdapter(t *testing.T) { assert.NotNil(t, hashFn) }) } - -func BenchmarkTrieAdapters(b *testing.B) { - benchmarkData := make(map[felt.Felt]felt.Felt) - for i := 0; i < 10000; i++ { - key, err := new(felt.Felt).SetRandom() - require.NoError(b, err) - value, err := new(felt.Felt).SetRandom() - require.NoError(b, err) - benchmarkData[*key] = *value - } - - b.Run("DeprecatedTrieAdapter", func(b *testing.B) { - memDB := memory.New() - txn := memDB.NewIndexedBatch() - storage := trie.NewStorage(txn, db.ContractStorage.Key([]byte{0})) - trie, err := trie.NewTriePedersen(storage, 251) - if err != nil { - b.Fatalf("Failed to create trie: %v", err) - } - adapter := NewDeprecatedTrieAdapter(trie) - - b.Run("Update", func(b *testing.B) { - for b.Loop() { - for key, value := range benchmarkData { - if key.Uint64()%20 == 0 { - value = felt.FromUint64(0) - } - if err := adapter.Update(&key, &value); err != nil { - b.Fatalf("Update failed: %v", err) - } - } - } - }) - - b.Run("Get", func(b *testing.B) { - for b.Loop() { - for key := range benchmarkData { - if _, err := adapter.Get(&key); err != nil { - b.Fatalf("Get failed: %v", err) - } - } - } - }) - - b.Run("Hash", func(b *testing.B) { - for b.Loop() { - if _, err := adapter.Hash(); err != nil { - b.Fatalf("Hash failed: %v", err) - } - } - }) - }) - - b.Run("TrieAdapter", func(b *testing.B) { - trie, err := trie2.NewEmptyPedersen() - if err != nil { - b.Fatalf("Failed to create trie: %v", err) - } - adapter := NewTrieAdapter(trie) - - b.Run("Update", func(b *testing.B) { - for b.Loop() { - for key, value := range benchmarkData { - if key.Uint64()%20 == 0 { - value = felt.FromUint64(0) - } - if err := adapter.Update(&key, &value); err != nil { - b.Fatalf("Update failed: %v", err) - } - } - } - }) - - b.Run("Get", func(b *testing.B) { - for b.Loop() { - for key := range benchmarkData { - if _, err := adapter.Get(&key); err != nil { - b.Fatalf("Get failed: %v", err) - } - } - } - }) - - b.Run("Hash", func(b *testing.B) { - for b.Loop() { - if _, err := adapter.Hash(); err != nil { - b.Fatalf("Hash failed: %v", err) - } - } - }) - }) -} From 33e0eff313062b2ce4f30fb912d1e226490eabda Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 2 Oct 2025 13:47:09 +0200 Subject: [PATCH 25/47] Fix last syncing bug --- core/state/state.go | 4 ++-- core/trie2/triedb/pathdb/disklayer.go | 6 ++++-- core/trie2/triedb/pathdb/journal.go | 9 +++++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/state/state.go b/core/state/state.go index 993c9be57b..4d7cef2947 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -394,7 +394,7 @@ func (s *State) commit() (felt.Felt, stateUpdate, error) { for i, addr := range keys { obj := s.stateObjects[addr] - + idx := i p.Go(func() error { // Object is marked as delete if obj == nil { @@ -410,7 +410,7 @@ func (s *State) commit() (felt.Felt, stateUpdate, error) { return err } - comms[i] = obj.commitment() + comms[idx] = obj.commitment() return nil }) } diff --git a/core/trie2/triedb/pathdb/disklayer.go b/core/trie2/triedb/pathdb/disklayer.go index 029dfc366b..dcfb30b847 100644 --- a/core/trie2/triedb/pathdb/disklayer.go +++ b/core/trie2/triedb/pathdb/disklayer.go @@ -86,9 +86,11 @@ func (dl *diskLayer) node(id trieutils.TrieID, owner *felt.Felt, path *trieutils if err != nil { return nil, err } + blobCopy := make([]byte, len(blob)) + copy(blobCopy, blob) - dl.cleans.putNode(owner, path, isClass, blob) - return blob, nil + dl.cleans.putNode(owner, path, isClass, blobCopy) + return blobCopy, nil } func (dl *diskLayer) update(root *felt.Felt, id, block uint64, nodes *nodeSet) diffLayer { diff --git a/core/trie2/triedb/pathdb/journal.go b/core/trie2/triedb/pathdb/journal.go index f4aa0e24df..39201d69d4 100644 --- a/core/trie2/triedb/pathdb/journal.go +++ b/core/trie2/triedb/pathdb/journal.go @@ -272,7 +272,7 @@ func (d *Database) loadLayers(enc []byte) (layer, error) { } func (d *Database) getStateRoot() felt.Felt { - encContractRoot, err := trieutils.GetNodeByPath( + encContractRootRaw, err := trieutils.GetNodeByPath( d.disk, db.ContractTrieContract, &felt.Zero, @@ -282,8 +282,11 @@ func (d *Database) getStateRoot() felt.Felt { if err != nil { return felt.Zero } + encContractRoot := make([]byte, len(encContractRootRaw)) + copy(encContractRoot, encContractRootRaw) + + encStorageRootRaw, err := trieutils.GetNodeByPath( - encStorageRoot, err := trieutils.GetNodeByPath( d.disk, db.ClassTrie, &felt.Zero, @@ -293,6 +296,8 @@ func (d *Database) getStateRoot() felt.Felt { if err != nil { return felt.Zero } + encStorageRoot := make([]byte, len(encStorageRootRaw)) + copy(encStorageRoot, encStorageRootRaw) contractRootNode, err := trienode.DecodeNode(encContractRoot, &felt.Zero, 0, contractClassTrieHeight) if err != nil { From c7ea439fc08755da945aa9ac952ae7a6fb565cfa Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 7 Oct 2025 10:47:34 +0200 Subject: [PATCH 26/47] fixes --- core/trie2/trieutils/accessors.go | 4 ++-- db/buckets.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/trie2/trieutils/accessors.go b/core/trie2/trieutils/accessors.go index ce9d31adc1..196d5c59d5 100644 --- a/core/trie2/trieutils/accessors.go +++ b/core/trie2/trieutils/accessors.go @@ -12,7 +12,7 @@ func GetNodeByPath(r db.KeyValueReader, bucket db.Bucket, owner *felt.Felt, path var res []byte if err := r.Get(nodeKeyByPath(bucket, owner, path, isLeaf), func(value []byte) error { - res = value + res = append([]byte(nil), value...) return nil }, ); err != nil { @@ -129,7 +129,7 @@ func GetNodeByHash(r db.KeyValueReader, bucket db.Bucket, owner *felt.Felt, path var res []byte if err := r.Get(nodeKeyByHash(bucket, owner, path, hash, isLeaf), func(value []byte) error { - res = value + res = append([]byte(nil), value...) return nil }, ); err != nil { diff --git a/db/buckets.go b/db/buckets.go index 4e20c924b1..51cf37088c 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -40,7 +40,7 @@ const ( MempoolNode ClassTrie // ClassTrie + nodetype + path + pathlength -> Trie Node ContractTrieContract // ContractTrieContract + nodetype + path + pathlength -> Trie Node - ContractTrieStorage // ContractTrieStorage + nodetype + path + pathlength -> Trie Node + ContractTrieStorage // ContractTrieStorage + nodetype + owner + path + pathlength -> Trie Node Contract // Contract + ContractAddr -> Contract StateHashToTrieRoots // StateHash -> ClassRootHash + ContractRootHash StateID // StateID + root hash -> state id From 2447d445d673cce6474b9bfcd59c1a3348b7e9b1 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 13 Oct 2025 10:46:33 +0200 Subject: [PATCH 27/47] fixes --- core/receipt.go | 5 +---- rpc/v8/storage_test.go | 16 ++++++++-------- rpc/v9/l1.go | 2 +- rpc/v9/storage_test.go | 20 ++++++++++---------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/core/receipt.go b/core/receipt.go index 65d5218a68..cf1f56055e 100644 --- a/core/receipt.go +++ b/core/receipt.go @@ -113,10 +113,7 @@ func calculateCommitment[T any](items []T, runOnTempTrie onTempTrieFunc, process } } - root, err := trie.Hash() - if err != nil { - return err - } + root := trie.Hash() commitment = &root return nil diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 80a4defd13..b1223c72d1 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -314,14 +314,14 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, nil, nil) require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key, *key2}, nil, nil) @@ -334,8 +334,8 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { if statetestutils.UseNewState() { @@ -351,7 +351,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) @@ -369,7 +369,7 @@ func TestStorageProof(t *testing.T) { require.Equal(t, nonce, ld.Nonce) require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { contract := felt.NewUnsafeFromString[felt.Felt]("0xdead") @@ -394,7 +394,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { @@ -408,7 +408,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { nonce := new(felt.Felt).SetUint64(121) diff --git a/rpc/v9/l1.go b/rpc/v9/l1.go index a128793ca0..f82c85082a 100644 --- a/rpc/v9/l1.go +++ b/rpc/v9/l1.go @@ -68,7 +68,7 @@ func (h *Handler) GetMessageStatus(ctx context.Context, l1TxnHash *common.Hash) if err != nil { return nil, jsonrpc.Err(jsonrpc.InternalError, fmt.Errorf("failed to retrieve L1 handler txn %v", err)) } - status, rpcErr := h.TransactionStatus(ctx, hash) + status, rpcErr := h.TransactionStatus(ctx, &hash) if rpcErr != nil { return nil, rpcErr } diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 2d3d6e48ba..42460e4eb8 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -188,7 +188,7 @@ func TestStorageProof(t *testing.T) { ) var classTrie, contractTrie commontrie.Trie - trieRoot := &felt.Zero + trieRoot := felt.Zero if !statetestutils.UseNewState() { tempTrie := emptyTrie(t) @@ -216,7 +216,7 @@ func TestStorageProof(t *testing.T) { createTrie(t, trieutils.NewClassTrieID(felt.Zero), &trieDB) contractTrie2 := createTrie(t, trieutils.NewContractTrieID(felt.Zero), &trieDB) tmpTrieRoot := contractTrie2.Hash() - trieRoot = &tmpTrieRoot + trieRoot = tmpTrieRoot // recreate because the previous ones are committed classTrie2, err := trie2.New(trieutils.NewClassTrieID(*newComm), 251, crypto.Pedersen, &trieDB) @@ -332,14 +332,14 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, nil, nil) require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key, *key2}, nil, nil) @@ -352,8 +352,8 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { if statetestutils.UseNewState() { @@ -370,7 +370,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { nonce := felt.NewFromUint64[felt.Felt](121) @@ -388,7 +388,7 @@ func TestStorageProof(t *testing.T) { require.Equal(t, nonce, ld.Nonce) require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { contract := felt.NewFromUint64[felt.Felt](0xdead) @@ -413,7 +413,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { @@ -427,7 +427,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { nonce := felt.NewFromUint64[felt.Felt](121) From ad74546f26d4f45ae90e250e8b6d9b02273a5157 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 13 Oct 2025 20:26:27 +0200 Subject: [PATCH 28/47] unit tests fix --- l1/l1_pkg_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/l1/l1_pkg_test.go b/l1/l1_pkg_test.go index 169a8ca415..1a45c7e1a7 100644 --- a/l1/l1_pkg_test.go +++ b/l1/l1_pkg_test.go @@ -385,7 +385,7 @@ func TestClient(t *testing.T) { BlockHash: block.expectedL2BlockHash, StateRoot: block.expectedL2BlockHash, } - assert.Equal(t, want, &got) + assert.Equal(t, want, got) } } }) @@ -463,7 +463,7 @@ func TestUnreliableSubscription(t *testing.T) { BlockHash: block.expectedL2BlockHash, StateRoot: block.expectedL2BlockHash, } - assert.Equal(t, want, &got) + assert.Equal(t, want, got) } } } From f8b7fb0996880eb203ce5de52ba1a0ed57f9de44 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 13:07:01 +0200 Subject: [PATCH 29/47] fix lint --- rpc/v8/storage_test.go | 12 ++++++------ rpc/v9/storage_test.go | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index ecd601d0b9..2ce9245da9 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -178,8 +178,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) - contractTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) + classTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) + contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { @@ -205,8 +205,8 @@ func TestStorageProof(t *testing.T) { require.NoError(t, err) contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) require.NoError(t, err) - classTrie = commontrie.NewTrieAdapter(classTrie2) - contractTrie = commontrie.NewTrieAdapter(contractTrie2) + classTrie = (*commontrie.TrieAdapter)(classTrie2) + contractTrie = (*commontrie.TrieAdapter)(contractTrie2) } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -840,9 +840,9 @@ func emptyCommonTrie(t *testing.T) commontrie.Trie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return commontrie.NewTrieAdapter(tempTrie) + return (*commontrie.TrieAdapter)(tempTrie) } else { - return commontrie.NewDeprecatedTrieAdapter(emptyTrie(t)) + return (*commontrie.DeprecatedTrieAdapter)(emptyTrie(t)) } } diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 7621c3c411..eb078bb4a5 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -196,8 +196,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) - contractTrie = commontrie.NewDeprecatedTrieAdapter(tempTrie) + classTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) + contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { @@ -223,8 +223,8 @@ func TestStorageProof(t *testing.T) { require.NoError(t, err) contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) require.NoError(t, err) - classTrie = commontrie.NewTrieAdapter(classTrie2) - contractTrie = commontrie.NewTrieAdapter(contractTrie2) + classTrie = (*commontrie.TrieAdapter)(classTrie2) + contractTrie = (*commontrie.TrieAdapter)(contractTrie2) } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -859,9 +859,9 @@ func emptyCommonTrie(t *testing.T) commontrie.Trie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return commontrie.NewTrieAdapter(tempTrie) + return (*commontrie.TrieAdapter)(tempTrie) } else { - return commontrie.NewDeprecatedTrieAdapter(emptyTrie(t)) + return (*commontrie.DeprecatedTrieAdapter)(emptyTrie(t)) } } From 37af5896fa0d8f11b58bd8f70e365b2775105313 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 13:30:25 +0200 Subject: [PATCH 30/47] minor fixes after merge --- blockchain/blockchain.go | 11 +++-------- blockchain/blockchain_test.go | 8 -------- cmd/juno/juno.go | 2 ++ node/node.go | 2 ++ rpc/v6/pending_data_wrapper.go | 1 - rpc/v7/pending_data_wrapper.go | 1 - rpc/v8/pending_data_wrapper.go | 1 - rpc/v9/events_test.go | 2 +- rpc/v9/pending_data_wrapper.go | 1 - 9 files changed, 8 insertions(+), 21 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 03d82e6b25..f25317a4c1 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -6,6 +6,9 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" "github.com/NethermindEth/juno/utils" @@ -114,14 +117,6 @@ func (b *Blockchain) Network() *utils.Network { return b.network } -// StateCommitment returns the latest block state commitment. -// If blockchain is empty zero felt is returned. -func (b *Blockchain) StateCommitment() (felt.Felt, error) { - b.listener.OnRead("StateCommitment") - batch := b.database.NewIndexedBatch() // this is a hack because we don't need to write to the db - return core.NewState(batch).Root() -} - // Height returns the latest block height. If blockchain is empty nil is returned. func (b *Blockchain) Height() (uint64, error) { b.listener.OnRead("Height") diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index 96726dc42d..3261517d08 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -252,10 +252,6 @@ func TestStore(t *testing.T) { require.NoError(t, err) assert.Equal(t, block0, headBlock) - root, err := chain.StateCommitment() - require.NoError(t, err) - assert.Equal(t, stateUpdate0.NewRoot, &root) - got0Block, err := chain.BlockByNumber(0) require.NoError(t, err) assert.Equal(t, block0, got0Block) @@ -280,10 +276,6 @@ func TestStore(t *testing.T) { require.NoError(t, err) assert.Equal(t, block1, headBlock) - root, err := chain.StateCommitment() - require.NoError(t, err) - assert.Equal(t, stateUpdate1.NewRoot, &root) - got1Block, err := chain.BlockByNumber(1) require.NoError(t, err) assert.Equal(t, block1, got1Block) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index 3865c46360..6988c60c59 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -149,6 +149,7 @@ const ( defaultSubmittedTransactionsCacheEntryTTL = 5 * time.Minute defaultNewState = false defaultDisableRPCBatchRequests = false + newStateF = "new-state" configFlagUsage = "The YAML configuration file." logLevelFlagUsage = "Options: trace, debug, info, warn, error." @@ -220,6 +221,7 @@ const ( submittedTransactionsCacheSize = "Maximum number of entries in the submitted transactions cache" submittedTransactionsCacheEntryTTL = "Time-to-live for each entry in the submitted transactions cache" disableRPCBatchRequestsUsage = "Disables handling of batched RPC requests." + newStateUsage = "EXPERIMENTAL: Use the new state package implementation" ) var Version string diff --git a/node/node.go b/node/node.go index c667195ade..6d3af3ff2e 100644 --- a/node/node.go +++ b/node/node.go @@ -118,6 +118,8 @@ type Config struct { HTTPUpdatePort uint16 `mapstructure:"http-update-port"` ForbidRPCBatchRequests bool `mapstructure:"disable-rpc-batch-requests"` + + NewState bool `mapstructure:"new-state"` } type Node struct { diff --git a/rpc/v6/pending_data_wrapper.go b/rpc/v6/pending_data_wrapper.go index 6343d394eb..e576fbcd5e 100644 --- a/rpc/v6/pending_data_wrapper.go +++ b/rpc/v6/pending_data_wrapper.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/sync" ) func (h *Handler) PendingData() (core.PendingData, error) { diff --git a/rpc/v7/pending_data_wrapper.go b/rpc/v7/pending_data_wrapper.go index 05c2b9636d..6b661adc6a 100644 --- a/rpc/v7/pending_data_wrapper.go +++ b/rpc/v7/pending_data_wrapper.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/sync" ) func (h *Handler) PendingData() (core.PendingData, error) { diff --git a/rpc/v8/pending_data_wrapper.go b/rpc/v8/pending_data_wrapper.go index 1d922cd85c..7f76bd98a8 100644 --- a/rpc/v8/pending_data_wrapper.go +++ b/rpc/v8/pending_data_wrapper.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/sync" ) func (h *Handler) PendingData() (core.PendingData, error) { diff --git a/rpc/v9/events_test.go b/rpc/v9/events_test.go index 8164c51c4a..6fcd32fb8a 100644 --- a/rpc/v9/events_test.go +++ b/rpc/v9/events_test.go @@ -206,7 +206,7 @@ func setupTestChain( ) (*blockchain.Blockchain, *adaptfeeder.Feeder) { t.Helper() testDB := memory.New() - chain := blockchain.New(testDB, network) + chain := blockchain.New(testDB, network, statetestutils.UseNewState()) client := feeder.NewTestClient(t, network) gw := adaptfeeder.New(client) diff --git a/rpc/v9/pending_data_wrapper.go b/rpc/v9/pending_data_wrapper.go index 942e38f5a9..d9b15077fb 100644 --- a/rpc/v9/pending_data_wrapper.go +++ b/rpc/v9/pending_data_wrapper.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/sync" ) func (h *Handler) PendingData() (core.PendingData, error) { From 934a6eccab8e4d08f7d6db3e59913767c007e2b4 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 14:36:37 +0200 Subject: [PATCH 31/47] refactor(juno): integrate common state with rawdb for the new state --- core/state/state.go | 76 ++++++++--------- core/trie2/triedb/database.go | 15 ++-- core/trie2/triedb/rawdb/database.go | 121 ++++++++++++++++++++++++++++ core/trie2/triedb/rawdb/reader.go | 22 +++++ core/trie2/triedb/rawdb/types.go | 15 ++++ core/trie2/trieutils/accessors.go | 2 +- sync/sync_test.go | 10 +++ 7 files changed, 211 insertions(+), 50 deletions(-) create mode 100644 core/trie2/triedb/rawdb/database.go create mode 100644 core/trie2/triedb/rawdb/reader.go create mode 100644 core/trie2/triedb/rawdb/types.go diff --git a/core/state/state.go b/core/state/state.go index f84d4d40b4..42aa18b108 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -492,64 +492,54 @@ func (s *State) flush( classes map[felt.Felt]core.Class, storeHistory bool, ) error { - p := pool.New().WithMaxGoroutines(runtime.GOMAXPROCS(0)).WithErrors() - - p.Go(func() error { - return s.db.triedb.Update(&update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes) - }) - batch := s.db.disk.NewBatch() - p.Go(func() error { - for addr, obj := range s.stateObjects { - if obj == nil { // marked as deleted - if err := DeleteContract(batch, &addr); err != nil { - return err - } - - // TODO(weiihann): handle hash-based, and there should be better ways of doing this - if err := trieutils.DeleteStorageNodesByPath(batch, addr); err != nil { - return err - } - } else { // updated - if err := WriteContract(batch, &addr, obj.contract); err != nil { - return err - } + if err := s.db.triedb.Update(&update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes, batch); err != nil { + return err + } - if storeHistory { - for key, val := range obj.dirtyStorage { - if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil { - return err - } - } + for addr, obj := range s.stateObjects { + if obj == nil { // marked as deleted + if err := DeleteContract(batch, &addr); err != nil { + return err + } - if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil { - return err - } + // TODO(weiihann): handle hash-based, and there should be better ways of doing this + if err := trieutils.DeleteStorageNodesByPath(batch, addr); err != nil { + return err + } + } else { // updated + if err := WriteContract(batch, &addr, obj.contract); err != nil { + return err + } - if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { + if storeHistory { + for key, val := range obj.dirtyStorage { + if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil { return err } } - } - } - for classHash, class := range classes { - if class == nil { // mark as deleted - if err := DeleteClass(batch, &classHash); err != nil { + if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil { return err } - } else { - if err := WriteClass(batch, &classHash, class, blockNum); err != nil { + + if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { return err } } } + } - return nil - }) - - if err := p.Wait(); err != nil { - return err + for classHash, class := range classes { + if class == nil { // mark as deleted + if err := DeleteClass(batch, &classHash); err != nil { + return err + } + } else { + if err := WriteClass(batch, &classHash, class, blockNum); err != nil { + return err + } + } } return batch.Write() diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go index 6a88c6cd45..b3b16ebfd5 100644 --- a/core/trie2/triedb/database.go +++ b/core/trie2/triedb/database.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/core/trie2/triedb/database" "github.com/NethermindEth/juno/core/trie2/triedb/hashdb" "github.com/NethermindEth/juno/core/trie2/triedb/pathdb" + "github.com/NethermindEth/juno/core/trie2/triedb/rawdb" "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" @@ -15,11 +16,13 @@ import ( const ( PathScheme string = "path" HashScheme string = "hash" + RawScheme string = "raw" ) type Config struct { PathConfig *pathdb.Config HashConfig *hashdb.Config + RawConfig *rawdb.Config } type Database struct { @@ -30,12 +33,9 @@ type Database struct { func New(disk db.KeyValueStore, config *Config) (*Database, error) { var triedb database.TrieDB var err error - // Default to path config if not provided + // Default to raw config if not provided if config == nil { - triedb, err = pathdb.New(disk, nil) - if err != nil { - return nil, err - } + triedb = rawdb.New(disk) } else if config.PathConfig != nil { triedb, err = pathdb.New(disk, config.PathConfig) if err != nil { @@ -57,12 +57,15 @@ func (d *Database) Update( blockNum uint64, mergeClassNodes, mergeContractNodes *trienode.MergeNodeSet, + w db.KeyValueWriter, ) error { switch td := d.triedb.(type) { case *pathdb.Database: return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes) case *hashdb.Database: return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes) + case *rawdb.Database: + return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes, w) default: return fmt.Errorf("unsupported trie db type: %T", td) } @@ -78,7 +81,7 @@ func (d *Database) Journal(root *felt.Felt) error { func (d *Database) Scheme() string { if d.config == nil { - return PathScheme + return RawScheme } else if d.config.PathConfig != nil { return PathScheme } diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go new file mode 100644 index 0000000000..3dd0953dfe --- /dev/null +++ b/core/trie2/triedb/rawdb/database.go @@ -0,0 +1,121 @@ +package rawdb + +import ( + "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" +) + +type Config struct { +} + +type Database struct { + disk db.KeyValueStore + + lock sync.RWMutex + log utils.SimpleLogger +} + +func New(disk db.KeyValueStore) *Database { + return &Database{ + disk: disk, + log: utils.NewNopZapLogger(), + } +} + +func (d *Database) readNode(id trieutils.TrieID, owner *felt.Felt, path *trieutils.Path, isLeaf bool) ([]byte, error) { + d.lock.RLock() + defer d.lock.RUnlock() + blob, err := trieutils.GetNodeByPath(d.disk, id.Bucket(), owner, path, isLeaf) + if err != nil { + return nil, err + } + + return blob, nil +} + +func (d *Database) NewIterator(id trieutils.TrieID) (db.Iterator, error) { + key := id.Bucket().Key() + owner := id.Owner() + if !owner.Equal(&felt.Zero) { + oBytes := owner.Bytes() + key = append(key, oBytes[:]...) + } + + return d.disk.NewIterator(key, true) +} + +func (d *Database) Commit(_ *felt.Felt) error { + return nil +} + +func (d *Database) Update( + root, + parent *felt.Felt, + blockNum uint64, + mergedClassNodes *trienode.MergeNodeSet, + mergedContractNodes *trienode.MergeNodeSet, + batch db.KeyValueWriter, +) error { + d.lock.Lock() + defer d.lock.Unlock() + + var classNodes classNodesMap + var contractNodes contractNodesMap + var contractStorageNodes contractStorageNodesMap + + if mergedClassNodes != nil { + classNodes, _ = mergedClassNodes.Flatten() + } + if mergedContractNodes != nil { + contractNodes, contractStorageNodes = mergedContractNodes.Flatten() + } + + for path, n := range classNodes { + if _, deleted := n.(*trienode.DeletedNode); deleted { + if err := trieutils.DeleteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf()); err != nil { + return err + } + } else { + if err := trieutils.WriteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { + return err + } + } + } + + for path, n := range contractNodes { + if _, deleted := n.(*trienode.DeletedNode); deleted { + if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf()); err != nil { + return err + } + } else { + if err := trieutils.WriteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { + return err + } + } + } + + for owner, nodes := range contractStorageNodes { + for path, n := range nodes { + if _, deleted := n.(*trienode.DeletedNode); deleted { + if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf()); err != nil { + return err + } + } else { + if err := trieutils.WriteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf(), n.Blob()); err != nil { + return err + } + } + } + } + + return nil +} + +func (d *Database) Close() error { + return nil +} diff --git a/core/trie2/triedb/rawdb/reader.go b/core/trie2/triedb/rawdb/reader.go new file mode 100644 index 0000000000..f4cf147a7a --- /dev/null +++ b/core/trie2/triedb/rawdb/reader.go @@ -0,0 +1,22 @@ +package rawdb + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/triedb/database" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +var _ database.NodeReader = (*reader)(nil) + +type reader struct { + id trieutils.TrieID + d *Database +} + +func (r *reader) Node(owner *felt.Felt, path *trieutils.Path, hash *felt.Felt, isLeaf bool) ([]byte, error) { + return r.d.readNode(r.id, owner, path, isLeaf) +} + +func (d *Database) NodeReader(id trieutils.TrieID) (database.NodeReader, error) { + return &reader{d: d, id: id}, nil +} diff --git a/core/trie2/triedb/rawdb/types.go b/core/trie2/triedb/rawdb/types.go new file mode 100644 index 0000000000..abebc7a3c2 --- /dev/null +++ b/core/trie2/triedb/rawdb/types.go @@ -0,0 +1,15 @@ +package rawdb + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +type ( + classNodesMap = map[trieutils.Path]trienode.TrieNode + contractNodesMap = map[trieutils.Path]trienode.TrieNode + contractStorageNodesMap = map[felt.Felt]map[trieutils.Path]trienode.TrieNode +) + +const ownerSize = felt.Bytes diff --git a/core/trie2/trieutils/accessors.go b/core/trie2/trieutils/accessors.go index 6d56c47f93..ff04845b23 100644 --- a/core/trie2/trieutils/accessors.go +++ b/core/trie2/trieutils/accessors.go @@ -13,7 +13,7 @@ func GetNodeByPath(r db.KeyValueReader, bucket db.Bucket, owner *felt.Felt, path var res []byte if err := r.Get(nodeKeyByPath(bucket, owner, path, isLeaf), func(value []byte) error { - res = append([]byte(nil), value...) + res = slices.Clone(value) return nil }, ); err != nil { diff --git a/sync/sync_test.go b/sync/sync_test.go index f6bcb65404..2d97bf0a60 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -336,6 +336,7 @@ func TestPendingData(t *testing.T) { Block: preConfirmedB, StateUpdate: &core.StateUpdate{ StateDiff: &emptyStateDiff, + OldRoot: su.OldRoot, }, } @@ -361,6 +362,11 @@ func TestPendingData(t *testing.T) { require.NoError(t, err) storageKey := &felt.One + // Get the state root after block 0 to use as OldRoot for block 1 + head, err := chain.Head() + require.NoError(t, err) + oldRoot := head.GlobalStateRoot + t.Run("Without Prelatest", func(t *testing.T) { numTxs := 10 preConfirmed := makePreConfirmedWithIncrementingCounter( @@ -369,6 +375,7 @@ func TestPendingData(t *testing.T) { contractAddress, storageKey, 0, + oldRoot, ) isWritten, err := synchronizer.StorePreConfirmed(preConfirmed) require.NoError(t, err) @@ -413,6 +420,7 @@ func TestPendingData(t *testing.T) { contractAddress, storageKey, 0, + oldRoot, ) preConfirmed.WithPreLatest(&preLatest) isWritten, err := synchronizer.StorePreConfirmed(preConfirmed) @@ -507,6 +515,7 @@ func makePreConfirmedWithIncrementingCounter( contractAddr *felt.Felt, storageKey *felt.Felt, startingNonce uint64, + oldRoot *felt.Felt, ) *core.PreConfirmed { transactions := make([]core.Transaction, numTxs) receipts := make([]*core.TransactionReceipt, numTxs) @@ -562,6 +571,7 @@ func makePreConfirmedWithIncrementingCounter( TransactionStateDiffs: stateDiffs, StateUpdate: &core.StateUpdate{ StateDiff: &aggregatedStateDiff, + OldRoot: oldRoot, }, } } From 7e80726ccecff83ab845ca772edb69d51916be0a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 14:42:32 +0200 Subject: [PATCH 32/47] add clean cache to the rawdb --- core/trie2/triedb/database.go | 4 +- core/trie2/triedb/rawdb/cache.go | 59 +++++++++++++++++++++++++++++ core/trie2/triedb/rawdb/database.go | 32 ++++++++++++++-- 3 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 core/trie2/triedb/rawdb/cache.go diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go index b3b16ebfd5..d2020c30e1 100644 --- a/core/trie2/triedb/database.go +++ b/core/trie2/triedb/database.go @@ -33,9 +33,9 @@ type Database struct { func New(disk db.KeyValueStore, config *Config) (*Database, error) { var triedb database.TrieDB var err error - // Default to raw config if not provided + // Default to path config if not provided if config == nil { - triedb = rawdb.New(disk) + triedb = rawdb.New(disk, nil) } else if config.PathConfig != nil { triedb, err = pathdb.New(disk, config.PathConfig) if err != nil { diff --git a/core/trie2/triedb/rawdb/cache.go b/core/trie2/triedb/rawdb/cache.go new file mode 100644 index 0000000000..9974b553e6 --- /dev/null +++ b/core/trie2/triedb/rawdb/cache.go @@ -0,0 +1,59 @@ +package rawdb + +import ( + "math" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/VictoriaMetrics/fastcache" +) + +const nodeCacheSize = ownerSize + trieutils.PathSize + 1 + +type trieType byte + +const ( + contract trieType = iota + class +) + +// Stores committed trie nodes in memory +type cleanCache struct { + cache *fastcache.Cache // map[nodeKey]node +} + +// Creates a new clean cache with the given size. The size is the maximum size of the cache in bytes. +func newCleanCache(size uint64) cleanCache { + if size > uint64(math.MaxInt) { + panic("cache size too large: uint64 to int conversion would overflow") + } + return cleanCache{ + cache: fastcache.New(int(size)), + } +} + +func (c *cleanCache) putNode(owner *felt.Felt, path *trieutils.Path, isClass bool, blob []byte) { + c.cache.Set(nodeKey(owner, path, isClass), blob) +} + +func (c *cleanCache) getNode(owner *felt.Felt, path *trieutils.Path, isClass bool) []byte { + return c.cache.Get(nil, nodeKey(owner, path, isClass)) +} + +func (c *cleanCache) deleteNode(owner *felt.Felt, path *trieutils.Path, isClass bool) { + c.cache.Del(nodeKey(owner, path, isClass)) +} + +// key = owner (32 bytes) + path (20 bytes) + trie type (1 byte) +func nodeKey(owner *felt.Felt, path *trieutils.Path, isClass bool) []byte { + key := make([]byte, nodeCacheSize) + ownerBytes := owner.Bytes() + copy(key[:felt.Bytes], ownerBytes[:]) + copy(key[felt.Bytes:felt.Bytes+trieutils.PathSize], path.EncodedBytes()) + + if isClass { + key[nodeCacheSize-1] = byte(class) + } + + return key +} diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index 3dd0953dfe..673a568a49 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -11,6 +11,7 @@ import ( ) type Config struct { + CleanCacheSize uint64 // Maximum size (in bytes) for caching clean nodes } type Database struct { @@ -18,23 +19,42 @@ type Database struct { lock sync.RWMutex log utils.SimpleLogger + + config Config + cleanCache *cleanCache } -func New(disk db.KeyValueStore) *Database { +func New(disk db.KeyValueStore, config *Config) *Database { + if config == nil { + config = &Config{ + CleanCacheSize: 16 * utils.Megabyte, + } + } + cleanCache := newCleanCache(config.CleanCacheSize) return &Database{ - disk: disk, - log: utils.NewNopZapLogger(), + disk: disk, + config: *config, + cleanCache: &cleanCache, + log: utils.NewNopZapLogger(), } } func (d *Database) readNode(id trieutils.TrieID, owner *felt.Felt, path *trieutils.Path, isLeaf bool) ([]byte, error) { d.lock.RLock() defer d.lock.RUnlock() + + isClass := id.Type() == trieutils.Class + blob := d.cleanCache.getNode(owner, path, isClass) + if blob != nil { + return blob, nil + } + blob, err := trieutils.GetNodeByPath(d.disk, id.Bucket(), owner, path, isLeaf) if err != nil { return nil, err } + d.cleanCache.putNode(owner, path, isClass, blob) return blob, nil } @@ -80,10 +100,12 @@ func (d *Database) Update( if err := trieutils.DeleteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf()); err != nil { return err } + d.cleanCache.deleteNode(&felt.Zero, &path, true) } else { if err := trieutils.WriteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { return err } + d.cleanCache.putNode(&felt.Zero, &path, true, n.Blob()) } } @@ -92,10 +114,12 @@ func (d *Database) Update( if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf()); err != nil { return err } + d.cleanCache.deleteNode(&felt.Zero, &path, false) } else { if err := trieutils.WriteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { return err } + d.cleanCache.putNode(&felt.Zero, &path, false, n.Blob()) } } @@ -105,10 +129,12 @@ func (d *Database) Update( if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf()); err != nil { return err } + d.cleanCache.deleteNode(&owner, &path, false) } else { if err := trieutils.WriteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf(), n.Blob()); err != nil { return err } + d.cleanCache.putNode(&owner, &path, false, n.Blob()) } } } From 1457e7d5079e3ac6017d78db67606766254f8b9e Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 16:11:00 +0200 Subject: [PATCH 33/47] linter --- core/trie2/triedb/rawdb/database.go | 52 ++++++++++++++--------------- core/trie2/triedb/rawdb/types.go | 2 -- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index 3dd0953dfe..ebe3d61409 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -10,8 +10,7 @@ import ( "github.com/NethermindEth/juno/utils" ) -type Config struct { -} +type Config struct{} type Database struct { disk db.KeyValueStore @@ -76,39 +75,21 @@ func (d *Database) Update( } for path, n := range classNodes { - if _, deleted := n.(*trienode.DeletedNode); deleted { - if err := trieutils.DeleteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf()); err != nil { - return err - } - } else { - if err := trieutils.WriteNodeByPath(batch, db.ClassTrie, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { - return err - } + if err := d.updateNode(batch, db.ClassTrie, &felt.Zero, &path, n); err != nil { + return err } } for path, n := range contractNodes { - if _, deleted := n.(*trienode.DeletedNode); deleted { - if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf()); err != nil { - return err - } - } else { - if err := trieutils.WriteNodeByPath(batch, db.ContractTrieContract, &felt.Zero, &path, n.IsLeaf(), n.Blob()); err != nil { - return err - } + if err := d.updateNode(batch, db.ContractTrieContract, &felt.Zero, &path, n); err != nil { + return err } } for owner, nodes := range contractStorageNodes { for path, n := range nodes { - if _, deleted := n.(*trienode.DeletedNode); deleted { - if err := trieutils.DeleteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf()); err != nil { - return err - } - } else { - if err := trieutils.WriteNodeByPath(batch, db.ContractTrieStorage, &owner, &path, n.IsLeaf(), n.Blob()); err != nil { - return err - } + if err := d.updateNode(batch, db.ContractTrieStorage, &owner, &path, n); err != nil { + return err } } } @@ -116,6 +97,25 @@ func (d *Database) Update( return nil } +func (d *Database) updateNode( + batch db.KeyValueWriter, + bucket db.Bucket, + owner *felt.Felt, + path *trieutils.Path, + n trienode.TrieNode, +) error { + if _, deleted := n.(*trienode.DeletedNode); deleted { + if err := trieutils.DeleteNodeByPath(batch, bucket, owner, path, n.IsLeaf()); err != nil { + return err + } + } else { + if err := trieutils.WriteNodeByPath(batch, bucket, owner, path, n.IsLeaf(), n.Blob()); err != nil { + return err + } + } + return nil +} + func (d *Database) Close() error { return nil } diff --git a/core/trie2/triedb/rawdb/types.go b/core/trie2/triedb/rawdb/types.go index abebc7a3c2..83166bf671 100644 --- a/core/trie2/triedb/rawdb/types.go +++ b/core/trie2/triedb/rawdb/types.go @@ -11,5 +11,3 @@ type ( contractNodesMap = map[trieutils.Path]trienode.TrieNode contractStorageNodesMap = map[felt.Felt]map[trieutils.Path]trienode.TrieNode ) - -const ownerSize = felt.Bytes From 566d8a7d1d7f8f4bca01e37f366ff2300aa0a9a3 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sat, 18 Oct 2025 16:14:21 +0200 Subject: [PATCH 34/47] lint --- core/trie2/triedb/rawdb/types.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/trie2/triedb/rawdb/types.go b/core/trie2/triedb/rawdb/types.go index 83166bf671..abebc7a3c2 100644 --- a/core/trie2/triedb/rawdb/types.go +++ b/core/trie2/triedb/rawdb/types.go @@ -11,3 +11,5 @@ type ( contractNodesMap = map[trieutils.Path]trienode.TrieNode contractStorageNodesMap = map[felt.Felt]map[trieutils.Path]trienode.TrieNode ) + +const ownerSize = felt.Bytes From 1d07efcd7526666ea420fd9d52e49e0dc5991d6b Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 20 Oct 2025 15:28:10 +0200 Subject: [PATCH 35/47] fix unit tests --- core/trie2/triedb/rawdb/database.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index 27c34bb2c3..99079a4a6d 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -96,20 +96,20 @@ func (d *Database) Update( } for path, n := range classNodes { - if err := d.updateNode(batch, db.ClassTrie, &felt.Zero, &path, n); err != nil { + if err := d.updateNode(batch, db.ClassTrie, &felt.Zero, &path, n, true); err != nil { return err } } for path, n := range contractNodes { - if err := d.updateNode(batch, db.ContractTrieContract, &felt.Zero, &path, n); err != nil { + if err := d.updateNode(batch, db.ContractTrieContract, &felt.Zero, &path, n, false); err != nil { return err } } for owner, nodes := range contractStorageNodes { for path, n := range nodes { - if err := d.updateNode(batch, db.ContractTrieStorage, &owner, &path, n); err != nil { + if err := d.updateNode(batch, db.ContractTrieStorage, &owner, &path, n, false); err != nil { return err } } @@ -124,17 +124,18 @@ func (d *Database) updateNode( owner *felt.Felt, path *trieutils.Path, n trienode.TrieNode, + isClass bool, ) error { if _, deleted := n.(*trienode.DeletedNode); deleted { if err := trieutils.DeleteNodeByPath(batch, bucket, owner, path, n.IsLeaf()); err != nil { return err } - d.cleanCache.deleteNode(owner, path, true) + d.cleanCache.deleteNode(owner, path, isClass) } else { if err := trieutils.WriteNodeByPath(batch, bucket, owner, path, n.IsLeaf(), n.Blob()); err != nil { return err } - d.cleanCache.putNode(owner, path, true, n.Blob()) + d.cleanCache.putNode(owner, path, isClass, n.Blob()) } return nil } From bb078d033c70e1bcd8961b266deeb430fd7c2985 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 20 Oct 2025 15:49:16 +0200 Subject: [PATCH 36/47] lint, tests --- blockchain/blockchain.go | 15 +++- blockchain/blockchain_test.go | 8 +- builder/builder.go | 4 +- .../p2p/validator/empty_fixtures_test.go | 6 +- consensus/p2p/validator/fixtures_test.go | 4 +- core/accessors.go | 7 +- core/state/commonstate/deprecated_state.go | 8 +- core/state/commonstate/state.go | 8 +- core/state/object.go | 9 ++- core/state/state.go | 9 ++- core/transaction.go | 80 ++++++++++--------- core/trie2/triedb/rawdb/database.go | 16 +++- core/trie2/triedb/rawdb/reader.go | 7 +- rpc/v6/estimate_fee_test.go | 5 +- rpc/v6/helpers.go | 4 +- rpc/v7/helpers.go | 4 +- rpc/v7/storage.go | 3 +- rpc/v8/helpers.go | 4 +- rpc/v8/simulation_test.go | 30 +++++-- rpc/v8/storage.go | 32 ++++++-- rpc/v8/storage_test.go | 30 +++++-- rpc/v8/subscriptions_test.go | 4 +- rpc/v9/helpers.go | 4 +- rpc/v9/simulation_test.go | 30 +++++-- rpc/v9/storage.go | 29 +++++-- rpc/v9/storage_test.go | 30 +++++-- rpc/v9/subscriptions_test.go | 4 +- sequencer/sequencer.go | 4 +- sync/pending.go | 6 +- sync/sync.go | 8 +- vm/vm_test.go | 8 +- 31 files changed, 320 insertions(+), 100 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index f25317a4c1..ab3cb1c89a 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -434,7 +434,9 @@ func (b *Blockchain) HeadState() (commonstate.StateReader, StateCloser, error) { } // StateAtBlockNumber returns a StateReader that provides a stable view to the state at the given block number -func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, StateCloser, error) { +func (b *Blockchain) StateAtBlockNumber( + blockNumber uint64, +) (commonstate.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") txn := b.database.NewIndexedBatch() @@ -467,7 +469,9 @@ func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (commonstate.StateRe } // StateAtBlockHash returns a StateReader that provides a stable view to the state at the given block hash -func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, StateCloser, error) { +func (b *Blockchain) StateAtBlockHash( + blockHash *felt.Felt, +) (commonstate.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockHash") if blockHash.IsZero() { emptyState, err := b.StateFactory.EmptyState() @@ -686,7 +690,12 @@ func (b *Blockchain) revertHead() error { return err } } - if err = core.DeleteTxsAndReceipts(b.database, batch, blockNumber, header.TransactionCount); err != nil { + if err = core.DeleteTxsAndReceipts( + b.database, + batch, + blockNumber, + header.TransactionCount, + ); err != nil { return err } if err = core.DeleteStateUpdateByBlockNum(batch, blockNumber); err != nil { diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index 3261517d08..dca1d77b46 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -301,7 +301,13 @@ func TestStoreL1HandlerTxnHash(t *testing.T) { l1HandlerMsgHash := common.HexToHash("0x42e76df4e3d5255262929c27132bd0d295a8d3db2cfe63d2fcd061c7a7a7ab34") l1HandlerTxnHash, err := chain.L1HandlerTxnHash(&l1HandlerMsgHash) require.NoError(t, err) - require.Equal(t, felt.NewUnsafeFromString[felt.Felt]("0x785c2ada3f53fbc66078d47715c27718f92e6e48b96372b36e5197de69b82b5"), &l1HandlerTxnHash) + require.Equal( + t, + felt.NewUnsafeFromString[felt.Felt]( + "0x785c2ada3f53fbc66078d47715c27718f92e6e48b96372b36e5197de69b82b5", + ), + &l1HandlerTxnHash, + ) } func TestBlockCommitments(t *testing.T) { diff --git a/builder/builder.go b/builder/builder.go index 6c481d10a6..6846a6783e 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -130,7 +130,9 @@ func (b *Builder) getRevealedBlockHash(blockHeight uint64) (*felt.Felt, error) { return header.Hash, nil } -func (b *Builder) PendingState(buildState *BuildState) (commonstate.StateReader, func() error, error) { +func (b *Builder) PendingState( + buildState *BuildState, +) (commonstate.StateReader, func() error, error) { if buildState.Preconfirmed == nil { return nil, nil, core.ErrPendingDataNotFound } diff --git a/consensus/p2p/validator/empty_fixtures_test.go b/consensus/p2p/validator/empty_fixtures_test.go index 32915da18b..37bcd15f36 100644 --- a/consensus/p2p/validator/empty_fixtures_test.go +++ b/consensus/p2p/validator/empty_fixtures_test.go @@ -42,7 +42,11 @@ func NewEmptyTestFixture( executor.RegisterBuildResult(&buildResult) - b := builder.New(blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor) + b := builder.New(blockchain.New( + database, + testCase.Network, + statetestutils.UseNewState(), + ), executor) proposalCommitment := EmptyProposalCommitment(headBlock, proposer, timestamp) diff --git a/consensus/p2p/validator/fixtures_test.go b/consensus/p2p/validator/fixtures_test.go index 80487ea7e1..efe7ffff68 100644 --- a/consensus/p2p/validator/fixtures_test.go +++ b/consensus/p2p/validator/fixtures_test.go @@ -109,7 +109,9 @@ func BuildTestFixture( executor.RegisterBuildResult(&buildResult) - builder := builder.New(blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor) + builder := builder.New( + blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor, + ) return TestFixture{ ProposalInit: &proposalInit, diff --git a/core/accessors.go b/core/accessors.go index 4a98ef3d31..615a711fa8 100644 --- a/core/accessors.go +++ b/core/accessors.go @@ -225,7 +225,12 @@ func GetReceiptByHash(r db.KeyValueReader, hash *felt.Felt) (*TransactionReceipt return GetReceiptByBlockNumIndexBytes(r, val) } -func DeleteTxsAndReceipts(r db.KeyValueReader, batch db.KeyValueWriter, blockNum, numTxs uint64) error { +func DeleteTxsAndReceipts( + r db.KeyValueReader, + batch db.KeyValueWriter, + blockNum, + numTxs uint64, +) error { // remove txs and receipts for i := range numTxs { txn, err := GetTxByBlockNumIndex(r, blockNum, i) diff --git a/core/state/commonstate/deprecated_state.go b/core/state/commonstate/deprecated_state.go index d3104140f5..8af3e95ab8 100644 --- a/core/state/commonstate/deprecated_state.go +++ b/core/state/commonstate/deprecated_state.go @@ -97,7 +97,13 @@ func (s *DeprecatedStateAdapter) Update( skipVerifyNewRoot bool, flushChanges bool, ) error { - return (*core.State)(s).Update(blockNumber, update, declaredClasses, skipVerifyNewRoot, flushChanges) + return (*core.State)(s).Update( + blockNumber, + update, + declaredClasses, + skipVerifyNewRoot, + flushChanges, + ) } type DeprecatedStateReaderAdapter struct { diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 69977dbe03..13580cd7cd 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -90,7 +90,13 @@ func (s *StateAdapter) Update( skipVerifyNewRoot bool, flushChanges bool, ) error { - return (*state.State)(s).Update(blockNumber, update, declaredClasses, skipVerifyNewRoot, flushChanges) + return (*state.State)(s).Update( + blockNumber, + update, + declaredClasses, + skipVerifyNewRoot, + flushChanges, + ) } func (s *StateAdapter) ContractClassHash(addr *felt.Felt) (felt.Felt, error) { diff --git a/core/state/object.go b/core/state/object.go index b267c95f34..7da7831bd8 100644 --- a/core/state/object.go +++ b/core/state/object.go @@ -51,7 +51,9 @@ func (s *stateObject) getStorage(key *felt.Felt) (felt.Felt, error) { // TODO(maksym): test if this works instead of reading from disk path := tr.FeltToPath(key) - reader, err := s.state.db.triedb.NodeReader(trieutils.NewContractStorageTrieID(s.state.initRoot, s.addr)) + reader, err := s.state.db.triedb.NodeReader( + trieutils.NewContractStorageTrieID(s.state.initRoot, s.addr), + ) if err != nil { return felt.Zero, err } @@ -71,7 +73,10 @@ func (s *stateObject) getStorageTrie() (*trie2.Trie, error) { return s.storageTrie, nil } - storageTrie, err := s.state.db.ContractStorageTrie(&s.state.initRoot, &s.addr) + storageTrie, err := s.state.db.ContractStorageTrie( + &s.state.initRoot, + &s.addr, + ) if err != nil { return nil, err } diff --git a/core/state/state.go b/core/state/state.go index 42aa18b108..906bccb839 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -493,7 +493,14 @@ func (s *State) flush( storeHistory bool, ) error { batch := s.db.disk.NewBatch() - if err := s.db.triedb.Update(&update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes, batch); err != nil { + if err := s.db.triedb.Update( + &update.curComm, + &update.prevComm, + blockNum, + update.classNodes, + update.contractNodes, + batch, + ); err != nil { return err } diff --git a/core/transaction.go b/core/transaction.go index 9bf015cddb..efa02826a2 100644 --- a/core/transaction.go +++ b/core/transaction.go @@ -674,32 +674,38 @@ func transactionCommitmentPedersen(transactions []Transaction, protocolVersion s // transactionCommitmentPoseidon0134 handles empty signatures compared to transactionCommitmentPoseidon0132: // empty signatures are interpreted as [] instead of [0] func transactionCommitmentPoseidon0134(transactions []Transaction) (*felt.Felt, error) { - return calculateCommitment(transactions, trie2.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { - var digest crypto.PoseidonDigest - digest.Update(transaction.Hash()) - - if txSignature := transaction.Signature(); len(txSignature) > 0 { - digest.Update(txSignature...) - } + return calculateCommitment( + transactions, + trie2.RunOnTempTriePoseidon, + func(transaction Transaction) *felt.Felt { + var digest crypto.PoseidonDigest + digest.Update(transaction.Hash()) + + if txSignature := transaction.Signature(); len(txSignature) > 0 { + digest.Update(txSignature...) + } - return digest.Finish() - }) + return digest.Finish() + }) } // transactionCommitmentPoseidon0132 is used to calculate tx commitment for 0.13.2 <= block.version < 0.13.4 func transactionCommitmentPoseidon0132(transactions []Transaction) (*felt.Felt, error) { - return calculateCommitment(transactions, trie2.RunOnTempTriePoseidon, func(transaction Transaction) *felt.Felt { - var digest crypto.PoseidonDigest - digest.Update(transaction.Hash()) - - if txSignature := transaction.Signature(); len(txSignature) > 0 { - digest.Update(txSignature...) - } else { - digest.Update(&felt.Zero) - } + return calculateCommitment( + transactions, + trie2.RunOnTempTriePoseidon, + func(transaction Transaction) *felt.Felt { + var digest crypto.PoseidonDigest + digest.Update(transaction.Hash()) + + if txSignature := transaction.Signature(); len(txSignature) > 0 { + digest.Update(txSignature...) + } else { + digest.Update(&felt.Zero) + } - return digest.Finish() - }) + return digest.Finish() + }) } type eventWithTxHash struct { @@ -722,22 +728,24 @@ func eventCommitmentPoseidon(receipts []*TransactionReceipt) (*felt.Felt, error) }) } } - return calculateCommitment(items, trie2.RunOnTempTriePoseidon, func(item *eventWithTxHash) *felt.Felt { - return crypto.PoseidonArray( - slices.Concat( - []*felt.Felt{ - item.Event.From, - item.TxHash, - new(felt.Felt).SetUint64(uint64(len(item.Event.Keys))), - }, - item.Event.Keys, - []*felt.Felt{ - new(felt.Felt).SetUint64(uint64(len(item.Event.Data))), - }, - item.Event.Data, - )..., - ) - }) + return calculateCommitment(items, + trie2.RunOnTempTriePoseidon, + func(item *eventWithTxHash) *felt.Felt { + return crypto.PoseidonArray( + slices.Concat( + []*felt.Felt{ + item.Event.From, + item.TxHash, + new(felt.Felt).SetUint64(uint64(len(item.Event.Keys))), + }, + item.Event.Keys, + []*felt.Felt{ + new(felt.Felt).SetUint64(uint64(len(item.Event.Data))), + }, + item.Event.Data, + )..., + ) + }) } // eventCommitmentPedersen computes the event commitment for a block. diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index ebe3d61409..6ace8dce01 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -26,7 +26,12 @@ func New(disk db.KeyValueStore) *Database { } } -func (d *Database) readNode(id trieutils.TrieID, owner *felt.Felt, path *trieutils.Path, isLeaf bool) ([]byte, error) { +func (d *Database) readNode( + id trieutils.TrieID, + owner *felt.Felt, + path *trieutils.Path, + isLeaf bool, +) ([]byte, error) { d.lock.RLock() defer d.lock.RUnlock() blob, err := trieutils.GetNodeByPath(d.disk, id.Bucket(), owner, path, isLeaf) @@ -109,7 +114,14 @@ func (d *Database) updateNode( return err } } else { - if err := trieutils.WriteNodeByPath(batch, bucket, owner, path, n.IsLeaf(), n.Blob()); err != nil { + if err := trieutils.WriteNodeByPath( + batch, + bucket, + owner, + path, + n.IsLeaf(), + n.Blob(), + ); err != nil { return err } } diff --git a/core/trie2/triedb/rawdb/reader.go b/core/trie2/triedb/rawdb/reader.go index f4cf147a7a..a1ffd1eef9 100644 --- a/core/trie2/triedb/rawdb/reader.go +++ b/core/trie2/triedb/rawdb/reader.go @@ -13,7 +13,12 @@ type reader struct { d *Database } -func (r *reader) Node(owner *felt.Felt, path *trieutils.Path, hash *felt.Felt, isLeaf bool) ([]byte, error) { +func (r *reader) Node( + owner *felt.Felt, + path *trieutils.Path, + hash *felt.Felt, + isLeaf bool, +) ([]byte, error) { return r.d.readNode(r.id, owner, path, isLeaf) } diff --git a/rpc/v6/estimate_fee_test.go b/rpc/v6/estimate_fee_test.go index 1fc043430d..75ba2e4528 100644 --- a/rpc/v6/estimate_fee_test.go +++ b/rpc/v6/estimate_fee_test.go @@ -56,8 +56,9 @@ func TestEstimateMessageFee(t *testing.T) { mockVM.EXPECT().Execute(gomock.Any(), gomock.Any(), gomock.Any(), &vm.BlockInfo{ Header: latestHeader, }, gomock.Any(), gomock.Any(), false, true, false, true).DoAndReturn( - func(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, - state commonstate.StateReader, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, + func(txns []core.Transaction, declaredClasses []core.Class, + paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state commonstate.StateReader, + skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch bool, ) (vm.ExecutionResults, error) { require.Len(t, txns, 1) assert.NotNil(t, txns[0].(*core.L1HandlerTransaction)) diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index ac63142aaf..f1ddc3f9ed 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -149,7 +149,9 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(id *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { +func (h *Handler) stateByBlockID( + id *BlockID, +) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { var reader commonstate.StateReader var closer blockchain.StateCloser var err error diff --git a/rpc/v7/helpers.go b/rpc/v7/helpers.go index afb2ed7789..e63e0c7814 100644 --- a/rpc/v7/helpers.go +++ b/rpc/v7/helpers.go @@ -159,7 +159,9 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(id *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { +func (h *Handler) stateByBlockID( + id *BlockID, +) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { var reader commonstate.StateReader var closer blockchain.StateCloser var err error diff --git a/rpc/v7/storage.go b/rpc/v7/storage.go index 2797e8d1d1..2cce3ca7f7 100644 --- a/rpc/v7/storage.go +++ b/rpc/v7/storage.go @@ -29,7 +29,8 @@ func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *js // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(&address) if err != nil { - // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. + // Remove db.ErrKeyNotFound after integration if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } diff --git a/rpc/v8/helpers.go b/rpc/v8/helpers.go index 9e0ee195a5..18f8d002b3 100644 --- a/rpc/v8/helpers.go +++ b/rpc/v8/helpers.go @@ -137,7 +137,9 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(blockID *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { +func (h *Handler) stateByBlockID( + blockID *BlockID, +) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { var reader commonstate.StateReader var closer blockchain.StateCloser var err error diff --git a/rpc/v8/simulation_test.go b/rpc/v8/simulation_test.go index c2b5ee8460..adf91f1828 100644 --- a/rpc/v8/simulation_test.go +++ b/rpc/v8/simulation_test.go @@ -36,7 +36,11 @@ func TestSimulateTransactions(t *testing.T) { PriceInFri: &felt.Zero, }, } - defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateReader) { + defaultMockBehavior := func( + mockReader *mocks.MockReader, + _ *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) @@ -52,7 +56,11 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip fee", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -71,7 +79,11 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip validate", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -89,7 +101,11 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "transaction execution error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -107,7 +123,11 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "inconsistent lengths error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 423f66d3ed..8cb856f5d9 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -41,7 +41,8 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(address) if err != nil { - // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. + // Remove db.ErrKeyNotFound after integration if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } @@ -234,7 +235,11 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } } -func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProof( + tr commontrie.Trie, + state commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { switch t := tr.(type) { case *commontrie.DeprecatedTrieAdapter: return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) @@ -245,7 +250,11 @@ func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contrac } } -func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithDeprecatedTrie( + tr *trie.Trie, + state commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -261,7 +270,8 @@ func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateRe nonce, err := state.ContractNonce(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, db.ErrKeyNotFound) { continue } return nil, err @@ -285,7 +295,11 @@ func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateRe }, nil } -func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie( + tr *trie2.Trie, + st commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) for i, contract := range contracts { @@ -297,7 +311,8 @@ func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contra nonce, err := st.ContractNonce(&contract) if err != nil { - if errors.Is(err, state.ErrContractNotDeployed) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, state.ErrContractNotDeployed) { continue } return nil, err @@ -321,7 +336,10 @@ func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contra }, nil } -func getContractStorageProof(state commonstate.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { +func getContractStorageProof( + state commonstate.StateReader, + storageKeys []StorageKeys, +) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) for i, storageKey := range storageKeys { contractStorageTrie, err := state.ContractStorageTrie(storageKey.Contract) diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 2ce9245da9..f5e3ef1a82 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -182,7 +182,11 @@ func TestStorageProof(t *testing.T) { contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) - createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { + createTrie := func( + t *testing.T, + id trieutils.TrieID, + trieDB *trie2.TestNodeDatabase, + ) *trie2.Trie { tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) _ = tr.Update(key, value) _ = tr.Update(key2, value2) @@ -201,9 +205,19 @@ func TestStorageProof(t *testing.T) { trieRoot = tmpTrieRoot // recreate because the previous ones are committed - classTrie2, err := trie2.New(trieutils.NewClassTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + classTrie2, err := trie2.New( + trieutils.NewClassTrieID(*newComm), + 251, + crypto.Pedersen, + &trieDB, + ) require.NoError(t, err) - contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + contractTrie2, err = trie2.New( + trieutils.NewContractTrieID(*newComm), + 251, + crypto.Pedersen, + &trieDB, + ) require.NoError(t, err) classTrie = (*commontrie.TrieAdapter)(classTrie2) contractTrie = (*commontrie.TrieAdapter)(contractTrie2) @@ -339,8 +353,14 @@ func TestStorageProof(t *testing.T) { }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { if statetestutils.UseNewState() { - mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(1) - mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(0) + mockState.EXPECT().ContractNonce(noSuchKey).Return( + felt.Zero, + state.ErrContractNotDeployed, + ).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return( + felt.Zero, + state.ErrContractNotDeployed, + ).Times(0) } else { mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) diff --git a/rpc/v8/subscriptions_test.go b/rpc/v8/subscriptions_test.go index 3f69b8e614..c6eaf11e52 100644 --- a/rpc/v8/subscriptions_test.go +++ b/rpc/v8/subscriptions_test.go @@ -415,7 +415,9 @@ func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, err return nil, nil, nil } -func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingStateBeforeIndex( + index int, +) (commonstate.StateReader, func() error, error) { return nil, nil, nil } diff --git a/rpc/v9/helpers.go b/rpc/v9/helpers.go index 0f7d302689..4739748dfa 100644 --- a/rpc/v9/helpers.go +++ b/rpc/v9/helpers.go @@ -151,7 +151,9 @@ func feeUnit(txn core.Transaction) FeeUnit { return feeUnit } -func (h *Handler) stateByBlockID(blockID *BlockID) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { +func (h *Handler) stateByBlockID( + blockID *BlockID, +) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { var reader commonstate.StateReader var closer blockchain.StateCloser var err error diff --git a/rpc/v9/simulation_test.go b/rpc/v9/simulation_test.go index 6ff01d4b50..e477893280 100644 --- a/rpc/v9/simulation_test.go +++ b/rpc/v9/simulation_test.go @@ -36,7 +36,11 @@ func TestSimulateTransactions(t *testing.T) { PriceInFri: &felt.Zero, }, } - defaultMockBehavior := func(mockReader *mocks.MockReader, _ *mocks.MockVM, mockState *mocks.MockStateReader) { + defaultMockBehavior := func( + mockReader *mocks.MockReader, + _ *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { mockReader.EXPECT().Network().Return(n) mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) mockReader.EXPECT().HeadsHeader().Return(headsHeader, nil) @@ -52,7 +56,11 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip fee", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -71,7 +79,11 @@ func TestSimulateTransactions(t *testing.T) { { //nolint:dupl name: "ok with zero values, skip validate", stepsUsed: 123, - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -89,7 +101,11 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "transaction execution error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, @@ -107,7 +123,11 @@ func TestSimulateTransactions(t *testing.T) { }, { name: "inconsistent lengths error", - mockBehavior: func(mockReader *mocks.MockReader, mockVM *mocks.MockVM, mockState *mocks.MockStateReader) { + mockBehavior: func( + mockReader *mocks.MockReader, + mockVM *mocks.MockVM, + mockState *mocks.MockStateReader, + ) { defaultMockBehavior(mockReader, mockVM, mockState) mockVM.EXPECT().Execute([]core.Transaction{}, nil, []*felt.Felt{}, &vm.BlockInfo{ Header: headsHeader, diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 8cda414d7b..34c8296744 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -41,7 +41,8 @@ func (h *Handler) StorageAt(address, key *felt.Felt, id *BlockID) (*felt.Felt, * // the returned value is always zero and error is nil. _, err := stateReader.ContractClassHash(address) if err != nil { - // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. Remove db.ErrKeyNotFound after integration + // TODO(maksymmalick): state.ErrContractNotDeployed is returned by new state. + // Remove db.ErrKeyNotFound after integration if errors.Is(err, db.ErrKeyNotFound) || errors.Is(err, state.ErrContractNotDeployed) { return nil, rpccore.ErrContractNotFound } @@ -234,7 +235,11 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } } -func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProof( + tr commontrie.Trie, + state commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { switch t := tr.(type) { case *commontrie.DeprecatedTrieAdapter: return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) @@ -245,7 +250,11 @@ func getContractProof(tr commontrie.Trie, state commonstate.StateReader, contrac } } -func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithDeprecatedTrie( + tr *trie.Trie, + state commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -285,7 +294,11 @@ func getContractProofWithDeprecatedTrie(tr *trie.Trie, state commonstate.StateRe }, nil } -func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contracts []felt.Felt) (*ContractProof, error) { +func getContractProofWithTrie( + tr *trie2.Trie, + st commonstate.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) @@ -298,7 +311,8 @@ func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contra nonce, err := st.ContractNonce(&contract) if err != nil { - if errors.Is(err, state.ErrContractNotDeployed) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, state.ErrContractNotDeployed) { continue } return nil, err @@ -322,7 +336,10 @@ func getContractProofWithTrie(tr *trie2.Trie, st commonstate.StateReader, contra }, nil } -func getContractStorageProof(state commonstate.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { +func getContractStorageProof( + state commonstate.StateReader, + storageKeys []StorageKeys, +) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) for i, storageKey := range storageKeys { contractStorageTrie, err := state.ContractStorageTrie(storageKey.Contract) diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index eb078bb4a5..6c81467100 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -200,7 +200,11 @@ func TestStorageProof(t *testing.T) { contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) } else { newComm := new(felt.Felt).SetUint64(1) - createTrie := func(t *testing.T, id trieutils.TrieID, trieDB *trie2.TestNodeDatabase) *trie2.Trie { + createTrie := func( + t *testing.T, + id trieutils.TrieID, + trieDB *trie2.TestNodeDatabase, + ) *trie2.Trie { tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) _ = tr.Update(key, value) _ = tr.Update(key2, value2) @@ -219,9 +223,19 @@ func TestStorageProof(t *testing.T) { trieRoot = tmpTrieRoot // recreate because the previous ones are committed - classTrie2, err := trie2.New(trieutils.NewClassTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + classTrie2, err := trie2.New( + trieutils.NewClassTrieID(*newComm), + 251, + crypto.Pedersen, + &trieDB, + ) require.NoError(t, err) - contractTrie2, err = trie2.New(trieutils.NewContractTrieID(*newComm), 251, crypto.Pedersen, &trieDB) + contractTrie2, err = trie2.New( + trieutils.NewContractTrieID(*newComm), + 251, + crypto.Pedersen, + &trieDB, + ) require.NoError(t, err) classTrie = (*commontrie.TrieAdapter)(classTrie2) contractTrie = (*commontrie.TrieAdapter)(contractTrie2) @@ -357,8 +371,14 @@ func TestStorageProof(t *testing.T) { }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { if statetestutils.UseNewState() { - mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(1) - mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, state.ErrContractNotDeployed).Times(0) + mockState.EXPECT().ContractNonce(noSuchKey).Return( + felt.Zero, + state.ErrContractNotDeployed, + ).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return( + felt.Zero, + state.ErrContractNotDeployed, + ).Times(0) } else { mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index 6ca4522043..b06a4a2433 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -93,7 +93,9 @@ func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, err return nil, nil, nil } -func (fs *fakeSyncer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingStateBeforeIndex( + index int, +) (commonstate.StateReader, func() error, error) { return nil, nil, nil } diff --git a/sequencer/sequencer.go b/sequencer/sequencer.go index 31c1b5e63d..a14ec71e78 100644 --- a/sequencer/sequencer.go +++ b/sequencer/sequencer.go @@ -215,7 +215,9 @@ func (s *Sequencer) PendingState() (commonstate.StateReader, func() error, error return s.builder.PendingState(s.buildState) } -func (s *Sequencer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { +func (s *Sequencer) PendingStateBeforeIndex( + index int, +) (commonstate.StateReader, func() error, error) { return nil, nil, errors.ErrUnsupported } diff --git a/sync/pending.go b/sync/pending.go index 7a6792fa50..327d00890a 100644 --- a/sync/pending.go +++ b/sync/pending.go @@ -20,7 +20,11 @@ type PendingState struct { head commonstate.StateReader } -func NewPendingState(stateDiff *core.StateDiff, newClasses map[felt.Felt]core.Class, head commonstate.StateReader) *PendingState { +func NewPendingState( + stateDiff *core.StateDiff, + newClasses map[felt.Felt]core.Class, + head commonstate.StateReader, +) *PendingState { return &PendingState{ stateDiff: stateDiff, newClasses: newClasses, diff --git a/sync/sync.go b/sync/sync.go index cd343cb0c7..37f4c6f4e0 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -115,7 +115,9 @@ func (n *NoopSynchronizer) PendingState() (commonstate.StateReader, func() error return nil, nil, errors.New("PendingState() not implemented") } -func (n *NoopSynchronizer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { +func (n *NoopSynchronizer) PendingStateBeforeIndex( + index int, +) (commonstate.StateReader, func() error, error) { return nil, nil, errors.New("PendingStateBeforeIndex() not implemented") } @@ -679,7 +681,9 @@ func (s *Synchronizer) PendingState() (commonstate.StateReader, func() error, er // PendingStateAfterIndex returns the state obtained by applying all transaction state diffs // up to given index in the pre-confirmed block. -func (s *Synchronizer) PendingStateBeforeIndex(index int) (commonstate.StateReader, func() error, error) { +func (s *Synchronizer) PendingStateBeforeIndex( + index int, +) (commonstate.StateReader, func() error, error) { txn := s.db.NewIndexedBatch() pendingPtr := s.pendingData.Load() diff --git a/vm/vm_test.go b/vm/vm_test.go index 4884d5d2af..be96b8190e 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -38,7 +38,9 @@ func TestCallDeprecatedCairo(t *testing.T) { require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) - newRoot := felt.NewUnsafeFromString[felt.Felt]("0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8") + newRoot := felt.NewUnsafeFromString[felt.Felt]( + "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8", + ) require.NoError(t, testState.Update(0, &core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: newRoot, @@ -186,7 +188,9 @@ func TestCallCairo(t *testing.T) { require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) - newRoot := felt.NewUnsafeFromString[felt.Felt]("0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84") + newRoot := felt.NewUnsafeFromString[felt.Felt]( + "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84", + ) require.NoError(t, testState.Update(0, &core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: newRoot, From cae623ccfc8448db35dc483df4cc45af2fa70951 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 20 Oct 2025 16:06:29 +0200 Subject: [PATCH 37/47] lint --- core/trie2/triedb/rawdb/cache.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/trie2/triedb/rawdb/cache.go b/core/trie2/triedb/rawdb/cache.go index 9974b553e6..3c165cc6f2 100644 --- a/core/trie2/triedb/rawdb/cache.go +++ b/core/trie2/triedb/rawdb/cache.go @@ -22,7 +22,8 @@ type cleanCache struct { cache *fastcache.Cache // map[nodeKey]node } -// Creates a new clean cache with the given size. The size is the maximum size of the cache in bytes. +// Creates a new clean cache with the given size. +// The size is the maximum size of the cache in bytes. func newCleanCache(size uint64) cleanCache { if size > uint64(math.MaxInt) { panic("cache size too large: uint64 to int conversion would overflow") From 953891a8a70a35f411f4755346d30173003a1ea2 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 20 Oct 2025 18:03:18 +0200 Subject: [PATCH 38/47] small order refactor --- blockchain/blockchain.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index ab3cb1c89a..bae8925d4d 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -348,11 +348,11 @@ func (b *Blockchain) store( return err } - if err := b.runningFilter.Insert(block.EventsBloom, block.Number); err != nil { + if err := batch.Write(); err != nil { return err } - return batch.Write() + return b.runningFilter.Insert(block.EventsBloom, block.Number) } // VerifyBlock assumes the block has already been sanity-checked. From 8618bf63e2d5867034f1f8f55a7d33c6c97f5770 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 21 Oct 2025 23:41:58 +0200 Subject: [PATCH 39/47] remove batch flushing from the state --- blockchain/blockchain.go | 30 +-- core/state.go | 1 - core/state/commonstate/deprecated_state.go | 2 - core/state/commonstate/state.go | 9 +- core/state/commonstate/state_test.go | 10 +- core/state/history.go | 2 +- core/state/history_test.go | 14 +- core/state/state.go | 61 +++--- core/state/state_test.go | 223 ++++++++++++++------- core/state_test.go | 60 +++--- genesis/genesis.go | 2 +- rpc/v8/transaction.go | 1 - sync/sync.go | 4 +- vm/vm_test.go | 39 ++-- 14 files changed, 266 insertions(+), 192 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index ab3cb1c89a..55831b98da 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -267,7 +267,7 @@ func (b *Blockchain) deprecatedStore( } state := core.NewState(txn) - if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil { + if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { return err } if err := core.WriteBlockHeader(txn, block.Header); err != nil { @@ -315,14 +315,15 @@ func (b *Blockchain) store( if err := verifyBlock(b.database, block); err != nil { return err } - state, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil) + batch := b.database.NewBatch() + + state, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil, batch) if err != nil { return err } - if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil { + if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil { return err } - batch := b.database.NewBatch() if err := core.WriteBlockHeader(batch, block.Header); err != nil { return err } @@ -428,7 +429,7 @@ func (b *Blockchain) HeadState() (commonstate.StateReader, StateCloser, error) { return nil, nil, err } - state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn) + state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn, nil) return state, noopStateCloser, err } @@ -585,7 +586,7 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) { if err != nil { return ret, err } - state, err := state.New(stateUpdate.NewRoot, b.stateDB) + state, err := state.New(stateUpdate.NewRoot, b.stateDB, nil) if err != nil { return ret, err } @@ -666,7 +667,9 @@ func (b *Blockchain) revertHead() error { if err != nil { return err } - state, err := state.New(stateUpdate.NewRoot, b.stateDB) + batch := b.database.NewBatch() + + state, err := state.New(stateUpdate.NewRoot, b.stateDB, batch) if err != nil { return err } @@ -680,7 +683,6 @@ func (b *Blockchain) revertHead() error { } genesisBlock := blockNumber == 0 - batch := b.database.NewBatch() for _, key := range [][]byte{ db.BlockHeaderByNumberKey(header.Number), db.BlockHeaderNumbersByHashKey(header.Hash), @@ -729,7 +731,7 @@ func (b *Blockchain) Simulate( // Simulate without commit txn := b.database.NewIndexedBatch() defer txn.Reset() - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, false); err != nil { + if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil { return SimulateResult{}, err } @@ -764,7 +766,7 @@ func (b *Blockchain) Finalise( ) error { if !b.StateFactory.UseNewState { err := b.database.Update(func(txn db.IndexedBatch) error { - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, true); err != nil { + if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil { return err } commitments, err := b.updateBlockHash(block, stateUpdate) @@ -786,7 +788,7 @@ func (b *Blockchain) Finalise( return b.runningFilter.Insert(block.EventsBloom, block.Number) } else { batch := b.database.NewBatch() - if err := b.updateStateRoots(nil, block, stateUpdate, newClasses, true); err != nil { + if err := b.updateStateRoots(nil, batch, block, stateUpdate, newClasses); err != nil { return err } commitments, err := b.updateBlockHash(block, stateUpdate) @@ -812,10 +814,10 @@ func (b *Blockchain) Finalise( // updateStateRoots computes and updates state roots in the block and state update func (b *Blockchain) updateStateRoots( txn db.IndexedBatch, + batch db.Batch, block *core.Block, stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.Class, - flushChanges bool, ) error { var height uint64 var err error @@ -831,7 +833,7 @@ func (b *Blockchain) updateStateRoots( stateRoot = &felt.Zero } - state, err := b.StateFactory.NewState(stateRoot, txn) + state, err := b.StateFactory.NewState(stateRoot, txn, batch) if err != nil { return err } @@ -844,7 +846,7 @@ func (b *Blockchain) updateStateRoots( stateUpdate.OldRoot = &oldStateRoot // Apply state update - if err = state.Update(block.Number, stateUpdate, newClasses, true, flushChanges); err != nil { + if err = state.Update(block.Number, stateUpdate, newClasses, true); err != nil { return err } diff --git a/core/state.go b/core/state.go index ead6d217de..e8972dc3d2 100644 --- a/core/state.go +++ b/core/state.go @@ -237,7 +237,6 @@ func (s *State) Update( update *StateUpdate, declaredClasses map[felt.Felt]Class, skipVerifyNewRoot bool, - flushChanges bool, // TODO(maksym): added to satisfy the interface, but not used ) error { err := s.verifyStateUpdateRoot(update.OldRoot) if err != nil { diff --git a/core/state/commonstate/deprecated_state.go b/core/state/commonstate/deprecated_state.go index 8af3e95ab8..c310ed8eac 100644 --- a/core/state/commonstate/deprecated_state.go +++ b/core/state/commonstate/deprecated_state.go @@ -95,14 +95,12 @@ func (s *DeprecatedStateAdapter) Update( update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool, - flushChanges bool, ) error { return (*core.State)(s).Update( blockNumber, update, declaredClasses, skipVerifyNewRoot, - flushChanges, ) } diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 13580cd7cd..ee100e6548 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -23,7 +23,6 @@ type State interface { update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool, - flushChanges bool, ) error Revert(blockNum uint64, update *core.StateUpdate) error Commitment() (felt.Felt, error) @@ -88,14 +87,12 @@ func (s *StateAdapter) Update( update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool, - flushChanges bool, ) error { return (*state.State)(s).Update( blockNumber, update, declaredClasses, skipVerifyNewRoot, - flushChanges, ) } @@ -201,13 +198,13 @@ func NewStateFactory( }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (State, error) { +func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch, batch db.Batch) (State, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) return NewDeprecatedStateAdapter(deprecatedState), nil } - stateState, err := state.New(stateRoot, sf.stateDB) + stateState, err := state.New(stateRoot, sf.stateDB, batch) if err != nil { return nil, err } @@ -239,7 +236,7 @@ func (sf *StateFactory) EmptyState() (StateReader, error) { emptyState := core.NewState(txn) return NewDeprecatedStateReaderAdapter(emptyState), nil } - state, err := state.New(&felt.Zero, sf.stateDB) + state, err := state.New(&felt.Zero, sf.stateDB, nil) if err != nil { return nil, err } diff --git a/core/state/commonstate/state_test.go b/core/state/commonstate/state_test.go index af6a41116b..0e99dbc066 100644 --- a/core/state/commonstate/state_test.go +++ b/core/state/commonstate/state_test.go @@ -25,7 +25,7 @@ func TestStateAdapter(t *testing.T) { panic(err) } stateDB := state.NewStateDB(memDB, db) - state, err := state.New(&felt.Zero, stateDB) + state, err := state.New(&felt.Zero, stateDB, nil) require.NoError(t, err) stateAdapter := NewStateAdapter(state) @@ -75,10 +75,10 @@ func BenchmarkStateUpdate(b *testing.B) { for i := range samples { declaredClasses := make(map[felt.Felt]core.Class) - if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { + if err := state.Update(uint64(i), suList[i], declaredClasses, false); err != nil { b.Fatalf("Update failed: %v", err) } - state, err = stateFactory.NewState(suList[i].NewRoot, txn) + state, err = stateFactory.NewState(suList[i].NewRoot, txn, nil) require.NoError(b, err) } } @@ -96,7 +96,7 @@ func BenchmarkStateUpdate(b *testing.B) { for i := range samples { declaredClasses := make(map[felt.Felt]core.Class) - if err := state.Update(uint64(i), suList[i], declaredClasses, false, true); err != nil { + if err := state.Update(uint64(i), suList[i], declaredClasses, false); err != nil { b.Fatalf("Update failed: %v", err) } } @@ -113,7 +113,7 @@ func prepareState(b *testing.B, newState bool) (State, *StateFactory, db.Indexed stateFactory, err := NewStateFactory(newState, trieDB, stateDB) require.NoError(b, err) - state, err := stateFactory.NewState(&felt.Zero, txn) + state, err := stateFactory.NewState(&felt.Zero, txn, nil) require.NoError(b, err) return state, stateFactory, txn diff --git a/core/state/history.go b/core/state/history.go index de854271b6..128a54e080 100644 --- a/core/state/history.go +++ b/core/state/history.go @@ -18,7 +18,7 @@ type StateHistory struct { } func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (StateHistory, error) { - state, err := New(stateRoot, db) + state, err := New(stateRoot, db, nil) if err != nil { return StateHistory{}, err } diff --git a/core/state/history_test.go b/core/state/history_test.go index 24a0a5cce9..cdee294e6f 100644 --- a/core/state/history_test.go +++ b/core/state/history_test.go @@ -123,10 +123,11 @@ func TestStateHistoryClassOperations(t *testing.T) { NewRoot: &felt.Zero, StateDiff: &core.StateDiff{}, } - state, err := New(&felt.Zero, stateDB) - require.NoError(t, err) - err = state.Update(0, stateUpdate, classes, false, true) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) + require.NoError(t, state.Update(0, stateUpdate, classes, false)) + require.NoError(t, batch.Write()) stateComm, err := state.Commitment() require.NoError(t, err) @@ -139,10 +140,11 @@ func TestStateHistoryClassOperations(t *testing.T) { class2Hash: class2, } - state, err = New(&stateComm, stateDB) - require.NoError(t, err) - err = state.Update(1, stateUpdate, classes2, false, true) + batch = stateDB.disk.NewBatch() + state, err = New(&stateComm, stateDB, batch) require.NoError(t, err) + require.NoError(t, state.Update(1, stateUpdate, classes2, false)) + require.NoError(t, batch.Write()) historyBlock0, err := NewStateHistory(0, &felt.Zero, stateDB) require.NoError(t, err) diff --git a/core/state/state.go b/core/state/state.go index 906bccb839..be12df0ba1 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -61,9 +61,11 @@ type State struct { classTrie *trie2.Trie stateObjects map[felt.Felt]*stateObject + + batch db.Batch } -func New(stateRoot *felt.Felt, db *StateDB) (*State, error) { +func New(stateRoot *felt.Felt, db *StateDB, batch db.Batch) (*State, error) { contractTrie, err := db.ContractTrie(stateRoot) if err != nil { return nil, err @@ -80,6 +82,7 @@ func New(stateRoot *felt.Felt, db *StateDB) (*State, error) { contractTrie: contractTrie, classTrie: classTrie, stateObjects: make(map[felt.Felt]*stateObject), + batch: batch, }, nil } @@ -168,7 +171,6 @@ func (s *State) Update( update *core.StateUpdate, declaredClasses map[felt.Felt]core.Class, skipVerifyNewRoot bool, - flushChanges bool, ) error { if err := s.verifyComm(update.OldRoot); err != nil { return err @@ -237,10 +239,8 @@ func (s *State) Update( deployedContracts: update.StateDiff.ReplacedClasses, }) - if flushChanges { - if err := s.flush(blockNum, &stateUpdate, dirtyClasses, true); err != nil { - return err - } + if s.batch != nil { + return s.flush(blockNum, &stateUpdate, dirtyClasses, true) } return nil @@ -307,19 +307,19 @@ func (s *State) Revert(blockNum uint64, update *core.StateUpdate) error { if !newComm.Equal(update.OldRoot) { return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.OldRoot, &newComm) } - - if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil { - return err + if s.batch != nil { + if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil { + return err + } + if err := s.deleteHistory(blockNum, update.StateDiff); err != nil { + return err + } } if err := s.db.stateCache.PopLayer(update.NewRoot, update.OldRoot); err != nil { return err } - if err := s.deleteHistory(blockNum, update.StateDiff); err != nil { - return err - } - return nil } @@ -492,45 +492,44 @@ func (s *State) flush( classes map[felt.Felt]core.Class, storeHistory bool, ) error { - batch := s.db.disk.NewBatch() if err := s.db.triedb.Update( &update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes, - batch, + s.batch, ); err != nil { return err } for addr, obj := range s.stateObjects { if obj == nil { // marked as deleted - if err := DeleteContract(batch, &addr); err != nil { + if err := DeleteContract(s.batch, &addr); err != nil { return err } // TODO(weiihann): handle hash-based, and there should be better ways of doing this - if err := trieutils.DeleteStorageNodesByPath(batch, addr); err != nil { + if err := trieutils.DeleteStorageNodesByPath(s.batch, addr); err != nil { return err } } else { // updated - if err := WriteContract(batch, &addr, obj.contract); err != nil { + if err := WriteContract(s.batch, &addr, obj.contract); err != nil { return err } if storeHistory { for key, val := range obj.dirtyStorage { - if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil { + if err := WriteStorageHistory(s.batch, &addr, &key, blockNum, val); err != nil { return err } } - if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil { + if err := WriteNonceHistory(s.batch, &addr, blockNum, &obj.contract.Nonce); err != nil { return err } - if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { + if err := WriteClassHashHistory(s.batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { return err } } @@ -539,17 +538,17 @@ func (s *State) flush( for classHash, class := range classes { if class == nil { // mark as deleted - if err := DeleteClass(batch, &classHash); err != nil { + if err := DeleteClass(s.batch, &classHash); err != nil { return err } } else { - if err := WriteClass(batch, &classHash, class, blockNum); err != nil { + if err := WriteClass(s.batch, &classHash, class, blockNum); err != nil { return err } } } - return batch.Write() + return nil } func (s *State) updateClassTrie(declaredClasses map[felt.Felt]*felt.Felt, classDefs map[felt.Felt]core.Class) error { @@ -736,39 +735,37 @@ func (s *State) valueAt(prefix []byte, blockNum uint64, cb func(val []byte) erro } func (s *State) deleteHistory(blockNum uint64, diff *core.StateDiff) error { - batch := s.db.disk.NewBatch() - for addr, storage := range diff.StorageDiffs { for key := range storage { - if err := DeleteStorageHistory(batch, &addr, &key, blockNum); err != nil { + if err := DeleteStorageHistory(s.batch, &addr, &key, blockNum); err != nil { return err } } } for addr := range diff.Nonces { - if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil { + if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil { return err } } for addr := range diff.ReplacedClasses { - if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil { + if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil { return err } } for addr := range diff.DeployedContracts { - if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil { + if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil { return err } - if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil { + if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil { return err } } - return batch.Write() + return nil } func (s *State) compareContracts(a, b felt.Felt) int { diff --git a/core/state/state_test.go b/core/state/state_test.go index ca399b8ef0..01ec0fc197 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -78,10 +78,9 @@ func TestUpdate(t *testing.T) { OldRoot: oldRoot, } stateDB := setupState(t, stateUpdates, 0) - state, err := New(&felt.Zero, stateDB) + state, err := New(&felt.Zero, stateDB, nil) require.NoError(t, err) - err = state.Update(block0, su, nil, false, true) - require.Error(t, err) + require.Error(t, state.Update(block0, su, nil, false)) }) t.Run("error when state's new root doesn't match state update's new root", func(t *testing.T) { @@ -92,30 +91,29 @@ func TestUpdate(t *testing.T) { StateDiff: new(core.StateDiff), } stateDB := setupState(t, stateUpdates, 0) - state, err := New(&felt.Zero, stateDB) + state, err := New(&felt.Zero, stateDB, nil) require.NoError(t, err) - err = state.Update(block0, su, nil, false, true) - require.Error(t, err) + require.Error(t, state.Update(block0, su, nil, false)) }) t.Run("post v0.11.0 declared classes affect root", func(t *testing.T) { t.Run("without class definition", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 3) - state, err := New(stateUpdates[3].OldRoot, stateDB) + state, err := New(stateUpdates[3].OldRoot, stateDB, nil) require.NoError(t, err) - require.Error(t, state.Update(block3, stateUpdates[3], nil, false, true)) + require.Error(t, state.Update(block3, stateUpdates[3], nil, false)) }) t.Run("with class definition", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 3) - state, err := New(stateUpdates[3].OldRoot, stateDB) + state, err := New(stateUpdates[3].OldRoot, stateDB, nil) require.NoError(t, err) - require.NoError(t, state.Update(block3, su3, su3DeclaredClasses(), false, true)) + require.NoError(t, state.Update(block3, su3, su3DeclaredClasses(), false)) }) }) t.Run("update noClassContracts storage", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 5) - state, err := New(stateUpdates[4].NewRoot, stateDB) + state, err := New(stateUpdates[4].NewRoot, stateDB, nil) require.NoError(t, err) gotValue, err := state.ContractStorage(scAddr, scKey) @@ -141,9 +139,9 @@ func TestUpdate(t *testing.T) { }, } stateDB := setupState(t, stateUpdates, 5) - state, err := New(stateUpdates[4].NewRoot, stateDB) + state, err := New(stateUpdates[4].NewRoot, stateDB, nil) require.NoError(t, err) - require.ErrorIs(t, state.Update(block5, su5, nil, false, true), ErrContractNotDeployed) + require.ErrorIs(t, state.Update(block5, su5, nil, false), ErrContractNotDeployed) }) } @@ -155,7 +153,7 @@ func TestContractClassHash(t *testing.T) { su1 := stateUpdates[1] stateDB := setupState(t, stateUpdates, 2) - state, err := New(su1.NewRoot, stateDB) + state, err := New(su1.NewRoot, stateDB, nil) require.NoError(t, err) allDeployedContracts := make(map[felt.Felt]*felt.Felt) @@ -185,10 +183,11 @@ func TestNonce(t *testing.T) { t.Run("newly deployed contract has zero nonce", func(t *testing.T) { stateDB := setupState(t, nil, 0) - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block0, su0, nil, false, true)) - + require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, batch.Write()) gotNonce, err := state.ContractNonce(addr) require.NoError(t, err) assert.Equal(t, felt.Zero, gotNonce) @@ -196,10 +195,11 @@ func TestNonce(t *testing.T) { t.Run("update contract nonce", func(t *testing.T) { stateDB := setupState(t, nil, 0) - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block0, su0, nil, false, true)) - + require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, batch.Write()) expectedNonce := new(felt.Felt).SetUint64(1) su1 := &core.StateUpdate{ NewRoot: felt.NewUnsafeFromString[felt.Felt]("0x6210642ffd49f64617fc9e5c0bbe53a6a92769e2996eb312a42d2bdb7f2afc1"), @@ -209,9 +209,11 @@ func TestNonce(t *testing.T) { }, } - state1, err := New(su1.OldRoot, stateDB) + batch1 := stateDB.disk.NewBatch() + state1, err := New(su1.OldRoot, stateDB, batch1) require.NoError(t, err) - require.NoError(t, state1.Update(block1, su1, nil, false, true)) + require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, batch1.Write()) gotNonce, err := state1.ContractNonce(addr) require.NoError(t, err) @@ -232,7 +234,8 @@ func TestClass(t *testing.T) { cairo1Class, err := gw.Class(t.Context(), cairo1Hash) require.NoError(t, err) - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) su0, err := gw.StateUpdate(t.Context(), 0) @@ -240,8 +243,8 @@ func TestClass(t *testing.T) { require.NoError(t, state.Update(0, su0, map[felt.Felt]core.Class{ *cairo0Hash: cairo0Class, *cairo1Hash: cairo1Class, - }, false, true)) - + }, false)) + require.NoError(t, batch.Write()) gotCairo1Class, err := state.Class(cairo1Hash) require.NoError(t, err) assert.Zero(t, gotCairo1Class.At) @@ -259,7 +262,7 @@ func TestContractDeployedAt(t *testing.T) { root := *stateUpdates[1].NewRoot t.Run("deployed on genesis", func(t *testing.T) { - state, err := New(&root, stateDB) + state, err := New(&root, stateDB, nil) require.NoError(t, err) d0 := felt.NewUnsafeFromString[felt.Felt]("0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") @@ -273,7 +276,7 @@ func TestContractDeployedAt(t *testing.T) { }) t.Run("deployed after genesis", func(t *testing.T) { - state, err := New(&root, stateDB) + state, err := New(&root, stateDB, nil) require.NoError(t, err) d1 := felt.NewUnsafeFromString[felt.Felt]("0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") @@ -287,7 +290,7 @@ func TestContractDeployedAt(t *testing.T) { }) t.Run("not deployed", func(t *testing.T) { - state, err := New(&root, stateDB) + state, err := New(&root, stateDB, nil) require.NoError(t, err) notDeployed := felt.NewUnsafeFromString[felt.Felt]("0xDEADBEEF") @@ -318,17 +321,21 @@ func TestRevert(t *testing.T) { }, } - state, err := New(su1.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, replaceStateUpdate, nil, false, true)) + require.NoError(t, state.Update(block2, replaceStateUpdate, nil, false)) + require.NoError(t, batch.Write()) gotClassHash, err := state.ContractClassHash(&su1FirstDeployedAddress) require.NoError(t, err) assert.Equal(t, *replacedVal, gotClassHash) - state, err = New(replaceStateUpdate.NewRoot, stateDB) + batch1 := stateDB.disk.NewBatch() + state, err = New(replaceStateUpdate.NewRoot, stateDB, batch1) require.NoError(t, err) require.NoError(t, state.Revert(block2, replaceStateUpdate)) + require.NoError(t, batch1.Write()) gotClassHash, err = state.ContractClassHash(&su1FirstDeployedAddress) require.NoError(t, err) @@ -349,17 +356,21 @@ func TestRevert(t *testing.T) { }, } - state, err := New(su1.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, nonceStateUpdate, nil, false, true)) + require.NoError(t, state.Update(block2, nonceStateUpdate, nil, false)) + require.NoError(t, batch.Write()) gotNonce, err := state.ContractNonce(&su1FirstDeployedAddress) require.NoError(t, err) assert.Equal(t, *replacedVal, gotNonce) - state, err = New(nonceStateUpdate.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(nonceStateUpdate.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block2, nonceStateUpdate)) + require.NoError(t, batch.Write()) gotNonce, err = state.ContractNonce(&su1FirstDeployedAddress) require.NoError(t, err) @@ -382,18 +393,23 @@ func TestRevert(t *testing.T) { }, } - state, err := New(su1.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, storageStateUpdate, nil, false, true)) + require.NoError(t, state.Update(block2, storageStateUpdate, nil, false)) + require.NoError(t, batch.Write()) + gotStorage, err := state.ContractStorage(&su1FirstDeployedAddress, replacedVal) require.NoError(t, err) assert.Equal(t, *replacedVal, gotStorage) - state, err = New(storageStateUpdate.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(storageStateUpdate.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block2, storageStateUpdate)) + require.NoError(t, batch.Write()) storage, sErr := state.ContractStorage(&su1FirstDeployedAddress, replacedVal) require.NoError(t, sErr) assert.Equal(t, felt.Zero, storage) @@ -446,13 +462,17 @@ func TestRevert(t *testing.T) { }, } - state, err := New(su1.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, declaredClassesStateUpdate, classesM, false, true)) + require.NoError(t, state.Update(block2, declaredClassesStateUpdate, classesM, false)) + require.NoError(t, batch.Write()) - state, err = New(declaredClassesStateUpdate.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(declaredClassesStateUpdate.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block2, declaredClassesStateUpdate)) + require.NoError(t, batch.Write()) var decClass *core.DeclaredClass decClass, err = state.Class(cairo0Addr) @@ -467,33 +487,45 @@ func TestRevert(t *testing.T) { t.Run("should be able to update after a revert", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 2) - state, err := New(su1.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, su2, nil, false, true)) + require.NoError(t, state.Update(block2, su2, nil, false)) + require.NoError(t, batch.Write()) - state, err = New(su2.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su2.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block2, su2)) + require.NoError(t, batch.Write()) - state, err = New(su1.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su1.NewRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block2, su2, nil, false, true)) + require.NoError(t, state.Update(block2, su2, nil, false)) + require.NoError(t, batch.Write()) }) t.Run("should be able to revert all the updates", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 3) - state, err := New(su2.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su2.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block2, su2)) + require.NoError(t, batch.Write()) - state, err = New(su1.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su1.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block1, su1)) + require.NoError(t, batch.Write()) - state, err = New(su0.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su0.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block0, su0)) + require.NoError(t, batch.Write()) }) t.Run("revert no class contracts", func(t *testing.T) { @@ -511,13 +543,17 @@ func TestRevert(t *testing.T) { su1.NewRoot = felt.NewUnsafeFromString[felt.Felt]("0x2829ac1aea81c890339e14422fe757d6831744031479cf33a9260d14282c341") su1.StateDiff.StorageDiffs[*scAddr] = map[felt.Felt]*felt.Felt{*scKey: scValue} - state, err := New(su1.OldRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su1.OldRoot, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block1, &su1, nil, false, true)) + require.NoError(t, state.Update(block1, &su1, nil, false)) + require.NoError(t, batch.Write()) - state, err = New(su1.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su1.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block1, &su1)) + require.NoError(t, batch.Write()) }) t.Run("revert declared classes", func(t *testing.T) { @@ -541,10 +577,13 @@ func TestRevert(t *testing.T) { *sierraHash: &core.Cairo1Class{}, } - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block0, declareDiff, newClasses, false, true)) + require.NoError(t, state.Update(block0, declareDiff, newClasses, false)) + require.NoError(t, batch.Write()) + batch = stateDB.disk.NewBatch() declaredClass, err := state.Class(classHash) require.NoError(t, err) assert.Equal(t, uint64(0), declaredClass.At) @@ -552,10 +591,12 @@ func TestRevert(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint64(0), sierraClass.At) - state, err = New(declareDiff.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(declareDiff.NewRoot, stateDB, batch) require.NoError(t, err) declareDiff.OldRoot = declareDiff.NewRoot - require.NoError(t, state.Update(block1, declareDiff, newClasses, false, true)) + require.NoError(t, state.Update(block1, declareDiff, newClasses, false)) + require.NoError(t, batch.Write()) // Redeclaring should not change the declared at block number declaredClass, err = state.Class(classHash) @@ -565,9 +606,11 @@ func TestRevert(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint64(0), sierraClass.At) - state, err = New(declareDiff.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(declareDiff.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block1, declareDiff)) + require.NoError(t, batch.Write()) // Reverting a re-declaration should not change state commitment or remove class definitions declaredClass, err = state.Class(classHash) @@ -577,10 +620,12 @@ func TestRevert(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint64(0), sierraClass.At) - state, err = New(declareDiff.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(declareDiff.NewRoot, stateDB, batch) require.NoError(t, err) declareDiff.OldRoot = &felt.Zero require.NoError(t, state.Revert(block0, declareDiff)) + require.NoError(t, batch.Write()) declaredClass, err = state.Class(classHash) require.ErrorIs(t, err, db.ErrKeyNotFound) @@ -609,22 +654,28 @@ func TestRevert(t *testing.T) { }, } - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) - require.NoError(t, state.Update(block0, su, nil, false, true)) + require.NoError(t, state.Update(block0, su, nil, false)) + require.NoError(t, batch.Write()) - state, err = New(su.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state, err = New(su.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block0, su)) + require.NoError(t, batch.Write()) }) t.Run("db should be empty after block0 revert", func(t *testing.T) { stateDB := setupState(t, stateUpdates, 1) - state, err := New(su0.NewRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su0.NewRoot, stateDB, batch) require.NoError(t, err) require.NoError(t, state.Revert(block0, su0)) + require.NoError(t, batch.Write()) it, err := stateDB.disk.NewIterator(nil, false) require.NoError(t, err) @@ -664,7 +715,7 @@ func TestContractHistory(t *testing.T) { t.Run("empty", func(t *testing.T) { stateDB := newTestStateDB() - state, err := New(&felt.Zero, stateDB) + state, err := New(&felt.Zero, stateDB, nil) require.NoError(t, err) nonce, err := state.ContractNonceAt(addr, block0) @@ -682,7 +733,8 @@ func TestContractHistory(t *testing.T) { t.Run("retrieve block height is the same as update", func(t *testing.T) { stateDB := newTestStateDB() - state, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) su0 := &core.StateUpdate{ @@ -695,7 +747,8 @@ func TestContractHistory(t *testing.T) { }, } - require.NoError(t, state.Update(block0, su0, nil, false, true)) + require.NoError(t, state.Update(block0, su0, nil, false)) + require.NoError(t, batch.Write()) gotNonce, err := state.ContractNonceAt(addr, block0) require.NoError(t, err) @@ -712,15 +765,19 @@ func TestContractHistory(t *testing.T) { t.Run("retrieve block height before update", func(t *testing.T) { stateDB := newTestStateDB() - state0, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state0, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) su0 := emptyStateUpdate - require.NoError(t, state0.Update(block0, su0, nil, false, true)) + require.NoError(t, state0.Update(block0, su0, nil, false)) + require.NoError(t, batch.Write()) - state1, err := New(su0.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state1, err := New(su0.NewRoot, stateDB, batch) require.NoError(t, err) su1 := su - require.NoError(t, state1.Update(block1, su1, nil, false, true)) + require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, batch.Write()) gotNonce, err := state1.ContractNonceAt(addr, block0) require.NoError(t, err) @@ -737,21 +794,26 @@ func TestContractHistory(t *testing.T) { t.Run("retrieve block height in between updates", func(t *testing.T) { stateDB := newTestStateDB() - state0, err := New(&felt.Zero, stateDB) + batch := stateDB.disk.NewBatch() + state0, err := New(&felt.Zero, stateDB, batch) require.NoError(t, err) su0 := su - require.NoError(t, state0.Update(block0, su0, nil, false, true)) + require.NoError(t, state0.Update(block0, su0, nil, false)) + require.NoError(t, batch.Write()) - state1, err := New(su0.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state1, err := New(su0.NewRoot, stateDB, batch) require.NoError(t, err) su1 := &core.StateUpdate{ OldRoot: su0.NewRoot, NewRoot: su0.NewRoot, StateDiff: &core.StateDiff{}, } - require.NoError(t, state1.Update(block1, su1, nil, false, true)) + require.NoError(t, state1.Update(block1, su1, nil, false)) + require.NoError(t, batch.Write()) - state2, err := New(su1.NewRoot, stateDB) + batch = stateDB.disk.NewBatch() + state2, err := New(su1.NewRoot, stateDB, batch) require.NoError(t, err) su2 := &core.StateUpdate{ OldRoot: su1.NewRoot, @@ -766,7 +828,8 @@ func TestContractHistory(t *testing.T) { }, }, } - require.NoError(t, state2.Update(block2, su2, nil, false, true)) + require.NoError(t, state2.Update(block2, su2, nil, false)) + require.NoError(t, batch.Write()) gotNonce, err := state2.ContractNonceAt(addr, block1) require.NoError(t, err) @@ -800,12 +863,14 @@ func BenchmarkStateUpdate(b *testing.B) { for b.Loop() { stateDB := newTestStateDB() for i, su := range stateUpdates { - state, err := New(su.OldRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su.OldRoot, stateDB, batch) require.NoError(b, err) - err = state.Update(uint64(i), su, nil, false, true) + err = state.Update(uint64(i), su, nil, false) if err != nil { b.Fatalf("Error updating state: %v", err) } + require.NoError(b, batch.Write()) } } } @@ -834,7 +899,8 @@ func getStateUpdates(t *testing.T) []*core.StateUpdate { func setupState(t *testing.T, stateUpdates []*core.StateUpdate, blocks uint64) *StateDB { stateDB := newTestStateDB() for i, su := range stateUpdates[:blocks] { - state, err := New(su.OldRoot, stateDB) + batch := stateDB.disk.NewBatch() + state, err := New(su.OldRoot, stateDB, batch) require.NoError(t, err) var declaredClasses map[felt.Felt]core.Class if i == 3 { @@ -842,10 +908,11 @@ func setupState(t *testing.T, stateUpdates []*core.StateUpdate, blocks uint64) * } require.NoError( t, - state.Update(uint64(i), su, declaredClasses, false, true), + state.Update(uint64(i), su, declaredClasses, false), "failed to update state for block %d", i, ) + require.NoError(t, batch.Write()) newComm, err := state.Commitment() require.NoError(t, err) assert.Equal(t, *su.NewRoot, newComm) diff --git a/core/state_test.go b/core/state_test.go index 5968cf39fc..33187fde71 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -41,7 +41,7 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) t.Run("empty state updated with mainnet block 0 state update", func(t *testing.T) { - require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) gotNewRoot, rerr := state.Root() require.NoError(t, rerr) assert.Equal(t, su0.NewRoot, &gotNewRoot) @@ -53,7 +53,7 @@ func TestUpdate(t *testing.T) { OldRoot: oldRoot, } expectedErr := fmt.Sprintf("state's current root: %s does not match the expected root: %s", su0.NewRoot, oldRoot) - require.EqualError(t, state.Update(1, su, nil, false, true), expectedErr) + require.EqualError(t, state.Update(1, su, nil, false), expectedErr) }) t.Run("error when state new root doesn't match state update's new root", func(t *testing.T) { @@ -64,16 +64,16 @@ func TestUpdate(t *testing.T) { StateDiff: new(core.StateDiff), } expectedErr := fmt.Sprintf("state's current root: %s does not match the expected root: %s", su0.NewRoot, newRoot) - require.EqualError(t, state.Update(1, su, nil, false, true), expectedErr) + require.EqualError(t, state.Update(1, su, nil, false), expectedErr) }) t.Run("non-empty state updated multiple times", func(t *testing.T) { - require.NoError(t, state.Update(1, su1, nil, false, true)) + require.NoError(t, state.Update(1, su1, nil, false)) gotNewRoot, rerr := state.Root() require.NoError(t, rerr) assert.Equal(t, su1.NewRoot, &gotNewRoot) - require.NoError(t, state.Update(2, su2, nil, false, true)) + require.NoError(t, state.Update(2, su2, nil, false)) gotNewRoot, err = state.Root() require.NoError(t, err) assert.Equal(t, su2.NewRoot, &gotNewRoot) @@ -91,11 +91,11 @@ func TestUpdate(t *testing.T) { t.Run("post v0.11.0 declared classes affect root", func(t *testing.T) { t.Run("without class definition", func(t *testing.T) { - require.Error(t, state.Update(3, su3, nil, false, true)) + require.Error(t, state.Update(3, su3, nil, false)) }) require.NoError(t, state.Update(3, su3, map[felt.Felt]core.Class{ *felt.NewUnsafeFromString[felt.Felt]("0xDEADBEEF"): &core.Cairo1Class{}, - }, false, true)) + }, false)) assert.NotEqual(t, su3.NewRoot, su3.OldRoot) }) @@ -114,7 +114,7 @@ func TestUpdate(t *testing.T) { } t.Run("update systemContracts storage", func(t *testing.T) { - require.NoError(t, state.Update(4, su4, nil, false, true)) + require.NoError(t, state.Update(4, su4, nil, false)) gotValue, err := state.ContractStorage(scAddr, scKey) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestUpdate(t *testing.T) { StorageDiffs: map[felt.Felt]map[felt.Felt]*felt.Felt{*scAddr2: {*scKey: scValue}}, }, } - assert.ErrorIs(t, state.Update(5, su5, nil, false, true), core.ErrContractNotDeployed) + assert.ErrorIs(t, state.Update(5, su5, nil, false), core.ErrContractNotDeployed) }) } @@ -160,8 +160,8 @@ func TestContractClassHash(t *testing.T) { su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false, true)) - require.NoError(t, state.Update(1, su1, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(1, su1, nil, false)) allDeployedContracts := make(map[felt.Felt]*felt.Felt) @@ -187,7 +187,7 @@ func TestContractClassHash(t *testing.T) { }, } - require.NoError(t, state.Update(2, replaceUpdate, nil, false, true)) + require.NoError(t, state.Update(2, replaceUpdate, nil, false)) gotClassHash, err := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, err) @@ -214,7 +214,7 @@ func TestNonce(t *testing.T) { }, } - require.NoError(t, state.Update(0, su, nil, false, true)) + require.NoError(t, state.Update(0, su, nil, false)) t.Run("newly deployed contract has zero nonce", func(t *testing.T) { nonce, err := state.ContractNonce(addr) @@ -232,7 +232,7 @@ func TestNonce(t *testing.T) { }, } - require.NoError(t, state.Update(1, su, nil, false, true)) + require.NoError(t, state.Update(1, su, nil, false)) gotNonce, err := state.ContractNonce(addr) require.NoError(t, err) @@ -249,7 +249,7 @@ func TestStateHistory(t *testing.T) { state := core.NewState(txn) su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) contractAddr := felt.NewUnsafeFromString[felt.Felt]("0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") changedLoc := felt.NewUnsafeFromString[felt.Felt]("0x5") @@ -281,7 +281,7 @@ func TestStateHistory(t *testing.T) { }, }, } - require.NoError(t, state.Update(1, su, nil, false, true)) + require.NoError(t, state.Update(1, su, nil, false)) t.Run("should give old value for a location that changed after the given height", func(t *testing.T) { oldValue, err := state.ContractStorageAt(contractAddr, changedLoc, 0) @@ -305,8 +305,8 @@ func TestContractIsDeployedAt(t *testing.T) { su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false, true)) - require.NoError(t, state.Update(1, su1, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) + require.NoError(t, state.Update(1, su1, nil, false)) t.Run("deployed on genesis", func(t *testing.T) { deployedOn0 := felt.NewUnsafeFromString[felt.Felt]("0x20cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") @@ -357,7 +357,7 @@ func TestClass(t *testing.T) { require.NoError(t, state.Update(0, su0, map[felt.Felt]core.Class{ *cairo0Hash: cairo0Class, *cairo1Hash: cairo1Class, - }, false, true)) + }, false)) gotCairo1Class, err := state.Class(cairo1Hash) require.NoError(t, err) @@ -379,10 +379,10 @@ func TestRevert(t *testing.T) { state := core.NewState(txn) su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - require.NoError(t, state.Update(1, su1, nil, false, true)) + require.NoError(t, state.Update(1, su1, nil, false)) t.Run("revert a replaced class", func(t *testing.T) { replaceStateUpdate := &core.StateUpdate{ @@ -395,7 +395,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, replaceStateUpdate, nil, false, true)) + require.NoError(t, state.Update(2, replaceStateUpdate, nil, false)) require.NoError(t, state.Revert(2, replaceStateUpdate)) classHash, sErr := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -413,7 +413,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, nonceStateUpdate, nil, false, true)) + require.NoError(t, state.Update(2, nonceStateUpdate, nil, false)) require.NoError(t, state.Revert(2, nonceStateUpdate)) nonce, sErr := state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -465,7 +465,7 @@ func TestRevert(t *testing.T) { }, } - require.NoError(t, state.Update(2, declaredClassesStateUpdate, classesM, false, true)) + require.NoError(t, state.Update(2, declaredClassesStateUpdate, classesM, false)) require.NoError(t, state.Revert(2, declaredClassesStateUpdate)) var decClass *core.DeclaredClass @@ -481,7 +481,7 @@ func TestRevert(t *testing.T) { su2, err := gw.StateUpdate(t.Context(), 2) require.NoError(t, err) t.Run("should be able to apply new update after a Revert", func(t *testing.T) { - require.NoError(t, state.Update(2, su2, nil, false, true)) + require.NoError(t, state.Update(2, su2, nil, false)) }) t.Run("should be able to revert all the state", func(t *testing.T) { @@ -528,7 +528,7 @@ func TestRevertGenesisStateDiff(t *testing.T) { }, }, } - require.NoError(t, state.Update(0, su, nil, false, true)) + require.NoError(t, state.Update(0, su, nil, false)) require.NoError(t, state.Revert(0, su)) } @@ -544,7 +544,7 @@ func TestRevertSystemContracts(t *testing.T) { su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, nil, false, true)) + require.NoError(t, state.Update(0, su0, nil, false)) su1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) @@ -560,7 +560,7 @@ func TestRevertSystemContracts(t *testing.T) { su1.StateDiff.StorageDiffs[*scAddr] = map[felt.Felt]*felt.Felt{*scKey: scValue} - require.NoError(t, state.Update(1, su1, nil, false, true)) + require.NoError(t, state.Update(1, su1, nil, false)) require.NoError(t, state.Revert(1, su1)) @@ -593,7 +593,7 @@ func TestRevertDeclaredClasses(t *testing.T) { *sierraHash: &core.Cairo1Class{}, } - require.NoError(t, state.Update(0, declareDiff, newClasses, false, true)) + require.NoError(t, state.Update(0, declareDiff, newClasses, false)) declaredClass, err := state.Class(classHash) require.NoError(t, err) assert.Equal(t, uint64(0), declaredClass.At) @@ -602,7 +602,7 @@ func TestRevertDeclaredClasses(t *testing.T) { assert.Equal(t, uint64(0), sierraClass.At) declareDiff.OldRoot = declareDiff.NewRoot - require.NoError(t, state.Update(1, declareDiff, newClasses, false, true)) + require.NoError(t, state.Update(1, declareDiff, newClasses, false)) t.Run("re-declaring a class shouldnt change it's DeclaredAt attribute", func(t *testing.T) { declaredClass, err = state.Class(classHash) diff --git a/genesis/genesis.go b/genesis/genesis.go index af6b6d1ea5..3487074d6a 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -119,7 +119,7 @@ func GenesisStateDiff( if err != nil { return core.StateDiff{}, nil, err } - state, err := stateFactory.NewState(&felt.Zero, memDB.NewIndexedBatch()) + state, err := stateFactory.NewState(&felt.Zero, memDB.NewIndexedBatch(), nil) if err != nil { return core.StateDiff{}, nil, err } diff --git a/rpc/v8/transaction.go b/rpc/v8/transaction.go index b16f92dfbd..fac6118225 100644 --- a/rpc/v8/transaction.go +++ b/rpc/v8/transaction.go @@ -610,7 +610,6 @@ func (h *Handler) AddTransaction(ctx context.Context, tx *BroadcastedTransaction } else { res, err = h.pushToFeederGateway(ctx, tx) } - if err != nil { return AddTxResponse{}, err } diff --git a/sync/sync.go b/sync/sync.go index 37f4c6f4e0..a8ab1ad25f 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -672,7 +672,7 @@ func (s *Synchronizer) PendingState() (commonstate.StateReader, func() error, er } pendingStateUpdate := pending.GetStateUpdate() - state, err := s.blockchain.StateFactory.NewState(pendingStateUpdate.OldRoot, txn) + state, err := s.blockchain.StateFactory.NewState(pendingStateUpdate.OldRoot, txn, nil) if err != nil { return nil, nil, err } @@ -729,7 +729,7 @@ func (s *Synchronizer) PendingStateBeforeIndex( } pendingStateUpdate := pending.GetStateUpdate() - state, err := s.blockchain.StateFactory.NewState(pendingStateUpdate.OldRoot, txn) + state, err := s.blockchain.StateFactory.NewState(pendingStateUpdate.OldRoot, txn, nil) if err != nil { return nil, nil, err } diff --git a/vm/vm_test.go b/vm/vm_test.go index be96b8190e..9bf64856b8 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -36,7 +36,8 @@ func TestCallDeprecatedCairo(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) + batch := testDB.NewBatch() + testState, err := stateFactory.NewState(&felt.Zero, txn, batch) require.NoError(t, err) newRoot := felt.NewUnsafeFromString[felt.Felt]( "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8", @@ -51,7 +52,8 @@ func TestCallDeprecatedCairo(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false, true)) + }, false)) + require.NoError(t, batch.Write()) entryPoint := felt.NewUnsafeFromString[felt.Felt]("0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") @@ -80,7 +82,7 @@ func TestCallDeprecatedCairo(t *testing.T) { // if new state, we need to create a new state with the new root if statetestutils.UseNewState() { - testState, err = stateFactory.NewState(newRoot, txn) + testState, err = stateFactory.NewState(newRoot, txn, nil) require.NoError(t, err) } @@ -94,7 +96,7 @@ func TestCallDeprecatedCairo(t *testing.T) { }, }, }, - }, nil, false, true)) + }, nil, false)) ret, err = New(&chainInfo, false, nil).Call( &CallInfo{ @@ -132,7 +134,8 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) + batch := testDB.NewBatch() + testState, err := stateFactory.NewState(&felt.Zero, txn, batch) require.NoError(t, err) require.NoError(t, testState.Update(0, &core.StateUpdate{ @@ -145,7 +148,8 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false, true)) + }, false)) + require.NoError(t, batch.Write()) entryPoint := felt.NewUnsafeFromString[felt.Felt]("0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") feeTokens := utils.DefaultFeeTokenAddresses @@ -186,7 +190,8 @@ func TestCallCairo(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) + batch := testDB.NewBatch() + testState, err := stateFactory.NewState(&felt.Zero, txn, batch) require.NoError(t, err) newRoot := felt.NewUnsafeFromString[felt.Felt]( "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84", @@ -201,7 +206,8 @@ func TestCallCairo(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false, true)) + }, false)) + require.NoError(t, batch.Write()) logLevel := utils.NewLogLevel(utils.ERROR) log, err := utils.NewZapLogger(logLevel, false) @@ -235,8 +241,9 @@ func TestCallCairo(t *testing.T) { assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) // if new state, we need to create a new state with the new root + batch = testDB.NewBatch() if statetestutils.UseNewState() { - testState, err = stateFactory.NewState(newRoot, txn) + testState, err = stateFactory.NewState(newRoot, txn, batch) require.NoError(t, err) } @@ -250,7 +257,11 @@ func TestCallCairo(t *testing.T) { }, }, }, - }, nil, false, true)) + }, nil, false)) + + if statetestutils.UseNewState() { + require.NoError(t, batch.Write()) + } ret, err = New(&chainInfo, false, log).Call( &CallInfo{ @@ -287,7 +298,8 @@ func TestCallInfoErrorHandling(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - testState, err := stateFactory.NewState(&felt.Zero, txn) + batch := testDB.NewBatch() + testState, err := stateFactory.NewState(&felt.Zero, txn, batch) require.NoError(t, err) require.NoError(t, testState.Update(0, &core.StateUpdate{ OldRoot: &felt.Zero, @@ -299,7 +311,8 @@ func TestCallInfoErrorHandling(t *testing.T) { }, }, map[felt.Felt]core.Class{ *classHash: simpleClass, - }, false, true)) + }, false)) + require.NoError(t, batch.Write()) logLevel := utils.NewLogLevel(utils.ERROR) log, err := utils.NewZapLogger(logLevel, false) @@ -363,7 +376,7 @@ func TestExecute(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - state, err := stateFactory.NewState(&felt.Zero, txn) + state, err := stateFactory.NewState(&felt.Zero, txn, nil) require.NoError(t, err) t.Run("empty transaction list", func(t *testing.T) { From 1c0622c728a044e59d15f99e8a7fbe51b440c390 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 21 Oct 2025 23:44:02 +0200 Subject: [PATCH 40/47] lint --- core/state/state_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/core/state/state_test.go b/core/state/state_test.go index 01ec0fc197..24f5fa6fe2 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -583,7 +583,6 @@ func TestRevert(t *testing.T) { require.NoError(t, state.Update(block0, declareDiff, newClasses, false)) require.NoError(t, batch.Write()) - batch = stateDB.disk.NewBatch() declaredClass, err := state.Class(classHash) require.NoError(t, err) assert.Equal(t, uint64(0), declaredClass.At) From c62e9a571d6accd7217a6d5182242bc9b70add56 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Fri, 24 Oct 2025 11:05:38 +0200 Subject: [PATCH 41/47] address comments --- blockchain/blockchain_test.go | 53 ------------------------ consensus/p2p/validator/fixtures_test.go | 3 +- 2 files changed, 2 insertions(+), 54 deletions(-) diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index dca1d77b46..1d21d07cd6 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -1,7 +1,6 @@ package blockchain_test import ( - "context" "fmt" "testing" @@ -724,55 +723,3 @@ func TestSubscribeL1Head(t *testing.T) { require.True(t, ok) assert.Equal(t, l1Head, got) } - -func fetchStateUpdatesAndBlocks(samples int) ([]*core.StateUpdate, []*core.Block, error) { - client := feeder.NewClient(utils.Mainnet.FeederURL) - gw := adaptfeeder.New(client) - - suList := make([]*core.StateUpdate, samples) - blocks := make([]*core.Block, samples) - for i := range samples { - fmt.Println("fetching", i) - su, err := gw.StateUpdate(context.Background(), uint64(i)) - if err != nil { - return nil, nil, err - } - suList[i] = su - block, err := gw.BlockByNumber(context.Background(), uint64(i)) - if err != nil { - return nil, nil, err - } - blocks[i] = block - } - return suList, blocks, nil -} - -func BenchmarkBlockchainStore(b *testing.B) { - samples := 100 - stateUpdates, blocks, err := fetchStateUpdatesAndBlocks(samples) - require.NoError(b, err) - - b.Run("new", func(b *testing.B) { - for b.Loop() { - b.StopTimer() - chain := blockchain.New(memory.New(), &utils.Mainnet, true) - b.StartTimer() - - for j := range samples { - require.NoError(b, chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil)) - } - } - }) - - b.Run("old", func(b *testing.B) { - for b.Loop() { - b.StopTimer() - chain := blockchain.New(memory.New(), &utils.Mainnet, false) - b.StartTimer() - - for j := range samples { - require.NoError(b, chain.Store(blocks[j], &emptyCommitments, stateUpdates[j], nil)) - } - } - }) -} diff --git a/consensus/p2p/validator/fixtures_test.go b/consensus/p2p/validator/fixtures_test.go index efe7ffff68..93f4a5c0ff 100644 --- a/consensus/p2p/validator/fixtures_test.go +++ b/consensus/p2p/validator/fixtures_test.go @@ -110,7 +110,8 @@ func BuildTestFixture( executor.RegisterBuildResult(&buildResult) builder := builder.New( - blockchain.New(database, testCase.Network, statetestutils.UseNewState()), executor, + blockchain.New(database, testCase.Network, statetestutils.UseNewState()), + executor, ) return TestFixture{ From 16b7da1cd5011f13d6eb9979812770792aebc6f5 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 13 Nov 2025 10:34:54 +0100 Subject: [PATCH 42/47] new pebble writes handling --- blockchain/blockchain.go | 7 +++-- core/trie2/triedb/rawdb/database.go | 27 ++++++++++++++----- core/trie2/trienode/nodeset.go | 41 ++++++++++++++++++++++++----- db/pebble/db.go | 16 +++++++---- 4 files changed, 72 insertions(+), 19 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 55831b98da..a626063a9e 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -349,11 +349,14 @@ func (b *Blockchain) store( return err } - if err := b.runningFilter.Insert(block.EventsBloom, block.Number); err != nil { + if err = batch.Write(); err != nil { return err } - return batch.Write() + return b.runningFilter.Insert( + block.EventsBloom, + block.Number, + ) } // VerifyBlock assumes the block has already been sanity-checked. diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index 9141572d0c..f3b1a84220 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -1,6 +1,7 @@ package rawdb import ( + "sort" "sync" "github.com/NethermindEth/juno/core/felt" @@ -90,30 +91,44 @@ func (d *Database) Update( defer d.lock.Unlock() var classNodes classNodesMap + var classOrderedPaths []trieutils.Path var contractNodes contractNodesMap + var contractOrderedPaths []trieutils.Path var contractStorageNodes contractStorageNodesMap + var contractStorageOrderedPaths map[felt.Felt][]trieutils.Path if mergedClassNodes != nil { - classNodes, _ = mergedClassNodes.Flatten() + classNodes, classOrderedPaths, _, _ = mergedClassNodes.FlattenWithOrder() } if mergedContractNodes != nil { - contractNodes, contractStorageNodes = mergedContractNodes.Flatten() + contractNodes, contractOrderedPaths, contractStorageNodes, contractStorageOrderedPaths = mergedContractNodes.FlattenWithOrder() } - for path, n := range classNodes { + for _, path := range classOrderedPaths { + n := classNodes[path] if err := d.updateNode(batch, db.ClassTrie, &felt.Zero, &path, n, true); err != nil { return err } } - for path, n := range contractNodes { + for _, path := range contractOrderedPaths { + n := contractNodes[path] if err := d.updateNode(batch, db.ContractTrieContract, &felt.Zero, &path, n, false); err != nil { return err } } - for owner, nodes := range contractStorageNodes { - for path, n := range nodes { + owners := make([]felt.Felt, 0, len(contractStorageNodes)) + for owner := range contractStorageNodes { + owners = append(owners, owner) + } + sort.Slice(owners, func(i, j int) bool { + return owners[i].Cmp(&owners[j]) < 0 + }) + for _, owner := range owners { + orderedPaths := contractStorageOrderedPaths[owner] + for _, path := range orderedPaths { + n := contractStorageNodes[owner][path] if err := d.updateNode(batch, db.ContractTrieStorage, &owner, &path, n, false); err != nil { return err } diff --git a/core/trie2/trienode/nodeset.go b/core/trie2/trienode/nodeset.go index 7640aeaf14..9a9405dccd 100644 --- a/core/trie2/trienode/nodeset.go +++ b/core/trie2/trienode/nodeset.go @@ -12,17 +12,26 @@ import ( // Contains a set of nodes, which are indexed by their path in the trie. // It is not thread safe. type NodeSet struct { - Owner felt.Felt // The owner (i.e. contract address) - Nodes map[trieutils.Path]TrieNode - updates int // the count of updated and inserted nodes - deletes int // the count of deleted nodes + Owner felt.Felt // The owner (i.e. contract address) + Nodes map[trieutils.Path]TrieNode + orderedPaths []trieutils.Path // Paths in the order they were added + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes } func NewNodeSet(owner felt.Felt) NodeSet { - return NodeSet{Owner: owner, Nodes: make(map[trieutils.Path]TrieNode)} + return NodeSet{ + Owner: owner, + Nodes: make(map[trieutils.Path]TrieNode), + orderedPaths: make([]trieutils.Path, 0), + } } func (ns *NodeSet) Add(key *trieutils.Path, node TrieNode) { + _, exists := ns.Nodes[*key] + if !exists { + ns.orderedPaths = append(ns.orderedPaths, *key) + } if _, ok := node.(*DeletedNode); ok { ns.deletes += 1 } else { @@ -63,6 +72,12 @@ func (ns *NodeSet) MergeSet(other *NodeSet) error { if ns.Owner != other.Owner { return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, other.Owner) } + // Add paths from other in order, but only if they don't already exist + for _, path := range other.orderedPaths { + if _, exists := ns.Nodes[path]; !exists { + ns.orderedPaths = append(ns.orderedPaths, path) + } + } maps.Copy(ns.Nodes, other.Nodes) ns.updates += other.updates ns.deletes += other.deletes @@ -83,6 +98,9 @@ func (ns *NodeSet) Merge(owner felt.Felt, other map[trieutils.Path]TrieNode) err } else { ns.updates -= 1 } + } else { + // Only add to orderedPaths if it's a new path + ns.orderedPaths = append(ns.orderedPaths, path) } // overwrite the existing node (if it exists) if _, ok := node.(*DeletedNode); ok { @@ -103,7 +121,7 @@ type MergeNodeSet struct { func NewMergeNodeSet(nodes *NodeSet) *MergeNodeSet { ns := &MergeNodeSet{ - OwnerSet: &NodeSet{Nodes: make(map[trieutils.Path]TrieNode)}, + OwnerSet: &NodeSet{Nodes: make(map[trieutils.Path]TrieNode), orderedPaths: make([]trieutils.Path, 0)}, ChildSets: make(map[felt.Felt]*NodeSet), } if nodes == nil { @@ -138,3 +156,14 @@ func (m *MergeNodeSet) Flatten() (map[trieutils.Path]TrieNode, map[felt.Felt]map } return m.OwnerSet.Nodes, childFlat } + +// FlattenWithOrder returns the nodes along with their ordered paths +func (m *MergeNodeSet) FlattenWithOrder() (map[trieutils.Path]TrieNode, []trieutils.Path, map[felt.Felt]map[trieutils.Path]TrieNode, map[felt.Felt][]trieutils.Path) { + childFlat := make(map[felt.Felt]map[trieutils.Path]TrieNode, len(m.ChildSets)) + childOrderedPaths := make(map[felt.Felt][]trieutils.Path, len(m.ChildSets)) + for owner, set := range m.ChildSets { + childFlat[owner] = set.Nodes + childOrderedPaths[owner] = set.orderedPaths + } + return m.OwnerSet.Nodes, m.OwnerSet.orderedPaths, childFlat, childOrderedPaths +} diff --git a/db/pebble/db.go b/db/pebble/db.go index 3981c647e9..b2dbc7331b 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -47,11 +47,17 @@ func NewWithOptions(path string, cacheSizeMB uint, maxOpenFiles int, colouredLog return nil, fmt.Errorf("create DB logger: %w", err) } - return newPebble(path, &pebble.Options{ - Logger: dbLog, - Cache: pebble.NewCache(int64(cacheSizeMB * utils.Megabyte)), - MaxOpenFiles: maxOpenFiles, - }) + opts := &pebble.Options{ + Logger: dbLog, + Cache: pebble.NewCache(int64(cacheSizeMB * utils.Megabyte)), + MaxOpenFiles: maxOpenFiles, + L0CompactionFileThreshold: 8, + L0StopWritesThreshold: 24, + MemTableSize: 8 * utils.Megabyte, + MaxConcurrentCompactions: func() int { return 2 }, + } + + return newPebble(path, opts) } func newPebble(path string, options *pebble.Options) (*DB, error) { From 268bf197b92a95ab3ae536cc1af7d9891fbc402d Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 17 Nov 2025 16:14:52 +0100 Subject: [PATCH 43/47] updates --- adapters/p2p2core/state.go | 3 +- blockchain/blockchain.go | 27 ++++++------ builder/builder.go | 4 +- cmd/juno/dbcmd_test.go | 2 +- consensus/consensus_test.go | 2 +- .../p2p/validator/empty_fixtures_test.go | 2 +- consensus/p2p/validator/fixtures_test.go | 2 +- .../validator/proposal_stream_demux_test.go | 2 +- consensus/p2p/validator/transition_test.go | 2 +- consensus/proposer/proposer_test.go | 2 +- core/common_state.go | 35 +++++++++++++++ .../commontrie/trie.go => common_trie.go} | 4 +- core/mocks/mock_commonstate_reader.go | 13 +++--- core/pending.go | 14 +++--- core/pending_state.go | 14 +++--- core/running_event_filter_test.go | 2 +- core/state.go | 13 +++--- core/state/history.go | 7 ++- core/state/history_test.go | 6 --- core/state/state.go | 15 +++---- .../state_factory.go} | 43 +++---------------- core/state_snapshot.go | 7 ++- core/state_snapshot_test.go | 2 +- core/state_test.go | 8 ++-- core/trie2/trie.go | 1 + genesis/genesis.go | 5 ++- l1/l1_pkg_test.go | 2 +- l1/l1_test.go | 2 +- mempool/mempool_test.go | 2 +- migration/migration_pkg_test.go | 2 +- mocks/mock_blockchain.go | 13 +++--- mocks/mock_commonstate_reader.go | 13 +++--- mocks/mock_state.go | 13 +++--- mocks/mock_vm.go | 5 +-- node/node_test.go | 2 +- node/throttled_vm.go | 6 +-- plugin/plugin_test.go | 2 +- rpc/v10/events_test.go | 3 +- rpc/v10/helpers.go | 4 +- rpc/v10/pending_data_wrapper.go | 2 +- rpc/v10/state_update_test.go | 3 +- rpc/v10/subscriptions_test.go | 9 ++-- rpc/v10/trace.go | 8 ++-- rpc/v10/trace_test.go | 3 +- rpc/v6/block_test.go | 2 +- rpc/v6/estimate_fee_test.go | 4 +- rpc/v6/events_test.go | 2 +- rpc/v6/helpers.go | 6 +-- rpc/v6/pending_data_wrapper.go | 4 +- rpc/v6/state_update_test.go | 2 +- rpc/v6/trace.go | 4 +- rpc/v6/trace_test.go | 2 +- rpc/v7/block_test.go | 2 +- rpc/v7/helpers.go | 6 +-- rpc/v7/pending_data_wrapper.go | 4 +- rpc/v7/trace.go | 4 +- rpc/v7/trace_test.go | 2 +- rpc/v8/block_test.go | 2 +- rpc/v8/helpers.go | 6 +-- rpc/v8/pending_data_wrapper.go | 4 +- rpc/v8/storage.go | 37 ++++++++-------- rpc/v8/storage_test.go | 21 +++++---- rpc/v8/subscriptions_test.go | 8 ++-- rpc/v8/trace.go | 4 +- rpc/v8/trace_test.go | 2 +- rpc/v9/block_test.go | 2 +- rpc/v9/compiled_casm_test.go | 4 +- rpc/v9/events_test.go | 2 +- rpc/v9/helpers.go | 6 +-- rpc/v9/pending_data_wrapper.go | 4 +- rpc/v9/state_update_test.go | 2 +- rpc/v9/storage.go | 35 ++++++++------- rpc/v9/storage_test.go | 22 +++++----- rpc/v9/subscriptions_test.go | 12 +++--- rpc/v9/trace.go | 8 ++-- rpc/v9/trace_test.go | 2 +- sequencer/sequencer.go | 4 +- sequencer/sequencer_test.go | 2 +- sync/pending_polling_test.go | 4 +- sync/pendingdata/helpers.go | 10 ++--- sync/reorg_test.go | 2 +- sync/sync.go | 4 +- sync/sync_test.go | 2 +- vm/vm.go | 12 +++--- vm/vm_test.go | 13 +++--- 85 files changed, 306 insertions(+), 305 deletions(-) create mode 100644 core/common_state.go rename core/{state/commontrie/trie.go => common_trie.go} (84%) rename core/state/{commonstate/state.go => statefactory/state_factory.go} (51%) diff --git a/adapters/p2p2core/state.go b/adapters/p2p2core/state.go index 5d63fd3fef..d4c0405823 100644 --- a/adapters/p2p2core/state.go +++ b/adapters/p2p2core/state.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" "github.com/starknet-io/starknet-p2pspecs/p2p/proto/class" @@ -15,7 +14,7 @@ import ( ) func AdaptStateDiff( - reader commonstate.StateReader, + reader core.CommonStateReader, contractDiffs []*state.ContractDiff, classes []*class.Class, ) (*core.StateDiff, error) { diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 25dc3d0bf5..f04c070b52 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -7,7 +7,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/state/statefactory" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" @@ -53,9 +54,9 @@ type Reader interface { } type StateProvider interface { - HeadState() (commonstate.StateReader, StateCloser, error) - StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, StateCloser, error) - StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, StateCloser, error) + HeadState() (core.CommonStateReader, StateCloser, error) + StateAtBlockHash(blockHash *felt.Felt) (core.CommonStateReader, StateCloser, error) + StateAtBlockNumber(blockNumber uint64) (core.CommonStateReader, StateCloser, error) } var ErrParentDoesNotMatchHead = errors.New("block's parent hash does not match head block hash") @@ -72,7 +73,7 @@ type Blockchain struct { l1HeadFeed *feed.Feed[*core.L1Head] cachedFilters *AggregatedBloomFilterCache runningFilter *core.RunningEventFilter - StateFactory *commonstate.StateFactory + StateFactory *statefactory.StateFactory } func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) *Blockchain { @@ -82,7 +83,7 @@ func New(database db.KeyValueStore, network *utils.Network, stateVersion bool) * } stateDB := state.NewStateDB(database, trieDB) - stateFactory, err := commonstate.NewStateFactory(stateVersion, trieDB, stateDB) + stateFactory, err := statefactory.NewStateFactory(stateVersion, trieDB, stateDB) if err != nil { panic(err) } @@ -422,7 +423,7 @@ type StateCloser = func() error var noopStateCloser = func() error { return nil } // TODO: remove this once we refactor the state // HeadState returns a StateReader that provides a stable view to the latest state -func (b *Blockchain) HeadState() (commonstate.StateReader, StateCloser, error) { +func (b *Blockchain) HeadState() (core.CommonStateReader, StateCloser, error) { b.listener.OnRead("HeadState") txn := b.database.NewIndexedBatch() @@ -444,7 +445,7 @@ func (b *Blockchain) HeadState() (commonstate.StateReader, StateCloser, error) { // StateAtBlockNumber returns a StateReader that provides a stable view to the state at the given block number func (b *Blockchain) StateAtBlockNumber( blockNumber uint64, -) (commonstate.StateReader, StateCloser, error) { +) (core.CommonStateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") txn := b.database.NewIndexedBatch() @@ -456,7 +457,7 @@ func (b *Blockchain) StateAtBlockNumber( if !b.StateFactory.UseNewState { deprecatedState := core.NewState(txn) snapshot := core.NewStateSnapshot(deprecatedState, blockNumber) - return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil + return snapshot, noopStateCloser, nil } height, err := core.GetChainHeight(txn) @@ -473,13 +474,13 @@ func (b *Blockchain) StateAtBlockNumber( if err != nil { return nil, nil, err } - return commonstate.NewStateReaderAdapter(&history), noopStateCloser, nil + return &history, noopStateCloser, nil } // StateAtBlockHash returns a StateReader that provides a stable view to the state at the given block hash func (b *Blockchain) StateAtBlockHash( blockHash *felt.Felt, -) (commonstate.StateReader, StateCloser, error) { +) (core.CommonStateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockHash") if blockHash.IsZero() { emptyState, err := b.StateFactory.EmptyState() @@ -494,7 +495,7 @@ func (b *Blockchain) StateAtBlockHash( if !b.StateFactory.UseNewState { deprecatedState := core.NewState(txn) snapshot := core.NewStateSnapshot(deprecatedState, header.Number) - return commonstate.NewDeprecatedStateReaderAdapter(snapshot), noopStateCloser, nil + return snapshot, noopStateCloser, nil } height, err := core.GetChainHeight(txn) @@ -511,7 +512,7 @@ func (b *Blockchain) StateAtBlockHash( if err != nil { return nil, nil, err } - return commonstate.NewStateReaderAdapter(&history), noopStateCloser, nil + return &history, noopStateCloser, nil } // EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain diff --git a/builder/builder.go b/builder/builder.go index b7b4f1c2ea..7f7f799fcc 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/utils" "github.com/consensys/gnark-crypto/ecc/stark-curve/ecdsa" @@ -131,7 +131,7 @@ func (b *Builder) getRevealedBlockHash(blockHeight uint64) (*felt.Felt, error) { func (b *Builder) PendingState( buildState *BuildState, -) (commonstate.StateReader, func() error, error) { +) (core.CommonStateReader, func() error, error) { if buildState.Preconfirmed == nil { return nil, nil, core.ErrPendingDataNotFound } diff --git a/cmd/juno/dbcmd_test.go b/cmd/juno/dbcmd_test.go index fbef9038b4..18887ef1bf 100644 --- a/cmd/juno/dbcmd_test.go +++ b/cmd/juno/dbcmd_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" juno "github.com/NethermindEth/juno/cmd/juno" "github.com/NethermindEth/juno/core" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/pebble" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index f1f3e334a1..2550b29adb 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/p2p/pubsub/testutils" diff --git a/consensus/p2p/validator/empty_fixtures_test.go b/consensus/p2p/validator/empty_fixtures_test.go index 871c5f107d..682931de56 100644 --- a/consensus/p2p/validator/empty_fixtures_test.go +++ b/consensus/p2p/validator/empty_fixtures_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" "github.com/starknet-io/starknet-p2pspecs/p2p/proto/common" diff --git a/consensus/p2p/validator/fixtures_test.go b/consensus/p2p/validator/fixtures_test.go index 5ff92282ef..100dfbb27c 100644 --- a/consensus/p2p/validator/fixtures_test.go +++ b/consensus/p2p/validator/fixtures_test.go @@ -14,7 +14,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/starknet" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" diff --git a/consensus/p2p/validator/proposal_stream_demux_test.go b/consensus/p2p/validator/proposal_stream_demux_test.go index 1294368942..19ed3d9838 100644 --- a/consensus/p2p/validator/proposal_stream_demux_test.go +++ b/consensus/p2p/validator/proposal_stream_demux_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/p2p/pubsub/testutils" diff --git a/consensus/p2p/validator/transition_test.go b/consensus/p2p/validator/transition_test.go index 7e68fa9ac8..d745175ddf 100644 --- a/consensus/p2p/validator/transition_test.go +++ b/consensus/p2p/validator/transition_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" diff --git a/consensus/proposer/proposer_test.go b/consensus/proposer/proposer_test.go index da5d397cd4..d9b88fafd1 100644 --- a/consensus/proposer/proposer_test.go +++ b/consensus/proposer/proposer_test.go @@ -15,7 +15,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" diff --git a/core/common_state.go b/core/common_state.go new file mode 100644 index 0000000000..e30351b725 --- /dev/null +++ b/core/common_state.go @@ -0,0 +1,35 @@ +package core + +import ( + "github.com/NethermindEth/juno/core/felt" +) + +//go:generate mockgen -destination=../../mocks/mock_commonstate_reader.go -package=mocks github.com/NethermindEth/juno/core CommonStateReader +type CommonState interface { + CommonStateReader + + ContractStorageAt(addr, key *felt.Felt, blockNumber uint64) (felt.Felt, error) + ContractNonceAt(addr *felt.Felt, blockNumber uint64) (felt.Felt, error) + ContractClassHashAt(addr *felt.Felt, blockNumber uint64) (felt.Felt, error) + ContractDeployedAt(addr *felt.Felt, blockNumber uint64) (bool, error) + + Update(blockNum uint64, + update *StateUpdate, + declaredClasses map[felt.Felt]ClassDefinition, + skipVerifyNewRoot bool, + flushChanges bool, + ) error + Revert(blockNum uint64, update *StateUpdate) error + Commitment() (felt.Felt, error) +} + +type CommonStateReader interface { + ContractClassHash(addr *felt.Felt) (felt.Felt, error) + ContractNonce(addr *felt.Felt) (felt.Felt, error) + ContractStorage(addr, key *felt.Felt) (felt.Felt, error) + Class(classHash *felt.Felt) (*DeclaredClassDefinition, error) + + ClassTrie() (CommonTrie, error) + ContractTrie() (CommonTrie, error) + ContractStorageTrie(addr *felt.Felt) (CommonTrie, error) +} diff --git a/core/state/commontrie/trie.go b/core/common_trie.go similarity index 84% rename from core/state/commontrie/trie.go rename to core/common_trie.go index f62db6265e..64f80d6a21 100644 --- a/core/state/commontrie/trie.go +++ b/core/common_trie.go @@ -1,11 +1,11 @@ -package commontrie +package core import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" ) -type Trie interface { +type CommonTrie interface { Update(key, value *felt.Felt) error Get(key *felt.Felt) (felt.Felt, error) Hash() (felt.Felt, error) diff --git a/core/mocks/mock_commonstate_reader.go b/core/mocks/mock_commonstate_reader.go index 6de99c57f3..f28dd351e5 100644 --- a/core/mocks/mock_commonstate_reader.go +++ b/core/mocks/mock_commonstate_reader.go @@ -14,7 +14,6 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - commontrie "github.com/NethermindEth/juno/core/state/commontrie" gomock "go.uber.org/mock/gomock" ) @@ -58,10 +57,10 @@ func (mr *MockStateReaderMockRecorder) Class(classHash any) *gomock.Call { } // ClassTrie mocks base method. -func (m *MockStateReader) ClassTrie() (commontrie.Trie, error) { +func (m *MockStateReader) ClassTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClassTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -118,10 +117,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorage(addr, key any) *gomock.Ca } // ContractStorageTrie mocks base method. -func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractStorageTrie", addr) - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -133,10 +132,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorageTrie(addr any) *gomock.Cal } // ContractTrie mocks base method. -func (m *MockStateReader) ContractTrie() (commontrie.Trie, error) { +func (m *MockStateReader) ContractTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/core/pending.go b/core/pending.go index b950c6a58a..86a92a140d 100644 --- a/core/pending.go +++ b/core/pending.go @@ -47,9 +47,9 @@ type PendingData interface { ReceiptByHash(hash *felt.Felt) (*TransactionReceipt, *felt.Felt, uint64, error) // PendingStateBeforeIndex returns the state obtained by applying all transaction state diffs // up to given index in the pre-confirmed block. - PendingStateBeforeIndex(baseState StateReader, index uint) (StateReader, error) + PendingStateBeforeIndex(baseState CommonStateReader, index uint) (CommonStateReader, error) // PendingState returns the state resulting from execution of the pending data - PendingState(baseState StateReader) StateReader + PendingState(baseState CommonStateReader) CommonStateReader } type Pending struct { @@ -136,11 +136,11 @@ func (p *Pending) ReceiptByHash( return nil, nil, 0, ErrTransactionReceiptNotFound } -func (p *Pending) PendingStateBeforeIndex(baseState StateReader, index uint) (StateReader, error) { +func (p *Pending) PendingStateBeforeIndex(baseState CommonStateReader, index uint) (CommonStateReader, error) { return nil, ErrPendingStateBeforeIndexNotSupported } -func (p *Pending) PendingState(baseState StateReader) StateReader { +func (p *Pending) PendingState(baseState CommonStateReader) CommonStateReader { return NewPendingState( p.StateUpdate.StateDiff, p.NewClasses, @@ -290,9 +290,9 @@ func (p *PreConfirmed) ReceiptByHash( } func (p *PreConfirmed) PendingStateBeforeIndex( - baseState StateReader, + baseState CommonStateReader, index uint, -) (StateReader, error) { +) (CommonStateReader, error) { if index > uint(len(p.Block.Transactions)) { return nil, ErrTransactionIndexOutOfBounds } @@ -316,7 +316,7 @@ func (p *PreConfirmed) PendingStateBeforeIndex( return NewPendingState(&stateDiff, newClasses, baseState), nil } -func (p *PreConfirmed) PendingState(baseState StateReader) StateReader { +func (p *PreConfirmed) PendingState(baseState CommonStateReader) CommonStateReader { stateDiff := EmptyStateDiff() newClasses := make(map[felt.Felt]ClassDefinition) diff --git a/core/pending_state.go b/core/pending_state.go index d5cc1b03c7..7252f0dd4c 100644 --- a/core/pending_state.go +++ b/core/pending_state.go @@ -5,8 +5,6 @@ import ( "fmt" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/db" ) @@ -15,13 +13,13 @@ var feltOne = new(felt.Felt).SetUint64(1) type PendingState struct { stateDiff *StateDiff newClasses map[felt.Felt]ClassDefinition - head commonstate.StateReader + head CommonStateReader } func NewPendingState( stateDiff *StateDiff, newClasses map[felt.Felt]ClassDefinition, - head commonstate.StateReader, + head CommonStateReader, ) *PendingState { return &PendingState{ stateDiff: stateDiff, @@ -78,15 +76,15 @@ func (p *PendingState) Class(classHash *felt.Felt) (*DeclaredClassDefinition, er return p.head.Class(classHash) } -func (p *PendingState) ClassTrie() (commontrie.Trie, error) { +func (p *PendingState) ClassTrie() (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (p *PendingState) ContractTrie() (commontrie.Trie, error) { +func (p *PendingState) ContractTrie() (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (p *PendingState) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (p *PendingState) ContractStorageTrie(addr *felt.Felt) (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } @@ -97,7 +95,7 @@ type PendingStateWriter struct { func NewPendingStateWriter( stateDiff *StateDiff, newClasses map[felt.Felt]ClassDefinition, - head commonstate.StateReader, + head CommonStateReader, ) PendingStateWriter { return PendingStateWriter{ PendingState: &PendingState{ diff --git a/core/running_event_filter_test.go b/core/running_event_filter_test.go index 66d2956c6e..b5889ee223 100644 --- a/core/running_event_filter_test.go +++ b/core/running_event_filter_test.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/encoder" diff --git a/core/state.go b/core/state.go index 4417d541b8..996aa4940b 100644 --- a/core/state.go +++ b/core/state.go @@ -12,7 +12,6 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" @@ -45,9 +44,9 @@ type StateReader interface { ContractStorage(addr, key *felt.Felt) (felt.Felt, error) Class(classHash *felt.Felt) (*DeclaredClassDefinition, error) - ClassTrie() (commontrie.Trie, error) - ContractTrie() (commontrie.Trie, error) - ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) + ClassTrie() (CommonTrie, error) + ContractTrie() (CommonTrie, error) + ContractStorageTrie(addr *felt.Felt) (CommonTrie, error) } type State struct { @@ -135,18 +134,18 @@ func (s *State) Commitment() (felt.Felt, error) { return root, nil } -func (s *State) ClassTrie() (commontrie.Trie, error) { +func (s *State) ClassTrie() (CommonTrie, error) { // We don't need to call the closer function here because we are only reading the trie tr, _, err := s.classesTrie() return tr, err } -func (s *State) ContractTrie() (commontrie.Trie, error) { +func (s *State) ContractTrie() (CommonTrie, error) { tr, _, err := s.storage() return tr, err } -func (s *State) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (s *State) ContractStorageTrie(addr *felt.Felt) (CommonTrie, error) { return storage(addr, s.txn) } diff --git a/core/state/history.go b/core/state/history.go index 2c8b655983..a128dff1bd 100644 --- a/core/state/history.go +++ b/core/state/history.go @@ -5,7 +5,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/db" ) @@ -99,14 +98,14 @@ func (s *StateHistory) Class(classHash *felt.Felt) (*core.DeclaredClassDefinitio return declaredClass, nil } -func (s *StateHistory) ClassTrie() (commontrie.Trie, error) { +func (s *StateHistory) ClassTrie() (core.CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (s *StateHistory) ContractTrie() (commontrie.Trie, error) { +func (s *StateHistory) ContractTrie() (core.CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (s *StateHistory) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (s *StateHistory) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } diff --git a/core/state/history_test.go b/core/state/history_test.go index 7e363b6fda..7b12aad4ac 100644 --- a/core/state/history_test.go +++ b/core/state/history_test.go @@ -19,12 +19,6 @@ func TestNewStateHistory(t *testing.T) { assert.Equal(t, uint64(0), history.blockNum) assert.NotNil(t, history.state) }) - - t.Run("invalid state root", func(t *testing.T) { - invalidRoot := felt.NewUnsafeFromString[felt.Felt]("0x999") - _, err := NewStateHistory(1, invalidRoot, stateDB) - assert.Error(t, err) - }) } func TestStateHistoryContractOperations(t *testing.T) { diff --git a/core/state/state.go b/core/state/state.go index fd2df6bc4a..7791203880 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -12,7 +12,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/core/trie2/trieutils" @@ -55,9 +54,9 @@ type ClassReader interface { } type TrieProvider interface { - ClassTrie() (commontrie.Trie, error) - ContractTrie() (commontrie.Trie, error) - ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) + ClassTrie() (core.CommonTrie, error) + ContractTrie() (core.CommonTrie, error) + ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) } type State struct { @@ -74,7 +73,7 @@ func New(stateRoot *felt.Felt, db *StateDB) (*State, error) { if err != nil { return nil, err } - + fmt.Println("contractTrie", contractTrie, err) classTrie, err := db.ClassTrie(stateRoot) if err != nil { return nil, err @@ -147,15 +146,15 @@ func (s *State) Class(classHash *felt.Felt) (*core.DeclaredClassDefinition, erro return GetClass(s.db.disk, classHash) } -func (s *State) ClassTrie() (commontrie.Trie, error) { +func (s *State) ClassTrie() (core.CommonTrie, error) { return s.classTrie, nil } -func (s *State) ContractTrie() (commontrie.Trie, error) { +func (s *State) ContractTrie() (core.CommonTrie, error) { return s.contractTrie, nil } -func (s *State) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (s *State) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) { return s.db.ContractStorageTrie(&s.initRoot, addr) } diff --git a/core/state/commonstate/state.go b/core/state/statefactory/state_factory.go similarity index 51% rename from core/state/commonstate/state.go rename to core/state/statefactory/state_factory.go index 093747d956..9193bdb1ee 100644 --- a/core/state/commonstate/state.go +++ b/core/state/statefactory/state_factory.go @@ -1,45 +1,14 @@ -package commonstate +package statefactory import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" ) -//go:generate mockgen -destination=../../mocks/mock_commonstate_reader.go -package=mocks github.com/NethermindEth/juno/core/state/commonstate StateReader -type State interface { - StateReader - - ContractStorageAt(addr, key *felt.Felt, blockNumber uint64) (felt.Felt, error) - ContractNonceAt(addr *felt.Felt, blockNumber uint64) (felt.Felt, error) - ContractClassHashAt(addr *felt.Felt, blockNumber uint64) (felt.Felt, error) - ContractDeployedAt(addr *felt.Felt, blockNumber uint64) (bool, error) - - Update(blockNum uint64, - update *core.StateUpdate, - declaredClasses map[felt.Felt]core.ClassDefinition, - skipVerifyNewRoot bool, - flushChanges bool, - ) error - Revert(blockNum uint64, update *core.StateUpdate) error - Commitment() (felt.Felt, error) -} - -type StateReader interface { - ContractClassHash(addr *felt.Felt) (felt.Felt, error) - ContractNonce(addr *felt.Felt) (felt.Felt, error) - ContractStorage(addr, key *felt.Felt) (felt.Felt, error) - Class(classHash *felt.Felt) (*core.DeclaredClassDefinition, error) - - ClassTrie() (commontrie.Trie, error) - ContractTrie() (commontrie.Trie, error) - ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) -} - type StateFactory struct { UseNewState bool triedb *triedb.Database @@ -62,24 +31,24 @@ func NewStateFactory( }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (State, error) { +func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (core.CommonState, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) return deprecatedState, nil } - stateState, err := state.New(stateRoot, sf.stateDB) + state, err := state.New(stateRoot, sf.stateDB) if err != nil { return nil, err } - return stateState, nil + return state, nil } func (sf *StateFactory) NewStateReader( stateRoot *felt.Felt, txn db.IndexedBatch, blockNumber uint64, -) (StateReader, error) { +) (core.CommonStateReader, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) snapshot := core.NewStateSnapshot(deprecatedState, blockNumber) @@ -93,7 +62,7 @@ func (sf *StateFactory) NewStateReader( return &history, nil } -func (sf *StateFactory) EmptyState() (StateReader, error) { +func (sf *StateFactory) EmptyState() (core.CommonStateReader, error) { if !sf.UseNewState { memDB := memory.New() txn := memDB.NewIndexedBatch() diff --git a/core/state_snapshot.go b/core/state_snapshot.go index f6c8eace7e..cf9b04f70c 100644 --- a/core/state_snapshot.go +++ b/core/state_snapshot.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commontrie" "github.com/NethermindEth/juno/db" ) @@ -95,14 +94,14 @@ func (s *stateSnapshot) Class(classHash *felt.Felt) (*DeclaredClassDefinition, e return declaredClass, nil } -func (s *stateSnapshot) ClassTrie() (commontrie.Trie, error) { +func (s *stateSnapshot) ClassTrie() (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (s *stateSnapshot) ContractTrie() (commontrie.Trie, error) { +func (s *stateSnapshot) ContractTrie() (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } -func (s *stateSnapshot) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (s *stateSnapshot) ContractStorageTrie(addr *felt.Felt) (CommonTrie, error) { return nil, ErrHistoricalTrieNotSupported } diff --git a/core/state_snapshot_test.go b/core/state_snapshot_test.go index 842335c3e6..3a43fbe9fc 100644 --- a/core/state_snapshot_test.go +++ b/core/state_snapshot_test.go @@ -68,7 +68,7 @@ func TestStateSnapshot(t *testing.T) { require.NoError(t, err) for desc, test := range map[string]struct { - snapshot core.StateReader + snapshot core.CommonStateReader checker func(*testing.T, felt.Felt, error) }{ "contract is not deployed": { diff --git a/core/state_test.go b/core/state_test.go index 8d1d8185c1..4c3866a75a 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -51,7 +51,7 @@ func TestUpdate(t *testing.T) { t.Run("error when state current root doesn't match state update's old root", func(t *testing.T) { - oldRoot := felt.NewUnsafeFromString[felt.Felt]("some old root") + oldRoot := felt.NewFromBytes[felt.Felt]([]byte("some old root")) su := &core.StateUpdate{ OldRoot: oldRoot, } @@ -64,7 +64,7 @@ func TestUpdate(t *testing.T) { }) t.Run("error when state new root doesn't match state update's new root", func(t *testing.T) { - newRoot := new(felt.Felt).SetBytes([]byte("some new root")) + newRoot := felt.NewFromBytes[felt.Felt]([]byte("some old root")) su := &core.StateUpdate{ NewRoot: newRoot, OldRoot: su0.NewRoot, @@ -115,7 +115,7 @@ func TestUpdate(t *testing.T) { scValue := felt.NewUnsafeFromString[felt.Felt]( "0x10979c6b0b36b03be36739a21cc43a51076545ce6d3397f1b45c7e286474ad5", ) - scAddr := new(felt.Felt).SetUint64(1) + scAddr := felt.NewFromUint64[felt.Felt](1) su4 := &core.StateUpdate{ OldRoot: su3.NewRoot, @@ -247,7 +247,7 @@ func TestNonce(t *testing.T) { }) t.Run("update contract nonce", func(t *testing.T) { - expectedNonce := new(felt.Felt).SetUint64(1) + expectedNonce := felt.NewFromUint64[felt.Felt](1) su = &core.StateUpdate{ NewRoot: felt.NewUnsafeFromString[felt.Felt]( "0x6210642ffd49f64617fc9e5c0bbe53a6a92769e2996eb312a42d2bdb7f2afc1", diff --git a/core/trie2/trie.go b/core/trie2/trie.go index aa67aa2624..cc58adf368 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -76,6 +76,7 @@ func New( } root, err := tr.resolveNode(nil, Path{}) + fmt.Println("root", root, err) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return nil, err } diff --git a/genesis/genesis.go b/genesis/genesis.go index 0cbb5e7d4e..8456d95492 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -10,7 +10,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/core/state/statefactory" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" rpc "github.com/NethermindEth/juno/rpc/v8" @@ -114,7 +115,7 @@ func GenesisStateDiff( stateDB := state.NewStateDB(memDB, triedb) // TODO(maksymmalick): remove this after integration done - stateFactory, err := commonstate.NewStateFactory(false, triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(false, triedb, stateDB) if err != nil { return core.StateDiff{}, nil, err } diff --git a/l1/l1_pkg_test.go b/l1/l1_pkg_test.go index 1a45c7e1a7..091e4b548d 100644 --- a/l1/l1_pkg_test.go +++ b/l1/l1_pkg_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1/contract" "github.com/NethermindEth/juno/mocks" diff --git a/l1/l1_test.go b/l1/l1_test.go index 64ca78a53f..09dc8ecb85 100644 --- a/l1/l1_test.go +++ b/l1/l1_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1" "github.com/NethermindEth/juno/l1/contract" diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 7bf9ec95c3..2f230c09b1 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" _ "github.com/NethermindEth/juno/encoder/registry" diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index b6af8c685b..7ddf3ecb4e 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" diff --git a/mocks/mock_blockchain.go b/mocks/mock_blockchain.go index d44beb870a..598267d151 100644 --- a/mocks/mock_blockchain.go +++ b/mocks/mock_blockchain.go @@ -15,7 +15,6 @@ import ( blockchain "github.com/NethermindEth/juno/blockchain" core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - commonstate "github.com/NethermindEth/juno/core/state/commonstate" utils "github.com/NethermindEth/juno/utils" common "github.com/ethereum/go-ethereum/common" gomock "go.uber.org/mock/gomock" @@ -151,10 +150,10 @@ func (mr *MockReaderMockRecorder) Head() *gomock.Call { } // HeadState mocks base method. -func (m *MockReader) HeadState() (commonstate.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) HeadState() (core.CommonStateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadState") - ret0, _ := ret[0].(commonstate.StateReader) + ret0, _ := ret[0].(core.CommonStateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -258,10 +257,10 @@ func (mr *MockReaderMockRecorder) Receipt(hash any) *gomock.Call { } // StateAtBlockHash mocks base method. -func (m *MockReader) StateAtBlockHash(blockHash *felt.Felt) (commonstate.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) StateAtBlockHash(blockHash *felt.Felt) (core.CommonStateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StateAtBlockHash", blockHash) - ret0, _ := ret[0].(commonstate.StateReader) + ret0, _ := ret[0].(core.CommonStateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 @@ -274,10 +273,10 @@ func (mr *MockReaderMockRecorder) StateAtBlockHash(blockHash any) *gomock.Call { } // StateAtBlockNumber mocks base method. -func (m *MockReader) StateAtBlockNumber(blockNumber uint64) (commonstate.StateReader, blockchain.StateCloser, error) { +func (m *MockReader) StateAtBlockNumber(blockNumber uint64) (core.CommonStateReader, blockchain.StateCloser, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StateAtBlockNumber", blockNumber) - ret0, _ := ret[0].(commonstate.StateReader) + ret0, _ := ret[0].(core.CommonStateReader) ret1, _ := ret[1].(blockchain.StateCloser) ret2, _ := ret[2].(error) return ret0, ret1, ret2 diff --git a/mocks/mock_commonstate_reader.go b/mocks/mock_commonstate_reader.go index a0efad3ceb..f77bad40dc 100644 --- a/mocks/mock_commonstate_reader.go +++ b/mocks/mock_commonstate_reader.go @@ -14,7 +14,6 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - commontrie "github.com/NethermindEth/juno/core/state/commontrie" gomock "go.uber.org/mock/gomock" ) @@ -58,10 +57,10 @@ func (mr *MockStateReaderMockRecorder) Class(classHash any) *gomock.Call { } // ClassTrie mocks base method. -func (m *MockStateReader) ClassTrie() (commontrie.Trie, error) { +func (m *MockStateReader) ClassTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClassTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -118,10 +117,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorage(addr, key any) *gomock.Ca } // ContractStorageTrie mocks base method. -func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (m *MockStateReader) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractStorageTrie", addr) - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -133,10 +132,10 @@ func (mr *MockStateReaderMockRecorder) ContractStorageTrie(addr any) *gomock.Cal } // ContractTrie mocks base method. -func (m *MockStateReader) ContractTrie() (commontrie.Trie, error) { +func (m *MockStateReader) ContractTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/mocks/mock_state.go b/mocks/mock_state.go index 5ea505f1f9..3cfe40a490 100644 --- a/mocks/mock_state.go +++ b/mocks/mock_state.go @@ -14,7 +14,6 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - commontrie "github.com/NethermindEth/juno/core/state/commontrie" gomock "go.uber.org/mock/gomock" ) @@ -73,10 +72,10 @@ func (mr *MockStateHistoryReaderMockRecorder) Class(classHash any) *gomock.Call } // ClassTrie mocks base method. -func (m *MockStateHistoryReader) ClassTrie() (commontrie.Trie, error) { +func (m *MockStateHistoryReader) ClassTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClassTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -193,10 +192,10 @@ func (mr *MockStateHistoryReaderMockRecorder) ContractStorageAt(addr, key, block } // ContractStorageTrie mocks base method. -func (m *MockStateHistoryReader) ContractStorageTrie(addr *felt.Felt) (commontrie.Trie, error) { +func (m *MockStateHistoryReader) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractStorageTrie", addr) - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -208,10 +207,10 @@ func (mr *MockStateHistoryReaderMockRecorder) ContractStorageTrie(addr any) *gom } // ContractTrie mocks base method. -func (m *MockStateHistoryReader) ContractTrie() (commontrie.Trie, error) { +func (m *MockStateHistoryReader) ContractTrie() (core.CommonTrie, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ContractTrie") - ret0, _ := ret[0].(commontrie.Trie) + ret0, _ := ret[0].(core.CommonTrie) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/mocks/mock_vm.go b/mocks/mock_vm.go index 9caf33aab3..ce830e3505 100644 --- a/mocks/mock_vm.go +++ b/mocks/mock_vm.go @@ -14,7 +14,6 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" - commonstate "github.com/NethermindEth/juno/core/state/commonstate" vm "github.com/NethermindEth/juno/vm" gomock "go.uber.org/mock/gomock" ) @@ -44,7 +43,7 @@ func (m *MockVM) EXPECT() *MockVMMockRecorder { } // Call mocks base method. -func (m *MockVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state commonstate.StateReader, maxSteps, maxGas uint64, structuredErrStack, returnStateDiff bool) (vm.CallResult, error) { +func (m *MockVM) Call(callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, state core.CommonStateReader, maxSteps, maxGas uint64, structuredErrStack, returnStateDiff bool) (vm.CallResult, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Call", callInfo, blockInfo, state, maxSteps, maxGas, structuredErrStack, returnStateDiff) ret0, _ := ret[0].(vm.CallResult) @@ -59,7 +58,7 @@ func (mr *MockVMMockRecorder) Call(callInfo, blockInfo, state, maxSteps, maxGas, } // Execute mocks base method. -func (m *MockVM) Execute(txns []core.Transaction, declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state commonstate.StateReader, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch, isEstimateFee bool) (vm.ExecutionResults, error) { +func (m *MockVM) Execute(txns []core.Transaction, declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, state core.CommonStateReader, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch, isEstimateFee bool) (vm.ExecutionResults, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Execute", txns, declaredClasses, paidFeesOnL1, blockInfo, state, skipChargeFee, skipValidate, errOnRevert, errStack, allowBinarySearch, isEstimateFee) ret0, _ := ret[0].(vm.ExecutionResults) diff --git a/node/node_test.go b/node/node_test.go index 0f67158a3a..158b22000f 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/node" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" diff --git a/node/throttled_vm.go b/node/throttled_vm.go index ff1c52ae6d..a08b51a550 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -3,7 +3,7 @@ package node import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/utils" "github.com/NethermindEth/juno/vm" ) @@ -23,7 +23,7 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott func (tvm *ThrottledVM) Call( callInfo *vm.CallInfo, blockInfo *vm.BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, maxSteps uint64, maxGas uint64, errStack, returnStateDiff bool, @@ -49,7 +49,7 @@ func (tvm *ThrottledVM) Execute( declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, skipChargeFee, skipValidate, errOnRevert, diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index f129f073de..b4d302b356 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" junoplugin "github.com/NethermindEth/juno/plugin" diff --git a/rpc/v10/events_test.go b/rpc/v10/events_test.go index 00d44fac6a..c36eec0ff6 100644 --- a/rpc/v10/events_test.go +++ b/rpc/v10/events_test.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/mocks" @@ -214,7 +215,7 @@ func setupTestChain( ) (*blockchain.Blockchain, *adaptfeeder.Feeder) { t.Helper() testDB := memory.New() - chain := blockchain.New(testDB, network) + chain := blockchain.New(testDB, network, statetestutils.UseNewState()) client := feeder.NewTestClient(t, network) gw := adaptfeeder.New(client) diff --git a/rpc/v10/helpers.go b/rpc/v10/helpers.go index ffafcdade0..df08b6ce13 100644 --- a/rpc/v10/helpers.go +++ b/rpc/v10/helpers.go @@ -126,8 +126,8 @@ func (h *Handler) callAndLogErr(f func() error, msg string) { func (h *Handler) stateByBlockID( blockID *rpcv9.BlockID, -) (core.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader core.StateReader +) (core.CommonStateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader core.CommonStateReader var closer blockchain.StateCloser var err error switch { diff --git a/rpc/v10/pending_data_wrapper.go b/rpc/v10/pending_data_wrapper.go index d7a3f1da90..edfcd08edd 100644 --- a/rpc/v10/pending_data_wrapper.go +++ b/rpc/v10/pending_data_wrapper.go @@ -33,7 +33,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (core.StateReader, func() error, error) { +func (h *Handler) PendingState() (core.CommonStateReader, func() error, error) { pendingData, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, core.ErrPendingDataNotFound) { diff --git a/rpc/v10/state_update_test.go b/rpc/v10/state_update_test.go index d1146b6e55..c693aff515 100644 --- a/rpc/v10/state_update_test.go +++ b/rpc/v10/state_update_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -34,7 +35,7 @@ func TestStateUpdate_ErrorCases(t *testing.T) { n := &utils.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PendingData().Return(nil, core.ErrPendingDataNotFound) diff --git a/rpc/v10/subscriptions_test.go b/rpc/v10/subscriptions_test.go index 9781e8e030..c7ea30d08f 100644 --- a/rpc/v10/subscriptions_test.go +++ b/rpc/v10/subscriptions_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -93,13 +94,13 @@ func (fs *fakeSyncer) PendingData() (core.PendingData, error) { return nil, core.ErrPendingDataNotFound } func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingState() (core.CommonStateReader, func() error, error) { return nil, nil, nil } func (fs *fakeSyncer) PendingStateBeforeIndex( index int, -) (core.StateReader, func() error, error) { +) (core.CommonStateReader, func() error, error) { return nil, nil, nil } @@ -1369,10 +1370,10 @@ func TestSubscribeNewHeadsHistorical(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &utils.Mainnet) + chain := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &utils.Mainnet) + chain = blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) syncer := newFakeSyncer() ctx, cancel := context.WithCancel(t.Context()) diff --git a/rpc/v10/trace.go b/rpc/v10/trace.go index 2c2f24d0a8..89ef41f5ca 100644 --- a/rpc/v10/trace.go +++ b/rpc/v10/trace.go @@ -101,8 +101,8 @@ func (h *Handler) TraceBlockTransactions( func traceTransactionsWithState( vm vm.VM, transactions []core.Transaction, - executionState core.StateReader, - classLookupState core.StateReader, + executionState core.CommonStateReader, + classLookupState core.CommonStateReader, blockInfo *vm.BlockInfo, ) ([]TracedBlockTransaction, http.Header, *jsonrpc.Error) { httpHeader := defaultExecutionHeader() @@ -169,7 +169,7 @@ func traceTransactionsWithState( // // Returns the list of declared classes, L1 handler fees, and an error if any. func fetchDeclaredClassesAndL1Fees( - transactions []core.Transaction, state core.StateReader, + transactions []core.Transaction, state core.CommonStateReader, ) ([]core.ClassDefinition, []*felt.Felt, *jsonrpc.Error) { var declaredClasses []core.ClassDefinition l1HandlerFees := []*felt.Felt{} @@ -402,7 +402,7 @@ func (h *Handler) traceBlockWithVM(block *core.Block) ( // Get state to read class definitions for declare transactions var ( - headState core.StateReader + headState core.CommonStateReader headStateCloser blockchain.StateCloser ) // TODO: remove pending variant when it is no longer supported diff --git a/rpc/v10/trace_test.go b/rpc/v10/trace_test.go index 55c719623d..9d1515df85 100644 --- a/rpc/v10/trace_test.go +++ b/rpc/v10/trace_test.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -735,7 +736,7 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { log := utils.NewNopZapLogger() n := &utils.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New(memory.New(), n, statetestutils.UseNewState()) handler := rpcv10.New(chain, nil, nil, log) if description == "pre_confirmed" { diff --git a/rpc/v6/block_test.go b/rpc/v6/block_test.go index f348844d1c..3fb6ae2857 100644 --- a/rpc/v6/block_test.go +++ b/rpc/v6/block_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v6/estimate_fee_test.go b/rpc/v6/estimate_fee_test.go index e11d7b8618..70c14ed170 100644 --- a/rpc/v6/estimate_fee_test.go +++ b/rpc/v6/estimate_fee_test.go @@ -6,7 +6,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -61,7 +61,7 @@ func TestEstimateMessageFee(t *testing.T) { declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *vm.BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, skipChargeFee, skipValidate, errOnRevert, diff --git a/rpc/v6/events_test.go b/rpc/v6/events_test.go index 7114afd10b..225fa4e751 100644 --- a/rpc/v6/events_test.go +++ b/rpc/v6/events_test.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index f1ddc3f9ed..edc2342759 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" @@ -151,8 +151,8 @@ func feeUnit(txn core.Transaction) FeeUnit { func (h *Handler) stateByBlockID( id *BlockID, -) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader commonstate.StateReader +) (core.CommonStateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader core.CommonStateReader var closer blockchain.StateCloser var err error switch { diff --git a/rpc/v6/pending_data_wrapper.go b/rpc/v6/pending_data_wrapper.go index 0d86b41370..b6bf86418a 100644 --- a/rpc/v6/pending_data_wrapper.go +++ b/rpc/v6/pending_data_wrapper.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/sync/pendingdata" ) @@ -39,7 +39,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { +func (h *Handler) PendingState() (core.CommonStateReader, func() error, error) { pendingData, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, core.ErrPendingDataNotFound) { diff --git a/rpc/v6/state_update_test.go b/rpc/v6/state_update_test.go index 7cb5ebd8c5..56657adf4e 100644 --- a/rpc/v6/state_update_test.go +++ b/rpc/v6/state_update_test.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v6/trace.go b/rpc/v6/trace.go index 426d8d90ba..ca975514b3 100644 --- a/rpc/v6/trace.go +++ b/rpc/v6/trace.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" @@ -159,7 +159,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block, defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState commonstate.StateReader + headState core.CommonStateReader headStateCloser blockchain.StateCloser ) if isPending { diff --git a/rpc/v6/trace_test.go b/rpc/v6/trace_test.go index ca1648776b..9d97fe5b7b 100644 --- a/rpc/v6/trace_test.go +++ b/rpc/v6/trace_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v7/block_test.go b/rpc/v7/block_test.go index 8cb8978070..e0f5d9df9b 100644 --- a/rpc/v7/block_test.go +++ b/rpc/v7/block_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v7/helpers.go b/rpc/v7/helpers.go index e63e0c7814..9992885311 100644 --- a/rpc/v7/helpers.go +++ b/rpc/v7/helpers.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -161,8 +161,8 @@ func feeUnit(txn core.Transaction) FeeUnit { func (h *Handler) stateByBlockID( id *BlockID, -) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader commonstate.StateReader +) (core.CommonStateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader core.CommonStateReader var closer blockchain.StateCloser var err error switch { diff --git a/rpc/v7/pending_data_wrapper.go b/rpc/v7/pending_data_wrapper.go index f742daa3a3..a5221ebae6 100644 --- a/rpc/v7/pending_data_wrapper.go +++ b/rpc/v7/pending_data_wrapper.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/sync/pendingdata" ) @@ -39,7 +39,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { +func (h *Handler) PendingState() (core.CommonStateReader, func() error, error) { pendingData, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, core.ErrPendingDataNotFound) { diff --git a/rpc/v7/trace.go b/rpc/v7/trace.go index 7864d1bd8b..0c04b82b5f 100644 --- a/rpc/v7/trace.go +++ b/rpc/v7/trace.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -197,7 +197,7 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block) defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState commonstate.StateReader + headState core.CommonStateReader headStateCloser blockchain.StateCloser ) if isPending { diff --git a/rpc/v7/trace_test.go b/rpc/v7/trace_test.go index 88c9aaef95..4c82b595e5 100644 --- a/rpc/v7/trace_test.go +++ b/rpc/v7/trace_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v8/block_test.go b/rpc/v8/block_test.go index 4dd721b5f2..b6ef1f7ea8 100644 --- a/rpc/v8/block_test.go +++ b/rpc/v8/block_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v8/helpers.go b/rpc/v8/helpers.go index 18f8d002b3..48e3f4ee75 100644 --- a/rpc/v8/helpers.go +++ b/rpc/v8/helpers.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -139,8 +139,8 @@ func feeUnit(txn core.Transaction) FeeUnit { func (h *Handler) stateByBlockID( blockID *BlockID, -) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader commonstate.StateReader +) (core.CommonStateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader core.CommonStateReader var closer blockchain.StateCloser var err error switch blockID.Type() { diff --git a/rpc/v8/pending_data_wrapper.go b/rpc/v8/pending_data_wrapper.go index fa5fe8fc1c..18e661b793 100644 --- a/rpc/v8/pending_data_wrapper.go +++ b/rpc/v8/pending_data_wrapper.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/sync/pendingdata" ) @@ -39,7 +39,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { +func (h *Handler) PendingState() (core.CommonStateReader, func() error, error) { pendingData, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, core.ErrPendingDataNotFound) { diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index b3e6e28587..228767fe87 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/core/state/commontrie" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" @@ -212,9 +212,9 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp return nil } -func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { +func getClassProof(tr core.CommonTrie, classes []felt.Felt) ([]*HashToNode, error) { switch t := tr.(type) { - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: classProof := trie.NewProofNodeSet() for _, class := range classes { if err := (*trie.Trie)(t).Prove(&class, classProof); err != nil { @@ -222,7 +222,7 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } } return adaptDeprecatedTrieProofNodes(classProof), nil - case *commontrie.TrieAdapter: + case *trie2.Trie: classProof := trie2.NewProofNodeSet() for _, class := range classes { if err := (*trie2.Trie)(t).Prove(&class, classProof); err != nil { @@ -236,14 +236,14 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } func getContractProof( - tr commontrie.Trie, - state commonstate.StateReader, + tr core.CommonTrie, + state core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { switch t := tr.(type) { - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) - case *commontrie.TrieAdapter: + case *trie2.Trie: return getContractProofWithTrie((*trie2.Trie)(t), state, contracts) default: return nil, fmt.Errorf("unknown trie type: %T", tr) @@ -252,18 +252,18 @@ func getContractProof( func getContractProofWithDeprecatedTrie( tr *trie.Trie, - state commonstate.StateReader, + state core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) for i, contract := range contracts { - if err := t.Prove(&contract, contractProof); err != nil { + if err := tr.Prove(&contract, contractProof); err != nil { return nil, err } - root, err := t.Hash() + root, err := tr.Hash() if err != nil { return nil, err } @@ -297,7 +297,7 @@ func getContractProofWithDeprecatedTrie( func getContractProofWithTrie( tr *trie2.Trie, - st commonstate.StateReader, + st core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() @@ -307,7 +307,10 @@ func getContractProofWithTrie( return nil, err } - root := tr.Hash() + root, err := tr.Hash() + if err != nil { + return nil, err + } nonce, err := st.ContractNonce(&contract) if err != nil { @@ -337,7 +340,7 @@ func getContractProofWithTrie( } func getContractStorageProof( - state commonstate.StateReader, + state core.CommonStateReader, storageKeys []StorageKeys, ) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) @@ -348,7 +351,7 @@ func getContractStorageProof( } switch t := contractStorageTrie.(type) { - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: contractStorageProof := trie.NewProofNodeSet() for _, key := range storageKey.Keys { if err := (*trie.Trie)(t).Prove(&key, contractStorageProof); err != nil { @@ -356,7 +359,7 @@ func getContractStorageProof( } } contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) - case *commontrie.TrieAdapter: + case *trie2.Trie: contractStorageProof := trie2.NewProofNodeSet() for _, key := range storageKey.Keys { if err := (*trie2.Trie)(t).Prove(&key, contractStorageProof); err != nil { diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index fecfd55c4c..0ea6803a18 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -12,8 +12,7 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commontrie" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" @@ -231,7 +230,7 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - var classTrie, contractTrie commontrie.Trie + var classTrie, contractTrie core.CommonTrie trieRoot := felt.Zero if !statetestutils.UseNewState() { @@ -240,8 +239,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) - contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) + classTrie = tempTrie + contractTrie = tempTrie } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func( @@ -263,7 +262,7 @@ func TestStorageProof(t *testing.T) { trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) createTrie(t, trieutils.NewClassTrieID(felt.Zero), &trieDB) contractTrie2 := createTrie(t, trieutils.NewContractTrieID(felt.Zero), &trieDB) - tmpTrieRoot := contractTrie2.Hash() + tmpTrieRoot, _ := contractTrie2.Hash() trieRoot = tmpTrieRoot // recreate because the previous ones are committed @@ -281,8 +280,8 @@ func TestStorageProof(t *testing.T) { &trieDB, ) require.NoError(t, err) - classTrie = (*commontrie.TrieAdapter)(classTrie2) - contractTrie = (*commontrie.TrieAdapter)(contractTrie2) + classTrie = classTrie2 + contractTrie = contractTrie2 } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -918,13 +917,13 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } -func emptyCommonTrie(t *testing.T) commontrie.Trie { +func emptyCommonTrie(t *testing.T) core.CommonTrie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return (*commontrie.TrieAdapter)(tempTrie) + return tempTrie } else { - return (*commontrie.DeprecatedTrieAdapter)(emptyTrie(t)) + return emptyTrie(t) } } diff --git a/rpc/v8/subscriptions_test.go b/rpc/v8/subscriptions_test.go index e91327d943..d0219b5cdc 100644 --- a/rpc/v8/subscriptions_test.go +++ b/rpc/v8/subscriptions_test.go @@ -15,8 +15,8 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -417,13 +417,13 @@ func (fs *fakeSyncer) PendingData() (core.PendingData, error) { return nil, core.ErrPendingDataNotFound } func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (commonstate.StateReader, func() error, error) { +func (fs *fakeSyncer) PendingState() (core.CommonStateReader, func() error, error) { return nil, nil, nil } func (fs *fakeSyncer) PendingStateBeforeIndex( index int, -) (commonstate.StateReader, func() error, error) { +) (core.CommonStateReader, func() error, error) { return nil, nil, nil } diff --git a/rpc/v8/trace.go b/rpc/v8/trace.go index e7be81f9e3..df092cbfae 100644 --- a/rpc/v8/trace.go +++ b/rpc/v8/trace.go @@ -11,7 +11,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -176,7 +176,7 @@ func (h *Handler) traceBlockTransactionWithVM(block *core.Block) ( defer h.callAndLogErr(closer, "Failed to close state in traceBlockTransactions") var ( - headState commonstate.StateReader + headState core.CommonStateReader headStateCloser blockchain.StateCloser ) diff --git a/rpc/v8/trace_test.go b/rpc/v8/trace_test.go index 77c37c6251..be40f37b27 100644 --- a/rpc/v8/trace_test.go +++ b/rpc/v8/trace_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v9/block_test.go b/rpc/v9/block_test.go index 0ad6d06c61..8faf9cb5fb 100644 --- a/rpc/v9/block_test.go +++ b/rpc/v9/block_test.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v9/compiled_casm_test.go b/rpc/v9/compiled_casm_test.go index 331d6ea1ae..72274bdd81 100644 --- a/rpc/v9/compiled_casm_test.go +++ b/rpc/v9/compiled_casm_test.go @@ -67,7 +67,7 @@ func TestCompiledCasm(t *testing.T) { require.NoError(t, err) mockState := mocks.NewMockStateHistoryReader(mockCtrl) - mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: class}, nil) + mockState.EXPECT().Class(classHash).Return(&core.DeclaredClassDefinition{Class: class}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) resp, rpcErr := handler.CompiledCasm(classHash) @@ -109,7 +109,7 @@ func TestCompiledCasm(t *testing.T) { } mockState := mocks.NewMockStateHistoryReader(mockCtrl) - mockState.EXPECT().Class(classHash).Return(&core.DeclaredClass{Class: cairoClass}, nil) + mockState.EXPECT().Class(classHash).Return(&core.DeclaredClassDefinition{Class: cairoClass}, nil) rd.EXPECT().HeadState().Return(mockState, nopCloser, nil) resp, rpcErr := handler.CompiledCasm(classHash) diff --git a/rpc/v9/events_test.go b/rpc/v9/events_test.go index 6fcd32fb8a..2d27ea5d98 100644 --- a/rpc/v9/events_test.go +++ b/rpc/v9/events_test.go @@ -8,7 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/mocks" diff --git a/rpc/v9/helpers.go b/rpc/v9/helpers.go index 4739748dfa..6f381010df 100644 --- a/rpc/v9/helpers.go +++ b/rpc/v9/helpers.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -153,8 +153,8 @@ func feeUnit(txn core.Transaction) FeeUnit { func (h *Handler) stateByBlockID( blockID *BlockID, -) (commonstate.StateReader, blockchain.StateCloser, *jsonrpc.Error) { - var reader commonstate.StateReader +) (core.CommonStateReader, blockchain.StateCloser, *jsonrpc.Error) { + var reader core.CommonStateReader var closer blockchain.StateCloser var err error switch blockID.Type() { diff --git a/rpc/v9/pending_data_wrapper.go b/rpc/v9/pending_data_wrapper.go index 9f96bf7dca..3de26e4617 100644 --- a/rpc/v9/pending_data_wrapper.go +++ b/rpc/v9/pending_data_wrapper.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/sync/pendingdata" ) @@ -34,7 +34,7 @@ func (h *Handler) PendingBlock() *core.Block { return pending.GetBlock() } -func (h *Handler) PendingState() (commonstate.StateReader, func() error, error) { +func (h *Handler) PendingState() (core.CommonStateReader, func() error, error) { pendingData, err := h.syncReader.PendingData() if err != nil { if errors.Is(err, core.ErrPendingDataNotFound) { diff --git a/rpc/v9/state_update_test.go b/rpc/v9/state_update_test.go index 3354bf34d7..b7e9cda800 100644 --- a/rpc/v9/state_update_test.go +++ b/rpc/v9/state_update_test.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index a7163bc3db..7aecb5c2d7 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" + "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commonstate" - "github.com/NethermindEth/juno/core/state/commontrie" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" @@ -212,11 +212,11 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp return nil } -func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, error) { +func getClassProof(tr core.CommonTrie, classes []felt.Felt) ([]*HashToNode, error) { switch t := tr.(type) { // TODO(maksym): remove after trie2 integration. RPC packages shouldn't // care about which trie implementation is being used and the output format should be the same - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: classProof := trie.NewProofNodeSet() for _, class := range classes { if err := (*trie.Trie)(t).Prove(&class, classProof); err != nil { @@ -224,7 +224,7 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } } return adaptDeprecatedTrieProofNodes(classProof), nil - case *commontrie.TrieAdapter: + case *trie2.Trie: classProof := trie2.NewProofNodeSet() for _, class := range classes { if err := (*trie2.Trie)(t).Prove(&class, classProof); err != nil { @@ -238,16 +238,16 @@ func getClassProof(tr commontrie.Trie, classes []felt.Felt) ([]*HashToNode, erro } func getContractProof( - tr commontrie.Trie, - state commonstate.StateReader, + tr core.CommonTrie, + state core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { // TODO(maksym): remove after trie2 integration. RPC packages shouldn't // care about which trie implementation is being used and the output format should be the same switch t := tr.(type) { - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) - case *commontrie.TrieAdapter: + case *trie2.Trie: return getContractProofWithTrie((*trie2.Trie)(t), state, contracts) default: return nil, fmt.Errorf("unknown trie type: %T", tr) @@ -256,14 +256,14 @@ func getContractProof( func getContractProofWithDeprecatedTrie( tr *trie.Trie, - state commonstate.StateReader, + state core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() contractLeavesData := make([]*LeafData, len(contracts)) for i, contract := range contracts { - if err := t.Prove(&contract, contractProof); err != nil { + if err := tr.Prove(&contract, contractProof); err != nil { return nil, err } @@ -300,7 +300,7 @@ func getContractProofWithDeprecatedTrie( func getContractProofWithTrie( tr *trie2.Trie, - st commonstate.StateReader, + st core.CommonStateReader, contracts []felt.Felt, ) (*ContractProof, error) { contractProof := trie2.NewProofNodeSet() @@ -311,7 +311,10 @@ func getContractProofWithTrie( return nil, err } - root := tr.Hash() + root, err := tr.Hash() + if err != nil { + return nil, err + } nonce, err := st.ContractNonce(&contract) if err != nil { @@ -341,7 +344,7 @@ func getContractProofWithTrie( } func getContractStorageProof( - state commonstate.StateReader, + state core.CommonStateReader, storageKeys []StorageKeys, ) ([][]*HashToNode, error) { contractStorageRes := make([][]*HashToNode, len(storageKeys)) @@ -352,7 +355,7 @@ func getContractStorageProof( } switch t := contractStorageTrie.(type) { - case *commontrie.DeprecatedTrieAdapter: + case *trie.Trie: contractStorageProof := trie.NewProofNodeSet() for _, key := range storageKey.Keys { if err := (*trie.Trie)(t).Prove(&key, contractStorageProof); err != nil { @@ -360,7 +363,7 @@ func getContractStorageProof( } } contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) - case *commontrie.TrieAdapter: + case *trie2.Trie: contractStorageProof := trie2.NewProofNodeSet() for _, key := range storageKey.Keys { if err := (*trie2.Trie)(t).Prove(&key, contractStorageProof); err != nil { diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 0fdbde4251..d297a231f8 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -12,8 +12,7 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commontrie" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" @@ -241,7 +240,7 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - var classTrie, contractTrie commontrie.Trie + var classTrie, contractTrie core.CommonTrie trieRoot := felt.Zero if !statetestutils.UseNewState() { @@ -250,8 +249,8 @@ func TestStorageProof(t *testing.T) { _, _ = tempTrie.Put(key2, value2) _ = tempTrie.Commit() trieRoot, _ = tempTrie.Root() - classTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) - contractTrie = (*commontrie.DeprecatedTrieAdapter)(tempTrie) + classTrie = tempTrie + contractTrie = tempTrie } else { newComm := new(felt.Felt).SetUint64(1) createTrie := func( @@ -273,7 +272,8 @@ func TestStorageProof(t *testing.T) { trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) createTrie(t, trieutils.NewClassTrieID(felt.Zero), &trieDB) contractTrie2 := createTrie(t, trieutils.NewContractTrieID(felt.Zero), &trieDB) - tmpTrieRoot := contractTrie2.Hash() + tmpTrieRoot, err := contractTrie2.Hash() + require.NoError(t, err) trieRoot = tmpTrieRoot // recreate because the previous ones are committed @@ -291,8 +291,8 @@ func TestStorageProof(t *testing.T) { &trieDB, ) require.NoError(t, err) - classTrie = (*commontrie.TrieAdapter)(classTrie2) - contractTrie = (*commontrie.TrieAdapter)(contractTrie2) + classTrie = classTrie2 + contractTrie = contractTrie2 } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -929,13 +929,13 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } -func emptyCommonTrie(t *testing.T) commontrie.Trie { +func emptyCommonTrie(t *testing.T) core.CommonTrie { if statetestutils.UseNewState() { tempTrie, err := trie2.NewEmptyPedersen() require.NoError(t, err) - return (*commontrie.TrieAdapter)(tempTrie) + return tempTrie } else { - return (*commontrie.DeprecatedTrieAdapter)(emptyTrie(t)) + return emptyTrie(t) } } diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index 492be776e5..adefa0d06a 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -14,8 +14,8 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -93,12 +93,14 @@ func (fs *fakeSyncer) HighestBlockHeader() *core.Header { func (fs *fakeSyncer) PendingData() (core.PendingData, error) { return nil, core.ErrPendingDataNotFound } -func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } -func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { return nil, nil, nil } +func (fs *fakeSyncer) PendingBlock() *core.Block { return nil } +func (fs *fakeSyncer) PendingState() (core.CommonStateReader, func() error, error) { + return nil, nil, nil +} func (fs *fakeSyncer) PendingStateBeforeIndex( index int, -) (commonstate.StateReader, func() error, error) { +) (core.CommonStateReader, func() error, error) { return nil, nil, nil } diff --git a/rpc/v9/trace.go b/rpc/v9/trace.go index ed8e814b50..e66cc040b3 100644 --- a/rpc/v9/trace.go +++ b/rpc/v9/trace.go @@ -189,8 +189,8 @@ func (h *Handler) Call(funcCall *FunctionCall, id *BlockID) ([]*felt.Felt, *json func traceTransactionsWithState( vm vm.VM, transactions []core.Transaction, - executionState core.StateReader, - classLookupState core.StateReader, + executionState core.CommonStateReader, + classLookupState core.CommonStateReader, blockInfo *vm.BlockInfo, ) ([]TracedBlockTransaction, http.Header, *jsonrpc.Error) { httpHeader := defaultExecutionHeader() @@ -257,7 +257,7 @@ func traceTransactionsWithState( // // Returns the list of declared classes, L1 handler fees, and an error if any. func fetchDeclaredClassesAndL1Fees( - transactions []core.Transaction, state core.StateReader, + transactions []core.Transaction, state core.CommonStateReader, ) ([]core.ClassDefinition, []*felt.Felt, *jsonrpc.Error) { var declaredClasses []core.ClassDefinition l1HandlerFees := []*felt.Felt{} @@ -490,7 +490,7 @@ func (h *Handler) traceBlockWithVM(block *core.Block) ( // Get state to read class definitions for declare transactions var ( - headState core.StateReader + headState core.CommonStateReader headStateCloser blockchain.StateCloser ) // TODO: remove pending variant when it is no longer supported diff --git a/rpc/v9/trace_test.go b/rpc/v9/trace_test.go index acd06793ac..52a2621d8b 100644 --- a/rpc/v9/trace_test.go +++ b/rpc/v9/trace_test.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" diff --git a/sequencer/sequencer.go b/sequencer/sequencer.go index 5756911dd2..9c4ca7f1b4 100644 --- a/sequencer/sequencer.go +++ b/sequencer/sequencer.go @@ -9,7 +9,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/feed" "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/plugin" @@ -212,7 +212,7 @@ func (s *Sequencer) PendingBlock() *core.Block { return s.buildState.PendingBlock() } -func (s *Sequencer) PendingState() (commonstate.StateReader, func() error, error) { +func (s *Sequencer) PendingState() (core.CommonStateReader, func() error, error) { return s.builder.PendingState(s.buildState) } diff --git a/sequencer/sequencer_test.go b/sequencer/sequencer_test.go index 4456e87971..3a1a315b3a 100644 --- a/sequencer/sequencer_test.go +++ b/sequencer/sequencer_test.go @@ -10,7 +10,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" diff --git a/sync/pending_polling_test.go b/sync/pending_polling_test.go index f0cbfbef24..e792f76e63 100644 --- a/sync/pending_polling_test.go +++ b/sync/pending_polling_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" @@ -760,7 +760,7 @@ func makeTestPreConfirmed(num uint64) core.PreConfirmed { func TestStorePending(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &utils.Mainnet) + bc := blockchain.New(testDB, &utils.Mainnet, statetestutils.UseNewState()) log := utils.NewNopZapLogger() client := feeder.NewTestClient(t, &utils.Mainnet) gw := adaptfeeder.New(client) diff --git a/sync/pendingdata/helpers.go b/sync/pendingdata/helpers.go index 1f288cbdb3..ab4524796a 100644 --- a/sync/pendingdata/helpers.go +++ b/sync/pendingdata/helpers.go @@ -161,7 +161,7 @@ func MakeEmptyPendingDataForParent( func ResolvePendingBaseState( pending *core.Pending, stateReader blockchain.Reader, -) (core.StateReader, blockchain.StateCloser, error) { +) (core.CommonStateReader, blockchain.StateCloser, error) { return stateReader.StateAtBlockHash(pending.Block.ParentHash) } @@ -169,7 +169,7 @@ func ResolvePendingBaseState( func ResolvePreConfirmedBaseState( preConfirmed *core.PreConfirmed, stateReader blockchain.Reader, -) (core.StateReader, blockchain.StateCloser, error) { +) (core.CommonStateReader, blockchain.StateCloser, error) { preLatest := preConfirmed.PreLatest // If pre-latest exists, use its parent hash as the base state if preLatest != nil { @@ -191,7 +191,7 @@ func ResolvePreConfirmedBaseState( func ResolvePendingDataBaseState( pending core.PendingData, stateReader blockchain.Reader, -) (core.StateReader, blockchain.StateCloser, error) { +) (core.CommonStateReader, blockchain.StateCloser, error) { switch p := pending.(type) { case *core.PreConfirmed: return ResolvePreConfirmedBaseState(p, stateReader) @@ -207,7 +207,7 @@ func ResolvePendingDataBaseState( func PendingState( pending core.PendingData, stateReader blockchain.Reader, -) (core.StateReader, blockchain.StateCloser, error) { +) (core.CommonStateReader, blockchain.StateCloser, error) { baseState, baseStateCloser, err := ResolvePendingDataBaseState(pending, stateReader) if err != nil { return nil, nil, err @@ -222,7 +222,7 @@ func PendingStateBeforeIndex( pending core.PendingData, stateReader blockchain.Reader, index uint, -) (core.StateReader, blockchain.StateCloser, error) { +) (core.CommonStateReader, blockchain.StateCloser, error) { baseState, baseStateCloser, err := ResolvePendingDataBaseState(pending, stateReader) if err != nil { return nil, nil, err diff --git a/sync/reorg_test.go b/sync/reorg_test.go index 36c7ff3917..4fc36fb506 100644 --- a/sync/reorg_test.go +++ b/sync/reorg_test.go @@ -14,7 +14,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/sync" diff --git a/sync/sync.go b/sync/sync.go index 166cad9093..c6f701bc2b 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -11,7 +11,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" junoplugin "github.com/NethermindEth/juno/plugin" @@ -118,7 +118,7 @@ func (n *NoopSynchronizer) PendingData() (core.PendingData, error) { return nil, errors.New("PendingData() is not implemented") } -func (n *NoopSynchronizer) PendingState() (commonstate.StateReader, func() error, error) { +func (n *NoopSynchronizer) PendingState() (core.CommonStateReader, func() error, error) { return nil, nil, errors.New("PendingState() not implemented") } diff --git a/sync/sync_test.go b/sync/sync_test.go index d8a72846bd..ffaae6fdb1 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -11,7 +11,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - statetestutils "github.com/NethermindEth/juno/core/state/state_test_utils" + statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" diff --git a/vm/vm.go b/vm/vm.go index 600b8c50c3..32aff3d9f2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -20,7 +20,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/state/commonstate" + "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/utils" ) @@ -50,7 +50,7 @@ type VM interface { Call( callInfo *CallInfo, blockInfo *BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, maxSteps uint64, maxGas uint64, structuredErrStack, @@ -61,7 +61,7 @@ type VM interface { declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, skipChargeFee, skipValidate, errOnRevert, @@ -88,7 +88,7 @@ func New(chainInfo *ChainInfo, concurrencyMode bool, log utils.SimpleLogger) VM // callContext manages the context that a Call instance executes on type callContext struct { // state that the call is running on - state commonstate.StateReader + state core.CommonStateReader log utils.SimpleLogger // err field to be possibly populated in case of an error in execution err string @@ -282,7 +282,7 @@ func makeCBlockInfo(blockInfo *BlockInfo) C.BlockInfo { func (v *vm) Call( callInfo *CallInfo, blockInfo *BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, maxSteps uint64, maxGas uint64, structuredErrStack, @@ -338,7 +338,7 @@ func (v *vm) Execute( declaredClasses []core.ClassDefinition, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo, - state commonstate.StateReader, + state core.CommonStateReader, skipChargeFee, skipValidate, errOnRevert, diff --git a/vm/vm_test.go b/vm/vm_test.go index 9a2af50b5d..aa112dea25 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -8,7 +8,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/commonstate" + + "github.com/NethermindEth/juno/core/state/statefactory" statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" @@ -34,7 +35,7 @@ func TestCallDeprecatedCairo(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) @@ -130,7 +131,7 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) @@ -189,7 +190,7 @@ func TestCallCairo(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) state, err := stateFactory.NewState(&felt.Zero, txn) newRoot := felt.NewUnsafeFromString[felt.Felt]( @@ -295,7 +296,7 @@ func TestCallInfoErrorHandling(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) testState, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) @@ -371,7 +372,7 @@ func TestExecute(t *testing.T) { triedb, err := triedb.New(testDB, nil) require.NoError(t, err) stateDB := state.NewStateDB(testDB, triedb) - stateFactory, err := commonstate.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) + stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) state, err := stateFactory.NewState(&felt.Zero, txn) require.NoError(t, err) From 259e9a7c1e731587d453548d5e504490c733c5a6 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 17 Nov 2025 17:32:04 +0100 Subject: [PATCH 44/47] fixes after update --- core/common_state.go | 1 - core/state/state_test.go | 12 ++++++------ core/state/statefactory/state_factory.go | 6 +++--- db/pebble/db.go | 19 +++++++------------ vm/vm_test.go | 8 +++++++- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/core/common_state.go b/core/common_state.go index e30351b725..9163057304 100644 --- a/core/common_state.go +++ b/core/common_state.go @@ -17,7 +17,6 @@ type CommonState interface { update *StateUpdate, declaredClasses map[felt.Felt]ClassDefinition, skipVerifyNewRoot bool, - flushChanges bool, ) error Revert(blockNum uint64, update *StateUpdate) error Commitment() (felt.Felt, error) diff --git a/core/state/state_test.go b/core/state/state_test.go index 901cb26800..f1e2bb93b4 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -262,12 +262,12 @@ func TestClass(t *testing.T) { su0, err := gw.StateUpdate(t.Context(), 0) require.NoError(t, err) - require.NoError(t, state.Update(0, su0, map[felt.Felt]core.Class{ - *cairo0Hash: cairo0Class, - *cairo1Hash: cairo1Class, - }, false, true)) - - gotCairo1Class, err := state.Class(cairo1Hash) + require.NoError(t, state.Update(0, su0, map[felt.Felt]core.ClassDefinition{ + *deprecatedCairoHash: deprecatedCairoClass, + *sierraHash: sierraClass, + }, false)) + require.NoError(t, batch.Write()) + gotSierraClass, err := state.Class(sierraHash) require.NoError(t, err) assert.Zero(t, gotSierraClass.At) assert.Equal(t, sierraClass, gotSierraClass.Class) diff --git a/core/state/statefactory/state_factory.go b/core/state/statefactory/state_factory.go index 9193bdb1ee..d05a31bd43 100644 --- a/core/state/statefactory/state_factory.go +++ b/core/state/statefactory/state_factory.go @@ -31,13 +31,13 @@ func NewStateFactory( }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (core.CommonState, error) { +func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch, batch db.Batch) (core.CommonState, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) return deprecatedState, nil } - state, err := state.New(stateRoot, sf.stateDB) + state, err := state.New(stateRoot, sf.stateDB, batch) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (sf *StateFactory) EmptyState() (core.CommonStateReader, error) { emptyState := core.NewState(txn) return emptyState, nil } - state, err := state.New(&felt.Zero, sf.stateDB) + state, err := state.New(&felt.Zero, sf.stateDB, nil) if err != nil { return nil, err } diff --git a/db/pebble/db.go b/db/pebble/db.go index 7471659ac4..57731d2206 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -28,24 +28,19 @@ type DB struct { // New opens a new database at the given path with default options func New(path string, options ...Option) (db.KeyValueStore, error) { - opts := pebble.Options{} - for _, option := range options { - if err := option(&opts); err != nil { - return nil, err - } - } - - opts := &pebble.Options{ - Logger: dbLog, - Cache: pebble.NewCache(int64(cacheSizeMB * utils.Megabyte)), - MaxOpenFiles: maxOpenFiles, + opts := pebble.Options{ L0CompactionFileThreshold: 8, L0StopWritesThreshold: 24, MemTableSize: 8 * utils.Megabyte, MaxConcurrentCompactions: func() int { return 2 }, } + for _, option := range options { + if err := option(&opts); err != nil { + return nil, err + } + } - return newPebble(path, opts) + return newPebble(path, &opts) } func newPebble(path string, options *pebble.Options) (*DB, error) { diff --git a/vm/vm_test.go b/vm/vm_test.go index a4f913e74c..ffa14b15fd 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -196,7 +196,8 @@ func TestCallCairo(t *testing.T) { stateDB := state.NewStateDB(testDB, triedb) stateFactory, err := statefactory.NewStateFactory(statetestutils.UseNewState(), triedb, stateDB) require.NoError(t, err) - state, err := stateFactory.NewState(&felt.Zero, txn) + batch := testDB.NewBatch() + state, err := stateFactory.NewState(&felt.Zero, txn, batch) newRoot := felt.NewUnsafeFromString[felt.Felt]( "0x2650cef46c190ec6bb7dc21a5a36781132e7c883b27175e625031149d4f1a84", ) @@ -211,6 +212,7 @@ func TestCallCairo(t *testing.T) { } declaredClass := map[felt.Felt]core.ClassDefinition{*classHash: simpleClass} require.NoError(t, state.Update(0, &firstStateUpdate, declaredClass, false)) + require.NoError(t, batch.Write()) logLevel := utils.NewLogLevel(utils.ERROR) log, err := utils.NewZapLogger(logLevel, false) @@ -274,6 +276,10 @@ func TestCallCairo(t *testing.T) { } require.NoError(t, state.Update(1, &secondStateUpdate, nil, false)) + if statetestutils.UseNewState() { + require.NoError(t, batch.Write()) + } + ret, err = vm.Call( &callInfo, &BlockInfo{Header: &core.Header{Number: 1}}, From 39e3e989f292a15d8289d8ff8cd06b7793c4b883 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 17 Nov 2025 17:49:51 +0100 Subject: [PATCH 45/47] lint --- blockchain/blockchain.go | 1 - builder/builder.go | 1 - core/pending.go | 5 ++++- core/state/statefactory/state_factory.go | 5 ++++- rpc/v6/helpers.go | 1 - rpc/v9/storage.go | 8 ++++---- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index f04c070b52..b7319f7fd2 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -8,7 +8,6 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" "github.com/NethermindEth/juno/core/state/statefactory" - "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" diff --git a/builder/builder.go b/builder/builder.go index 7f7f799fcc..720fef1bd6 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/utils" "github.com/consensys/gnark-crypto/ecc/stark-curve/ecdsa" diff --git a/core/pending.go b/core/pending.go index 86a92a140d..8f0c620edf 100644 --- a/core/pending.go +++ b/core/pending.go @@ -136,7 +136,10 @@ func (p *Pending) ReceiptByHash( return nil, nil, 0, ErrTransactionReceiptNotFound } -func (p *Pending) PendingStateBeforeIndex(baseState CommonStateReader, index uint) (CommonStateReader, error) { +func (p *Pending) PendingStateBeforeIndex( + baseState CommonStateReader, + index uint, +) (CommonStateReader, error) { return nil, ErrPendingStateBeforeIndexNotSupported } diff --git a/core/state/statefactory/state_factory.go b/core/state/statefactory/state_factory.go index 9193bdb1ee..f9a9b8add2 100644 --- a/core/state/statefactory/state_factory.go +++ b/core/state/statefactory/state_factory.go @@ -31,7 +31,10 @@ func NewStateFactory( }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (core.CommonState, error) { +func (sf *StateFactory) NewState( + stateRoot *felt.Felt, + txn db.IndexedBatch, +) (core.CommonState, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) return deprecatedState, nil diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index edc2342759..90452aed5a 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 7aecb5c2d7..4b5ca61b8d 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -219,7 +219,7 @@ func getClassProof(tr core.CommonTrie, classes []felt.Felt) ([]*HashToNode, erro case *trie.Trie: classProof := trie.NewProofNodeSet() for _, class := range classes { - if err := (*trie.Trie)(t).Prove(&class, classProof); err != nil { + if err := t.Prove(&class, classProof); err != nil { return nil, err } } @@ -227,7 +227,7 @@ func getClassProof(tr core.CommonTrie, classes []felt.Felt) ([]*HashToNode, erro case *trie2.Trie: classProof := trie2.NewProofNodeSet() for _, class := range classes { - if err := (*trie2.Trie)(t).Prove(&class, classProof); err != nil { + if err := t.Prove(&class, classProof); err != nil { return nil, err } } @@ -246,9 +246,9 @@ func getContractProof( // care about which trie implementation is being used and the output format should be the same switch t := tr.(type) { case *trie.Trie: - return getContractProofWithDeprecatedTrie((*trie.Trie)(t), state, contracts) + return getContractProofWithDeprecatedTrie(t, state, contracts) case *trie2.Trie: - return getContractProofWithTrie((*trie2.Trie)(t), state, contracts) + return getContractProofWithTrie(t, state, contracts) default: return nil, fmt.Errorf("unknown trie type: %T", tr) } From 5f2a1da865e68983e3a637f319a3379f67375e3b Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 18 Nov 2025 18:10:48 +0100 Subject: [PATCH 46/47] cleanup the code to reduce noise --- consensus/consensus_test.go | 1 - core/state/object.go | 1 - core/state/state.go | 6 +++--- core/trie2/trie.go | 1 - core/trie2/triedb/pathdb/disklayer.go | 6 ++---- core/trie2/triedb/pathdb/journal.go | 8 ++------ db/buckets.go | 2 +- rpc/v10/helpers.go | 2 +- rpc/v6/estimate_fee_test.go | 1 - rpc/v6/helpers.go | 2 +- rpc/v6/pending_data_wrapper.go | 1 - rpc/v6/trace.go | 1 - rpc/v7/helpers.go | 3 +-- rpc/v7/pending_data_wrapper.go | 1 - rpc/v7/trace.go | 1 - rpc/v8/helpers.go | 3 +-- rpc/v8/pending_data_wrapper.go | 1 - rpc/v8/storage.go | 2 +- rpc/v8/trace.go | 1 - rpc/v9/class_test.go | 12 ++++++------ rpc/v9/helpers.go | 1 - rpc/v9/pending_data_wrapper.go | 1 - rpc/v9/state_update_test.go | 8 ++++---- rpc/v9/storage.go | 1 - sync/sync.go | 1 - vm/vm.go | 1 - vm/vm_test.go | 1 - 27 files changed, 23 insertions(+), 47 deletions(-) diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index 2550b29adb..3d4dd22398 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -61,7 +61,6 @@ func getBlockchain( t.Helper() testDB := memory.New() network := &utils.Mainnet - bc := blockchain.New(testDB, network, statetestutils.UseNewState()) require.NoError(t, bc.StoreGenesis(&genesisDiff, genesisClasses)) return bc diff --git a/core/state/object.go b/core/state/object.go index 2afc6b9c1e..1cfb2b5ddd 100644 --- a/core/state/object.go +++ b/core/state/object.go @@ -49,7 +49,6 @@ func (s *stateObject) getStorage(key *felt.Felt) (felt.Felt, error) { return felt.Zero, err } - // TODO(maksym): test if this works instead of reading from disk path := tr.FeltToPath(key) reader, err := s.state.db.triedb.NodeReader( trieutils.NewContractStorageTrieID(s.state.initRoot, s.addr), diff --git a/core/state/state.go b/core/state/state.go index 7791203880..31ea690d2f 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -73,7 +73,7 @@ func New(stateRoot *felt.Felt, db *StateDB) (*State, error) { if err != nil { return nil, err } - fmt.Println("contractTrie", contractTrie, err) + classTrie, err := db.ClassTrie(stateRoot) if err != nil { return nil, err @@ -409,7 +409,6 @@ func (s *State) commit() (felt.Felt, stateUpdate, error) { for i, addr := range keys { obj := s.stateObjects[addr] - idx := i p.Go(func() error { // Object is marked as delete if obj == nil { @@ -425,7 +424,7 @@ func (s *State) commit() (felt.Felt, stateUpdate, error) { return err } - comms[idx] = obj.commitment() + comms[i] = obj.commitment() return nil }) } @@ -586,6 +585,7 @@ func (s *State) verifyComm(comm *felt.Felt) error { if err != nil { return err } + if !curComm.Equal(comm) { return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", comm, &curComm) } diff --git a/core/trie2/trie.go b/core/trie2/trie.go index cc58adf368..aa67aa2624 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -76,7 +76,6 @@ func New( } root, err := tr.resolveNode(nil, Path{}) - fmt.Println("root", root, err) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return nil, err } diff --git a/core/trie2/triedb/pathdb/disklayer.go b/core/trie2/triedb/pathdb/disklayer.go index dcfb30b847..029dfc366b 100644 --- a/core/trie2/triedb/pathdb/disklayer.go +++ b/core/trie2/triedb/pathdb/disklayer.go @@ -86,11 +86,9 @@ func (dl *diskLayer) node(id trieutils.TrieID, owner *felt.Felt, path *trieutils if err != nil { return nil, err } - blobCopy := make([]byte, len(blob)) - copy(blobCopy, blob) - dl.cleans.putNode(owner, path, isClass, blobCopy) - return blobCopy, nil + dl.cleans.putNode(owner, path, isClass, blob) + return blob, nil } func (dl *diskLayer) update(root *felt.Felt, id, block uint64, nodes *nodeSet) diffLayer { diff --git a/core/trie2/triedb/pathdb/journal.go b/core/trie2/triedb/pathdb/journal.go index 3a721b69c2..99d83d21bd 100644 --- a/core/trie2/triedb/pathdb/journal.go +++ b/core/trie2/triedb/pathdb/journal.go @@ -272,7 +272,7 @@ func (d *Database) loadLayers(enc []byte) (layer, error) { } func (d *Database) getStateRoot() felt.Felt { - encContractRootRaw, err := trieutils.GetNodeByPath( + encContractRoot, err := trieutils.GetNodeByPath( d.disk, db.ContractTrieContract, &felt.Zero, @@ -282,10 +282,8 @@ func (d *Database) getStateRoot() felt.Felt { if err != nil { return felt.Zero } - encContractRoot := make([]byte, len(encContractRootRaw)) - copy(encContractRoot, encContractRootRaw) - encStorageRootRaw, err := trieutils.GetNodeByPath( + encStorageRoot, err := trieutils.GetNodeByPath( d.disk, db.ClassTrie, @@ -296,8 +294,6 @@ func (d *Database) getStateRoot() felt.Felt { if err != nil { return felt.Zero } - encStorageRoot := make([]byte, len(encStorageRootRaw)) - copy(encStorageRoot, encStorageRootRaw) contractRootNode, err := trienode.DecodeNode(encContractRoot, &felt.Zero, 0, contractClassTrieHeight) if err != nil { diff --git a/db/buckets.go b/db/buckets.go index 51cf37088c..2bce176510 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -40,7 +40,7 @@ const ( MempoolNode ClassTrie // ClassTrie + nodetype + path + pathlength -> Trie Node ContractTrieContract // ContractTrieContract + nodetype + path + pathlength -> Trie Node - ContractTrieStorage // ContractTrieStorage + nodetype + owner + path + pathlength -> Trie Node + ContractTrieStorage // ContractTrieStorage + owner + nodetype + owner + path + pathlength -> Trie Node Contract // Contract + ContractAddr -> Contract StateHashToTrieRoots // StateHash -> ClassRootHash + ContractRootHash StateID // StateID + root hash -> state id diff --git a/rpc/v10/helpers.go b/rpc/v10/helpers.go index df08b6ce13..5c1b1a5007 100644 --- a/rpc/v10/helpers.go +++ b/rpc/v10/helpers.go @@ -17,7 +17,7 @@ func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // empty L1Head is returned if l1 head doesn't exist + // empty l1Head is returned if l1 head doesn't exist return l1Head, nil } diff --git a/rpc/v6/estimate_fee_test.go b/rpc/v6/estimate_fee_test.go index 70c14ed170..0d496dfdc7 100644 --- a/rpc/v6/estimate_fee_test.go +++ b/rpc/v6/estimate_fee_test.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v6/helpers.go b/rpc/v6/helpers.go index 90452aed5a..0269dea13d 100644 --- a/rpc/v6/helpers.go +++ b/rpc/v6/helpers.go @@ -19,7 +19,7 @@ func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // empty l1Head is returned if l1 head doesn't exist + // empty L1Head is returned if l1 head doesn't exist return l1Head, nil } diff --git a/rpc/v6/pending_data_wrapper.go b/rpc/v6/pending_data_wrapper.go index b6bf86418a..196a950899 100644 --- a/rpc/v6/pending_data_wrapper.go +++ b/rpc/v6/pending_data_wrapper.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/sync/pendingdata" ) diff --git a/rpc/v6/trace.go b/rpc/v6/trace.go index ca975514b3..4fcff3b313 100644 --- a/rpc/v6/trace.go +++ b/rpc/v6/trace.go @@ -12,7 +12,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/jsonrpc" rpccore "github.com/NethermindEth/juno/rpc/rpccore" "github.com/NethermindEth/juno/utils" diff --git a/rpc/v7/helpers.go b/rpc/v7/helpers.go index 9992885311..cf07f14318 100644 --- a/rpc/v7/helpers.go +++ b/rpc/v7/helpers.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -23,7 +22,7 @@ func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { if errors.Is(err, db.ErrKeyNotFound) { return core.L1Head{}, nil } - // empty l1Head is returned if l1 head doesn't exist + // empty L1Head is returned if l1 head doesn't exist return l1Head, nil } diff --git a/rpc/v7/pending_data_wrapper.go b/rpc/v7/pending_data_wrapper.go index a5221ebae6..eda21537ce 100644 --- a/rpc/v7/pending_data_wrapper.go +++ b/rpc/v7/pending_data_wrapper.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/sync/pendingdata" ) diff --git a/rpc/v7/trace.go b/rpc/v7/trace.go index 0c04b82b5f..b755da13d8 100644 --- a/rpc/v7/trace.go +++ b/rpc/v7/trace.go @@ -13,7 +13,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v8/helpers.go b/rpc/v8/helpers.go index 48e3f4ee75..742badf1b3 100644 --- a/rpc/v8/helpers.go +++ b/rpc/v8/helpers.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -20,7 +19,7 @@ func (h *Handler) l1Head() (core.L1Head, *jsonrpc.Error) { if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return core.L1Head{}, jsonrpc.Err(jsonrpc.InternalError, err.Error()) } - // empty l1Head is returned if l1 head doesn't exist + // empty L1Head is returned if l1 head doesn't exist return l1Head, nil } diff --git a/rpc/v8/pending_data_wrapper.go b/rpc/v8/pending_data_wrapper.go index 18e661b793..4be358c6dd 100644 --- a/rpc/v8/pending_data_wrapper.go +++ b/rpc/v8/pending_data_wrapper.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/sync/pendingdata" ) diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 228767fe87..7ec38314c2 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" @@ -498,6 +497,7 @@ type ContractProof struct { Nodes []*HashToNode `json:"nodes"` LeavesData []*LeafData `json:"contract_leaves_data"` } + type GlobalRoots struct { ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` diff --git a/rpc/v8/trace.go b/rpc/v8/trace.go index df092cbfae..0d683637f2 100644 --- a/rpc/v8/trace.go +++ b/rpc/v8/trace.go @@ -11,7 +11,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v9/class_test.go b/rpc/v9/class_test.go index 0e63e3fad6..7ed4b775bb 100644 --- a/rpc/v9/class_test.go +++ b/rpc/v9/class_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" rpcv6 "github.com/NethermindEth/juno/rpc/v6" - rpc "github.com/NethermindEth/juno/rpc/v9" + rpcv9 "github.com/NethermindEth/juno/rpc/v9" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" @@ -40,7 +40,7 @@ func TestClass(t *testing.T) { return nil }, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil).AnyTimes() - handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) latest := blockIDLatest(t) @@ -75,7 +75,7 @@ func TestClass(t *testing.T) { t.Run("state by id error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) - handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) @@ -87,7 +87,7 @@ func TestClass(t *testing.T) { t.Run("class hash not found error", func(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockState := mocks.NewMockStateReader(mockCtrl) - handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) mockReader.EXPECT().HeadState().Return(mockState, func() error { return nil @@ -134,7 +134,7 @@ func TestClassAt(t *testing.T) { return nil }, nil).AnyTimes() mockReader.EXPECT().HeadsHeader().Return(new(core.Header), nil).AnyTimes() - handler := rpc.New(mockReader, nil, nil, utils.NewNopZapLogger()) + handler := rpcv9.New(mockReader, nil, nil, utils.NewNopZapLogger()) latest := blockIDLatest(t) @@ -167,7 +167,7 @@ func TestClassHashAt(t *testing.T) { mockReader := mocks.NewMockReader(mockCtrl) mockSyncReader := mocks.NewMockSyncReader(mockCtrl) log := utils.NewNopZapLogger() - handler := rpc.New(mockReader, mockSyncReader, nil, log) + handler := rpcv9.New(mockReader, mockSyncReader, nil, log) targetAddress := felt.FromUint64[felt.Felt](1234) t.Run("empty blockchain", func(t *testing.T) { diff --git a/rpc/v9/helpers.go b/rpc/v9/helpers.go index 6f381010df..e5ae90a7e1 100644 --- a/rpc/v9/helpers.go +++ b/rpc/v9/helpers.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" diff --git a/rpc/v9/pending_data_wrapper.go b/rpc/v9/pending_data_wrapper.go index 3de26e4617..8b401582f5 100644 --- a/rpc/v9/pending_data_wrapper.go +++ b/rpc/v9/pending_data_wrapper.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/NethermindEth/juno/core" - "github.com/NethermindEth/juno/sync/pendingdata" ) diff --git a/rpc/v9/state_update_test.go b/rpc/v9/state_update_test.go index b7e9cda800..88b04acd9a 100644 --- a/rpc/v9/state_update_test.go +++ b/rpc/v9/state_update_test.go @@ -12,7 +12,7 @@ import ( "github.com/NethermindEth/juno/mocks" rpccore "github.com/NethermindEth/juno/rpc/rpccore" rpcv6 "github.com/NethermindEth/juno/rpc/v6" - rpc "github.com/NethermindEth/juno/rpc/v9" + rpcv9 "github.com/NethermindEth/juno/rpc/v9" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" @@ -21,7 +21,7 @@ import ( ) func TestStateUpdate(t *testing.T) { - errTests := map[string]rpc.BlockID{ + errTests := map[string]rpcv9.BlockID{ "latest": blockIDLatest(t), "pre_confirmed": blockIDPreConfirmed(t), "hash": blockIDHash(t, &felt.One), @@ -41,7 +41,7 @@ func TestStateUpdate(t *testing.T) { mockSyncReader.EXPECT().PendingData().Return(nil, core.ErrPendingDataNotFound) } log := utils.NewNopZapLogger() - handler := rpc.New(chain, mockSyncReader, nil, log) + handler := rpcv9.New(chain, mockSyncReader, nil, log) update, rpcErr := handler.StateUpdate(&id) assert.Empty(t, update) @@ -51,7 +51,7 @@ func TestStateUpdate(t *testing.T) { log := utils.NewNopZapLogger() mockReader := mocks.NewMockReader(mockCtrl) - handler := rpc.New(mockReader, mockSyncReader, nil, log) + handler := rpcv9.New(mockReader, mockSyncReader, nil, log) client := feeder.NewTestClient(t, n) mainnetGw := adaptfeeder.New(client) diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 4b5ca61b8d..fdbb5e04f5 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -7,7 +7,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/core/trie2" "github.com/NethermindEth/juno/core/trie2/trienode" diff --git a/sync/sync.go b/sync/sync.go index c6f701bc2b..c3956df096 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -11,7 +11,6 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/feed" junoplugin "github.com/NethermindEth/juno/plugin" diff --git a/vm/vm.go b/vm/vm.go index 32aff3d9f2..bb31df9775 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -20,7 +20,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/utils" ) diff --git a/vm/vm_test.go b/vm/vm_test.go index aa112dea25..0b04e86b30 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -8,7 +8,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/state" - "github.com/NethermindEth/juno/core/state/statefactory" statetestutils "github.com/NethermindEth/juno/core/state/statetestutils" "github.com/NethermindEth/juno/core/trie2/triedb" From f87bd2e7df914903feedc7a630c3f6f0a224c0c2 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 19 Nov 2025 13:27:26 +0100 Subject: [PATCH 47/47] make the state `flush` concurrent, atomicity will be handled in another PR --- core/state/state.go | 83 +++++++++++++++-------------- core/trie2/triedb/database.go | 3 +- core/trie2/triedb/rawdb/database.go | 4 +- 3 files changed, 46 insertions(+), 44 deletions(-) diff --git a/core/state/state.go b/core/state/state.go index 31ea690d2f..71a1dff47e 100644 --- a/core/state/state.go +++ b/core/state/state.go @@ -502,61 +502,64 @@ func (s *State) flush( classes map[felt.Felt]core.ClassDefinition, storeHistory bool, ) error { + p := pool.New().WithMaxGoroutines(runtime.GOMAXPROCS(0)).WithErrors() + + p.Go(func() error { + return s.db.triedb.Update(&update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes) + }) + batch := s.db.disk.NewBatch() - if err := s.db.triedb.Update( - &update.curComm, - &update.prevComm, - blockNum, - update.classNodes, - update.contractNodes, - batch, - ); err != nil { - return err - } + p.Go(func() error { + for addr, obj := range s.stateObjects { + if obj == nil { // marked as deleted + if err := DeleteContract(batch, &addr); err != nil { + return err + } - for addr, obj := range s.stateObjects { - if obj == nil { // marked as deleted - if err := DeleteContract(batch, &addr); err != nil { - return err - } + // TODO(weiihann): handle hash-based, and there should be better ways of doing this + if err := trieutils.DeleteStorageNodesByPath(batch, addr); err != nil { + return err + } + } else { // updated + if err := WriteContract(batch, &addr, obj.contract); err != nil { + return err + } - // TODO(weiihann): handle hash-based, and there should be better ways of doing this - if err := trieutils.DeleteStorageNodesByPath(batch, addr); err != nil { - return err - } - } else { // updated - if err := WriteContract(batch, &addr, obj.contract); err != nil { - return err - } + if storeHistory { + for key, val := range obj.dirtyStorage { + if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil { + return err + } + } - if storeHistory { - for key, val := range obj.dirtyStorage { - if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil { + if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil { + return err + } + + if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { return err } } + } + } - if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil { + for classHash, class := range classes { + if class == nil { // mark as deleted + if err := DeleteClass(batch, &classHash); err != nil { return err } - - if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil { + } else { + if err := WriteClass(batch, &classHash, class, blockNum); err != nil { return err } } } - } - for classHash, class := range classes { - if class == nil { // mark as deleted - if err := DeleteClass(batch, &classHash); err != nil { - return err - } - } else { - if err := WriteClass(batch, &classHash, class, blockNum); err != nil { - return err - } - } + return nil + }) + + if err := p.Wait(); err != nil { + return err } return batch.Write() diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go index 1bec9289f5..414037a8d0 100644 --- a/core/trie2/triedb/database.go +++ b/core/trie2/triedb/database.go @@ -59,7 +59,6 @@ func (d *Database) Update( blockNum uint64, mergeClassNodes, mergeContractNodes *trienode.MergeNodeSet, - batch db.KeyValueWriter, ) error { switch td := d.triedb.(type) { case *pathdb.Database: @@ -67,7 +66,7 @@ func (d *Database) Update( case *hashdb.Database: return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes) case *rawdb.Database: - return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes, batch) + return td.Update(root, parent, blockNum, mergeClassNodes, mergeContractNodes) default: return fmt.Errorf("unsupported trie db type: %T", td) } diff --git a/core/trie2/triedb/rawdb/database.go b/core/trie2/triedb/rawdb/database.go index 6ace8dce01..7a8faadb26 100644 --- a/core/trie2/triedb/rawdb/database.go +++ b/core/trie2/triedb/rawdb/database.go @@ -63,11 +63,11 @@ func (d *Database) Update( blockNum uint64, mergedClassNodes *trienode.MergeNodeSet, mergedContractNodes *trienode.MergeNodeSet, - batch db.KeyValueWriter, ) error { d.lock.Lock() defer d.lock.Unlock() + batch := d.disk.NewBatch() var classNodes classNodesMap var contractNodes contractNodesMap var contractStorageNodes contractStorageNodesMap @@ -99,7 +99,7 @@ func (d *Database) Update( } } - return nil + return batch.Write() } func (d *Database) updateNode(