diff --git a/.github/workflows/juno-test.yml b/.github/workflows/juno-test.yml index 1dac659341..4d047b5c19 100644 --- a/.github/workflows/juno-test.yml +++ b/.github/workflows/juno-test.yml @@ -67,4 +67,4 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true - files: coverage.out + files: coverage/coverage.old.out,coverage/coverage.new.out diff --git a/Makefile b/Makefile index 9d9000782b..d41a9dc26c 100644 --- a/Makefile +++ b/Makefile @@ -73,6 +73,9 @@ clean-testcache: ## Clean Go test cache test: clean-testcache rustdeps ## Run tests go test $(GO_TAGS) ./... +test-new-state: clean-testcache rustdeps ## Run tests with new state + JUNO_NEW_STATE=true go test $(GO_TAGS) ./... + test-cached: rustdeps ## Run cached tests go test $(GO_TAGS) ./... @@ -82,10 +85,11 @@ test-race: clean-testcache rustdeps ## Run tests with race detection benchmarks: rustdeps ## Run benchmarks go test $(GO_TAGS) ./... -run=^# -bench=. -benchmem -test-cover: clean-testcache rustdeps ## Run tests with coverage +test-cover: clean-testcache rustdeps ## Run tests with coverage in both old- and new-state modes mkdir -p coverage - go test $(GO_TAGS) -coverpkg=$(PKG) -coverprofile=coverage/coverage.out -covermode=atomic $(PKG) - go tool cover -html=coverage/coverage.out -o coverage/coverage.html + go test $(GO_TAGS) -coverpkg=$(PKG) -coverprofile=coverage/coverage.old.out -covermode=atomic $(PKG) + JUNO_NEW_STATE=true go test $(GO_TAGS) -coverpkg=$(PKG) -coverprofile=coverage/coverage.new.out -covermode=atomic $(PKG) + go tool cover -html=coverage/coverage.old.out -o coverage/coverage.html install-deps: install-gofumpt install-mockgen install-golangci-lint check-rust ## Install dependencies diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go index e165d178b4..840128658f 100644 --- a/blockchain/blockchain_test.go +++ b/blockchain/blockchain_test.go @@ -11,6 +11,9 @@ import ( "github.com/NethermindEth/juno/core/deprecatedstate" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + "github.com/NethermindEth/juno/core/state" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -25,7 +28,11 @@ func TestNew(t *testing.T) { client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) t.Run("empty blockchain's head is nil", func(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) assert.Equal(t, &networks.Mainnet, chain.Network()) b, err := chain.Head() assert.Nil(t, b) @@ -39,10 +46,18 @@ func TestNew(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &networks.Mainnet) + chain = blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) b, err := chain.Head() require.NoError(t, err) assert.Equal(t, block0, b) @@ -53,7 +68,11 @@ func TestHeight(t *testing.T) { client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) t.Run("return nil if blockchain is empty", func(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Sepolia) + chain := blockchain.New( + memory.New(), + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) _, err := chain.Height() assert.Error(t, err) }) @@ -65,10 +84,18 @@ func TestHeight(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &networks.Mainnet) + chain = blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) height, err := chain.Height() require.NoError(t, err) assert.Equal(t, block0.Number, height) @@ -76,7 +103,11 @@ func TestHeight(t *testing.T) { } func TestBlockByNumberAndHash(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Sepolia) + chain := blockchain.New( + memory.New(), + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) t.Run("same block is returned for both core.GetBlockByNumber and GetBlockByHash", func(t *testing.T) { client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -110,7 +141,11 @@ func TestBlockByNumberAndHash(t *testing.T) { func TestSanityCheckNewHeight(t *testing.T) { h1 := felt.NewRandom[felt.Felt]() - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) @@ -157,7 +192,11 @@ func TestStore(t *testing.T) { t.Run("add block to empty blockchain", func(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) require.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) headBlock, err := chain.Head() @@ -185,7 +224,11 @@ func TestStore(t *testing.T) { stateUpdate1, err := gw.StateUpdate(t.Context(), 1) require.NoError(t, err) - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) require.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) require.NoError(t, chain.Store(block1, &emptyCommitments, stateUpdate1, nil)) @@ -209,7 +252,11 @@ func TestStore(t *testing.T) { func TestStoreL1HandlerTxnHash(t *testing.T) { client := feeder.NewTestClient(t, &networks.Sepolia) gw := adaptfeeder.New(client) - chain := blockchain.New(memory.New(), &networks.Sepolia) + chain := blockchain.New( + memory.New(), + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) var stateUpdate *core.StateUpdate for i := range uint64(7) { block, err := gw.BlockByNumber(t.Context(), i) @@ -228,7 +275,11 @@ func TestStoreL1HandlerTxnHash(t *testing.T) { } func TestBlockCommitments(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -253,7 +304,11 @@ func TestBlockCommitments(t *testing.T) { } func TestTransactionAndReceipt(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -391,7 +446,11 @@ func TestTransactionAndReceipt(t *testing.T) { func TestState(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -460,7 +519,11 @@ func TestEvents(t *testing.T) { } testDB := memory.New() - chain := blockchain.New(testDB, &networks.Goerli2) + chain := blockchain.New( + testDB, + &networks.Goerli2, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Goerli2) gw := adaptfeeder.New(client) @@ -668,7 +731,11 @@ func TestEvents(t *testing.T) { func TestRevert(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -763,7 +830,11 @@ func TestRevertHeadMigratedCasmClasses(t *testing.T) { } testDB := memory.New() - chain := blockchain.New(testDB, &networks.Integration) + chain := blockchain.New( + testDB, + &networks.Integration, + blockchain.WithNewState(statetestutils.UseNewState()), + ) receipts0 := make([]*core.TransactionReceipt, 0) //nolint:dupl // Similar to block0 in `TestRevertHeadDeclaredV2CasmClasses` @@ -816,6 +887,7 @@ func TestRevertHeadMigratedCasmClasses(t *testing.T) { } stateUpdate1 := &core.StateUpdate{ + OldRoot: stateUpdate0.NewRoot, StateDiff: &core.StateDiff{ MigratedClasses: map[felt.SierraClassHash]felt.CasmClassHash{ sierraHash: v2CasmHash, @@ -875,8 +947,11 @@ func TestRevertHeadDeclaredV2CasmClasses(t *testing.T) { } testDB := memory.New() - chain := blockchain.New(testDB, network) - + chain := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) // Block 0: empty genesis block so we have a head after reverting block 1 receipts0 := make([]*core.TransactionReceipt, 0) //nolint:dupl // Similar to block0 in `TestRevertHeadMigratedClass` @@ -924,6 +999,7 @@ func TestRevertHeadDeclaredV2CasmClasses(t *testing.T) { } stateUpdate1 := &core.StateUpdate{ + OldRoot: stateUpdate0.NewRoot, StateDiff: &core.StateDiff{ DeclaredV1Classes: map[felt.Felt]*felt.Felt{ sierraClassHashFelt: (*felt.Felt)(&v2CasmHash), @@ -968,7 +1044,11 @@ func TestL1Update(t *testing.T) { for _, head := range heads { t.Run(fmt.Sprintf("update L1 head to block %d", head.BlockNumber), func(t *testing.T) { - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) require.NoError(t, chain.SetL1Head(head)) got, err := chain.L1Head() require.NoError(t, err) @@ -983,7 +1063,11 @@ func TestSubscribeL1Head(t *testing.T) { StateRoot: new(felt.Felt).SetUint64(2), } - chain := blockchain.New(memory.New(), &networks.Mainnet) + chain := blockchain.New( + memory.New(), + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) sub := chain.SubscribeL1Head() t.Cleanup(sub.Unsubscribe) @@ -1000,9 +1084,19 @@ func chainStateCommitment(t *testing.T, database db.KeyValueStore) felt.Felt { header, err := core.GetBlockHeaderByNumber(database, height) require.NoError(t, err) - //nolint:staticcheck,nolintlint // used by old state - txn := database.NewIndexedBatch() - commitment, err := deprecatedstate.New(txn).Commitment(header.ProtocolVersion) + if !statetestutils.UseNewState() { + //nolint:staticcheck,nolintlint // used by old state + txn := database.NewIndexedBatch() + commitment, err := deprecatedstate.New(txn).Commitment(header.ProtocolVersion) + require.NoError(t, err) + return commitment + } + + trieDB := triedb.New(database, nil) + stateDB := state.NewStateDB(database, trieDB) + state, err := state.NewStateReader(header.GlobalStateRoot, stateDB) + require.NoError(t, err) + commitment, err := state.Commitment(header.ProtocolVersion) require.NoError(t, err) return commitment } diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index b28834d1ca..56419a8817 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -48,13 +48,12 @@ func DBCmd(defaultDBPath string) *cobra.Command { } func DBInfoCmd() *cobra.Command { - cmd := &cobra.Command{ + return &cobra.Command{ Use: "info", Short: "Retrieve database information", Long: `This subcommand retrieves and displays blockchain information stored in the database.`, RunE: dbInfo, } - return cmd } func DBSizeCmd() *cobra.Command { @@ -74,6 +73,8 @@ func DBRevertCmd() *cobra.Command { RunE: dbRevert, } cmd.Flags().Uint64(dbRevertToBlockF, 0, "New head (this block won't be reverted)") + cmd.Flags().Bool(newStateF, defaultNewState, newStateUsage) + return cmd } @@ -116,7 +117,7 @@ func dbInfo(cmd *cobra.Command, args []string) error { database, ) if err != nil { - return fmt.Errorf("failed to get deprecatedschema metadata: %v", err) + return fmt.Errorf("failed to get deprecated schema metadata: %v", err) } schemaVersion = deprecatedMetadata.Version } else { @@ -164,6 +165,11 @@ func dbRevert(cmd *cobra.Command, args []string) error { return fmt.Errorf("--%v cannot be 0", dbRevertToBlockF) } + newState, err := cmd.Flags().GetBool(newStateF) + if err != nil { + return err + } + database, err := openDB(dbPath) if err != nil { return err @@ -171,7 +177,11 @@ func dbRevert(cmd *cobra.Command, args []string) error { defer database.Close() for { - chain := blockchain.New(database, nil) + chain := blockchain.New( + database, + nil, + blockchain.WithNewState(newState), + ) head, err := chain.Head() if err != nil { return fmt.Errorf("failed to get the latest block information: %v", err) diff --git a/cmd/juno/dbcmd_test.go b/cmd/juno/dbcmd_test.go index b34cba5514..7bc51e1c3e 100644 --- a/cmd/juno/dbcmd_test.go +++ b/cmd/juno/dbcmd_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" juno "github.com/NethermindEth/juno/cmd/juno" "github.com/NethermindEth/juno/core" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/pebblev2" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/spf13/cobra" @@ -44,6 +45,9 @@ func TestDBCmd(t *testing.T) { require.NoError(t, cmd.Flags().Set("db-path", dbPath)) require.NoError(t, cmd.Flags().Set("to-block", strconv.Itoa(int(revertToBlock)))) + if statetestutils.UseNewState() { + require.NoError(t, cmd.Flags().Set("new-state", "true")) + } require.NoError(t, cmd.Execute()) // unfortunately we cannot use blockchain from prepareDB because @@ -55,7 +59,11 @@ func TestDBCmd(t *testing.T) { require.NoError(t, db.Close()) }) - chain := blockchain.New(db, &network) + chain := blockchain.New( + db, + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) block, err := chain.Head() require.NoError(t, err) assert.Equal(t, revertToBlock, block.Number) @@ -79,8 +87,11 @@ func prepareDB(t *testing.T, network *networks.Network, syncToBlock uint64) stri testDB, err := pebblev2.New(dbPath) require.NoError(t, err) - chain := blockchain.New(testDB, network) - + chain := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) for blockNumber := uint64(0); blockNumber <= syncToBlock; blockNumber++ { block, err := gw.BlockByNumber(t.Context(), blockNumber) require.NoError(t, err) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index d667866a87..497141b876 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -157,6 +157,7 @@ const ( defaultHTTPUpdatePort = 0 defaultSubmittedTransactionsCacheSize = 10_000 defaultSubmittedTransactionsCacheEntryTTL = 5 * time.Minute + defaultNewState = false defaultDisableRPCBatchRequests = false defaultDBCompactionConcurrency = "" defaultDBMemtableSize = 256 @@ -165,6 +166,7 @@ const ( defaultRPCRequestTimeout = 1 * time.Minute defaultMaxConcurrentCompilations = 8 defaultDisableReceivedTxnStream = false + newStateF = "new-state" configFlagUsage = "The YAML configuration file." logLevelFlagUsage = "Options: trace, debug, info, warn, error." @@ -257,6 +259,7 @@ const ( "submitted through this node — these transactions are local to the node " + "and are not sourced from the network. When this flag is enabled, the " + "node will no longer notify subscribers about transactions submitted through it." + newStateUsage = "EXPERIMENTAL: Use the new state package implementation" ) var Version string @@ -475,6 +478,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr defaultSubmittedTransactionsCacheEntryTTL, submittedTransactionsCacheEntryTTL, ) + junoCmd.Flags().Bool(newStateF, defaultNewState, newStateUsage) junoCmd.Flags().Bool( disableRPCBatchRequestsF, defaultDisableRPCBatchRequests, disableRPCBatchRequestsUsage, ) diff --git a/consensus/consensus_test.go b/consensus/consensus_test.go index 68fbee86d2..82fd48048a 100644 --- a/consensus/consensus_test.go +++ b/consensus/consensus_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/p2p/dht" @@ -67,7 +68,11 @@ func getBlockchain( t.Helper() testDB := memory.New() network := &networks.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) require.NoError(t, bc.StoreGenesis(&genesisDiff, genesisClasses)) return bc } @@ -97,6 +102,7 @@ func loadGenesis( &network, vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) diff --git a/consensus/p2p/validator/empty_fixtures_test.go b/consensus/p2p/validator/empty_fixtures_test.go index 89e01c3887..b2d915d1ff 100644 --- a/consensus/p2p/validator/empty_fixtures_test.go +++ b/consensus/p2p/validator/empty_fixtures_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" "github.com/starknet-io/starknet-p2p-specs/p2p/proto/consensus/consensus" @@ -41,7 +42,11 @@ func NewEmptyTestFixture( executor.RegisterBuildResult(&buildResult) - b := builder.New(blockchain.New(database, testCase.Network), executor) + b := builder.New(blockchain.New( + database, + testCase.Network, + blockchain.WithNewState(statetestutils.UseNewState()), + ), executor) proposalCommitment := EmptyProposalCommitment(headBlock, proposer, timestamp) diff --git a/consensus/p2p/validator/fixtures_test.go b/consensus/p2p/validator/fixtures_test.go index 229d1adec8..d4ee32f239 100644 --- a/consensus/p2p/validator/fixtures_test.go +++ b/consensus/p2p/validator/fixtures_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/starknet" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -114,7 +115,14 @@ func BuildTestFixture( executor.RegisterBuildResult(&buildResult) - builder := builder.New(blockchain.New(database, testCase.Network), executor) + builder := builder.New( + blockchain.New( + database, + testCase.Network, + blockchain.WithNewState(statetestutils.UseNewState()), + ), + executor, + ) return TestFixture{ ProposalInit: &proposalInit, diff --git a/consensus/p2p/validator/proposal_stream_demux_test.go b/consensus/p2p/validator/proposal_stream_demux_test.go index 999b1ae211..458477521f 100644 --- a/consensus/p2p/validator/proposal_stream_demux_test.go +++ b/consensus/p2p/validator/proposal_stream_demux_test.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/consensus/starknet" "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/p2p/pubsub/testutils" @@ -53,7 +54,11 @@ func TestProposalStreamDemux(t *testing.T) { executor := NewMockExecutor(t, network) database := memory.New() - bc := blockchain.New(database, network) + bc := blockchain.New( + database, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) builder := builder.New(bc, executor) transition := NewTransition(&builder, nil) proposalStore := proposal.ProposalStore[starknet.Hash]{} diff --git a/consensus/p2p/validator/transition_test.go b/consensus/p2p/validator/transition_test.go index 72423f0d9d..04ba434414 100644 --- a/consensus/p2p/validator/transition_test.go +++ b/consensus/p2p/validator/transition_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -31,7 +32,11 @@ func getBuilder(t *testing.T, seqAddr *felt.Felt) (*builder.Builder, *core.Heade t.Helper() testDB := memory.New() network := &networks.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) logger := log.NewNopZapLogger() privKey, err := ecdsa.GenerateKey(rand.Reader) @@ -57,6 +62,7 @@ func getBuilder(t *testing.T, seqAddr *felt.Felt) (*builder.Builder, *core.Heade bc.Network(), vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) diff --git a/consensus/proposer/proposer_test.go b/consensus/proposer/proposer_test.go index 8ba468b5b4..1ab6278d28 100644 --- a/consensus/proposer/proposer_test.go +++ b/consensus/proposer/proposer_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/consensus/types" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -215,7 +216,11 @@ func getBlockchain(t *testing.T) *blockchain.Blockchain { t.Helper() testDB := memory.New() network := &networks.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) return bc } @@ -241,6 +246,7 @@ func getBuilder(t *testing.T, logger log.Logger, bc *blockchain.Blockchain) *bui bc.Network(), vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) diff --git a/core/running_event_filter_test.go b/core/running_event_filter_test.go index ad056c9d27..1f1413c667 100644 --- a/core/running_event_filter_test.go +++ b/core/running_event_filter_test.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/encoder" @@ -66,7 +67,11 @@ func TestRunningEventFilter_LazyInitialization_EmptyDB(t *testing.T) { func TestRunningEventFilter_LazyInitialization_Preload(t *testing.T) { testDB := memory.New() n := &networks.Sepolia - chain := blockchain.New(testDB, n) + chain := blockchain.New( + testDB, + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, n) gw := adaptfeeder.New(client) diff --git a/genesis/genesis.go b/genesis/genesis.go index e1921ab026..e478745568 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -13,6 +13,8 @@ import ( "github.com/NethermindEth/juno/core/deprecatedstate" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + st "github.com/NethermindEth/juno/core/state" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/db/memory" rpc "github.com/NethermindEth/juno/rpc/v8" "github.com/NethermindEth/juno/starknet" @@ -100,14 +102,29 @@ func GenesisStateDiff( network *networks.Network, maxSteps uint64, maxGas uint64, + useNewState bool, compiler compiler.Compiler, ) (core.StateDiff, map[felt.Felt]core.ClassDefinition, error) { initialStateDiff := core.EmptyStateDiff() memDB := memory.New() + + var state core.StateReader + var err error + if !useNewState { + state = deprecatedstate.New(memDB.NewIndexedBatch()) + } else { + triedb := triedb.New(memDB, nil) + stateDB := st.NewStateDB(memDB, triedb) + state, err = st.NewStateReader(&felt.Zero, stateDB) + if err != nil { + return core.StateDiff{}, nil, err + } + } + genesisState := pending.NewStateWriter( &initialStateDiff, make(map[felt.Felt]core.ClassDefinition, len(config.Classes)), - deprecatedstate.New(memDB.NewIndexedBatch()), + state, ) if err := declareClasses(ctx, config, &genesisState, compiler); err != nil { diff --git a/genesis/genesis_test.go b/genesis/genesis_test.go index 2e299cf497..99fa0ddb55 100644 --- a/genesis/genesis_test.go +++ b/genesis/genesis_test.go @@ -5,6 +5,7 @@ import ( "github.com/NethermindEth/juno/blockchain/networks" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/starknet/compiler" "github.com/NethermindEth/juno/utils/log" @@ -30,6 +31,7 @@ func TestGenesisStateDiff(t *testing.T) { network, vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), nil, ) require.NoError(t, err) @@ -54,6 +56,7 @@ func TestGenesisStateDiff(t *testing.T) { network, vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) diff --git a/l1/l1_pkg_test.go b/l1/l1_pkg_test.go index 72414d5c02..4f45ad65e1 100644 --- a/l1/l1_pkg_test.go +++ b/l1/l1_pkg_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/blockchain/networks" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1/contract" "github.com/NethermindEth/juno/mocks" @@ -337,7 +338,11 @@ func TestClient(t *testing.T) { ctrl := gomock.NewController(t) nopLog := log.NewNopZapLogger() network := networks.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New( + memory.New(), + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := NewClient(nil, chain, nopLog).WithResubscribeDelay(0).WithPollFinalisedInterval(time.Nanosecond) @@ -398,7 +403,11 @@ func TestUnreliableSubscription(t *testing.T) { ctrl := gomock.NewController(t) nopLog := log.NewNopZapLogger() network := networks.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New( + memory.New(), + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := NewClient(nil, chain, nopLog).WithResubscribeDelay(0).WithPollFinalisedInterval(time.Nanosecond) err := errors.New("test err") diff --git a/l1/l1_test.go b/l1/l1_test.go index 5d68586092..430c0557f0 100644 --- a/l1/l1_test.go +++ b/l1/l1_test.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/blockchain/networks" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/l1" "github.com/NethermindEth/juno/l1/contract" @@ -55,7 +56,11 @@ func TestFailToCreateSubscription(t *testing.T) { network := networks.Mainnet ctrl := gomock.NewController(t) nopLog := log.NewNopZapLogger() - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New( + memory.New(), + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) subscriber := mocks.NewMockSubscriber(ctrl) @@ -86,7 +91,11 @@ func TestMismatchedChainID(t *testing.T) { network := networks.Mainnet ctrl := gomock.NewController(t) nopLog := log.NewNopZapLogger() - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New( + memory.New(), + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) subscriber := mocks.NewMockSubscriber(ctrl) @@ -111,7 +120,11 @@ func TestEventListener(t *testing.T) { ctrl := gomock.NewController(t) nopLog := log.NewNopZapLogger() network := networks.Mainnet - chain := blockchain.New(memory.New(), &network) + chain := blockchain.New( + memory.New(), + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) subscriber := mocks.NewMockSubscriber(ctrl) subscriber. diff --git a/mempool/mempool.go b/mempool/mempool.go index 8bfe93bc47..ea06e0c440 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -294,7 +294,7 @@ func (p *SequencerMempool) validate(userTxn *BroadcastedTransaction) error { return fmt.Errorf("deploy transactions are not supported") case *core.DeployAccountTransaction: if !t.Nonce.IsZero() { - return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce) + return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce.String()) } case *core.DeclareTransaction: nonce, err := state.ContractNonce(t.SenderAddress) diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 97e1beb7c4..9e8ba20828 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebblev2" _ "github.com/NethermindEth/juno/encoder/registry" @@ -195,7 +196,11 @@ func TestWait(t *testing.T) { defer dbCloser() mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) - bc := blockchain.New(testDB, &networks.Sepolia) + bc := blockchain.New( + testDB, + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) block0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) stateUpdate0, err := gw.StateUpdate(t.Context(), 0) diff --git a/migration/deprecated/migration_pkg_test.go b/migration/deprecated/migration_pkg_test.go index a720380da1..208e082412 100644 --- a/migration/deprecated/migration_pkg_test.go +++ b/migration/deprecated/migration_pkg_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" @@ -89,7 +90,11 @@ func TestRelocateContractStorageRootKeys(t *testing.T) { func TestRecalculateBloomFilters(t *testing.T) { testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) diff --git a/mocks/mock_gateway_handler.go b/mocks/mock_gateway.go similarity index 94% rename from mocks/mock_gateway_handler.go rename to mocks/mock_gateway.go index aae1f2011a..a918086072 100644 --- a/mocks/mock_gateway_handler.go +++ b/mocks/mock_gateway.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/NethermindEth/juno/rpc (interfaces: Gateway) +// Source: github.com/NethermindEth/juno/rpc/rpccore (interfaces: Gateway) // // Generated by this command: // @@ -20,6 +20,7 @@ import ( type MockGateway struct { ctrl *gomock.Controller recorder *MockGatewayMockRecorder + isgomock struct{} } // MockGatewayMockRecorder is the mock recorder for MockGateway. diff --git a/node/genesis.go b/node/genesis.go index bcc47f3e7f..dba16e4994 100644 --- a/node/genesis.go +++ b/node/genesis.go @@ -20,6 +20,7 @@ func buildGenesis( v vm.VM, maxSteps uint64, maxGas uint64, + useNewState bool, compiler compiler.Compiler, ) error { if _, err := bc.Height(); !errors.Is(err, db.ErrKeyNotFound) { @@ -42,6 +43,7 @@ func buildGenesis( bc.Network(), maxSteps, maxGas, + useNewState, compiler, ) if err != nil { diff --git a/node/node.go b/node/node.go index fa81188285..629956b536 100644 --- a/node/node.go +++ b/node/node.go @@ -132,6 +132,7 @@ type Config struct { RPCRequestTimeout time.Duration `mapstructure:"rpc-request-timeout"` MaxConcurrentCompilations uint `mapstructure:"max-concurrent-compilations"` + NewState bool `mapstructure:"new-state"` } type Node struct { @@ -197,15 +198,19 @@ func New(cfg *Config, version string, logLevel *log.Level) (*Node, error) { if err != nil { return nil, fmt.Errorf("open DB: %w", err) } + ua := fmt.Sprintf("Juno/%s Starknet Client", version) services := make([]service.Service, 0) earlyServices := make([]service.Service, 0) - opts := make([]blockchain.Option, 0, 1) + opts := make([]blockchain.Option, 0, 2) if cfg.Metrics { opts = append(opts, blockchain.WithListener(makeBlockchainMetrics())) } + opts = append(opts, blockchain.WithNewState( + cfg.NewState, + )) chain := blockchain.New(database, &cfg.Network, opts...) // Verify that cfg.Network is compatible with the database. @@ -632,6 +637,7 @@ func (n *Node) Run(ctx context.Context) { vm.New(&chainInfo, false, n.logger), n.cfg.RPCCallMaxSteps, n.cfg.RPCCallMaxGas, + n.cfg.NewState, n.compiler, ) if err != nil { diff --git a/node/node_test.go b/node/node_test.go index fc9fd4b51d..6f9cb14f46 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/blockchain/networks" "github.com/NethermindEth/juno/clients/feeder" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/pebblev2" "github.com/NethermindEth/juno/node" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -76,7 +77,11 @@ func TestNetworkVerificationOnNonEmptyDB(t *testing.T) { logger := log.NewNopZapLogger() database, err := pebblev2.New(dbPath) require.NoError(t, err) - chain := blockchain.New(database, &network) + chain := blockchain.New( + database, + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) ctx, cancel := context.WithCancel(t.Context()) dataSource := sync.NewFeederGatewayDataSource(chain, adaptfeeder.New(feeder.NewTestClient(t, &network))) syncer := sync.New(chain, dataSource, logger, 0, 0, false, database). @@ -95,6 +100,7 @@ func TestNetworkVerificationOnNonEmptyDB(t *testing.T) { DatabasePath: dbPath, DBCompression: "zstd", Network: test.network, + NewState: statetestutils.UseNewState(), DisableL1Verification: true, SubmittedTransactionsCacheEntryTTL: time.Second, }, "v0.1", logLevel) diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index 33e918bd75..f794c1d05a 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/blockchain/networks" "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" junoplugin "github.com/NethermindEth/juno/plugin" @@ -40,7 +41,11 @@ func TestPlugin(t *testing.T) { require.NoError(t, err) plugin.EXPECT().NewBlock(block, su, gomock.Any()) } - bc := blockchain.New(testDB, &networks.Integration) + bc := blockchain.New( + testDB, + &networks.Integration, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, integGw) synchronizer := sync.New( bc, @@ -57,7 +62,11 @@ func TestPlugin(t *testing.T) { cancel() t.Run("resync to mainnet with the same db", func(t *testing.T) { - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) // Ensure current head is Integration head head, err := bc.HeadsHeader() diff --git a/rpc/v10/block_test.go b/rpc/v10/block_test.go index 3cc93e3fc8..94bd1de7af 100644 --- a/rpc/v10/block_test.go +++ b/rpc/v10/block_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -626,7 +627,11 @@ func TestBlockWithTxHashes_ErrorCases(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) mockSyncReader := mocks.NewMockSyncReader(mockCtrl) handler := rpc.New(chain, mockSyncReader, nil, logger) @@ -755,7 +760,11 @@ func TestBlockWithTxs_ErrorCases(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) mockSyncReader := mocks.NewMockSyncReader(mockCtrl) handler := rpc.New(chain, mockSyncReader, nil, logger) @@ -902,7 +911,11 @@ func TestBlockWithReceipts_ErrorCases(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) mockSyncReader := mocks.NewMockSyncReader(mockCtrl) handler := rpc.New(chain, mockSyncReader, nil, logger) diff --git a/rpc/v10/events_test.go b/rpc/v10/events_test.go index bae6f8aa91..91bc027642 100644 --- a/rpc/v10/events_test.go +++ b/rpc/v10/events_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -178,7 +179,11 @@ func setupTestChain( ) (*blockchain.Blockchain, *adaptfeeder.Feeder) { t.Helper() testDB := memory.New() - chain := blockchain.New(testDB, network) + chain := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, network) gw := adaptfeeder.New(client) diff --git a/rpc/v10/state_update_test.go b/rpc/v10/state_update_test.go index add4500bd2..18fbe607da 100644 --- a/rpc/v10/state_update_test.go +++ b/rpc/v10/state_update_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -36,7 +37,11 @@ func TestStateUpdate_ErrorCases(t *testing.T) { n := &networks.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PreConfirmed().Return(nil, db.ErrKeyNotFound) diff --git a/rpc/v10/storage.go b/rpc/v10/storage.go index 287760c4c4..e6956ef0e9 100644 --- a/rpc/v10/storage.go +++ b/rpc/v10/storage.go @@ -8,6 +8,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -303,21 +305,26 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp } func getClassProof(tr core.Trie, classes []felt.Felt) ([]*HashToNode, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", tr) - } - - classProof := trie.NewProofNodeSet() - for _, class := range classes { - if err := t.Prove(&class, classProof); err != nil { - return nil, err + switch t := tr.(type) { + case *trie.Trie: + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } } + return adaptDeprecatedTrieProofNodes(classProof), nil + case *trie2.Trie: + classProof := trie2.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } + } + return adaptTrieProofNodes(classProof) + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) } - - return adaptProofNodes(classProof), nil } func getContractProof( @@ -325,23 +332,77 @@ func getContractProof( state core.StateReader, contracts []felt.Felt, ) (*ContractProof, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { + switch t := tr.(type) { + case *trie.Trie: + return getContractProofWithDeprecatedTrie(t, state, contracts) + case *trie2.Trie: + return getContractProofWithTrie(t, state, contracts) + default: return nil, fmt.Errorf("unknown trie type: %T", tr) } +} +func getContractProofWithDeprecatedTrie( + tr *trie.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { contractProof := trie.NewProofNodeSet() - contractLeavesData := make([]*LeafData, len(contracts)) - for i, contract := range contracts { - if err := t.Prove(&contract, contractProof); err != nil { + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + + return &ContractProof{ + Nodes: adaptDeprecatedTrieProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractProofWithTrie( + tr *trie2.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { + contractProof := trie2.NewProofNodeSet() + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { return nil, err } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + + nodes, err := adaptTrieProofNodes(contractProof) + if err != nil { + return nil, err + } + + return &ContractProof{ + Nodes: nodes, + LeavesData: contractLeavesData, + }, nil +} +func buildContractLeavesData( + state core.StateReader, + contracts []felt.Felt, +) ([]*LeafData, error) { + contractLeavesData := make([]*LeafData, len(contracts)) + for i, contract := range contracts { classHash, err := state.ContractClassHash(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, db.ErrKeyNotFound) { continue } return nil, err @@ -368,11 +429,7 @@ func getContractProof( StorageRoot: &storageRoot, } } - - return &ContractProof{ - Nodes: adaptProofNodes(contractProof), - LeavesData: contractLeavesData, - }, nil + return contractLeavesData, nil } func getContractStorageProof( @@ -386,24 +443,36 @@ func getContractStorageProof( return nil, err } - t, ok := contractStorageTrie.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", t) - } - contractStorageProof := trie.NewProofNodeSet() - for _, key := range storageKey.Keys { - if err := t.Prove(&key, contractStorageProof); err != nil { + switch t := contractStorageTrie.(type) { + case *trie.Trie: + contractStorageProof := trie.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) + case *trie2.Trie: + contractStorageProof := trie2.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + nodes, err := adaptTrieProofNodes(contractStorageProof) + if err != nil { return nil, err } + contractStorageRes[i] = nodes + default: + return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } - - contractStorageRes[i] = adaptProofNodes(contractStorageProof) } return contractStorageRes, nil } -func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptDeprecatedTrieProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -433,6 +502,60 @@ func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } +func adaptTrieProofNodes(proof *trie2.ProofNodeSet) ([]*HashToNode, error) { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trienode.BinaryNode: + leftChild, err := nodeFelt(n.Children[0]) + if err != nil { + return nil, err + } + rightChild, err := nodeFelt(n.Children[1]) + if err != nil { + return nil, err + } + node = &BinaryNode{ + Left: &leftChild, + Right: &rightChild, + } + case *trienode.EdgeNode: + pathFelt := n.Path.Felt() + child, err := nodeFelt(n.Child) + if err != nil { + return nil, err + } + + node = &EdgeNode{ + Path: pathFelt.String(), + Length: int(n.Path.Len()), + Child: &child, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes, nil +} + +func nodeFelt(n trienode.Node) (felt.Felt, error) { + switch n := n.(type) { + case *trienode.HashNode: + return felt.Felt(*n), nil + case *trienode.ValueNode: + return felt.Felt(*n), nil + default: + return felt.Felt{}, fmt.Errorf("unknown node type: %T", n) + } +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v10/storage_test.go b/rpc/v10/storage_test.go index 9bad7f2178..142e1c4679 100644 --- a/rpc/v10/storage_test.go +++ b/rpc/v10/storage_test.go @@ -15,7 +15,11 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -565,11 +569,63 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - tempTrie := emptyTrie(t) - _, _ = tempTrie.Put(key, value) - _, _ = tempTrie.Put(key2, value2) - _ = tempTrie.Commit() - trieRoot, _ := tempTrie.Hash() + var classTrie, contractTrie core.Trie + var trieRoot felt.Felt + + if !statetestutils.UseNewState() { + tempTrie := emptyDeprecatedTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ = tempTrie.Hash() + classTrie = tempTrie + contractTrie = tempTrie + } else { + newComm := felt.FromUint64[felt.StateRootHash](1) + createTrie := func( + t *testing.T, + id trieutils.TrieID, + trieDB *trie2.TestNodeDatabase, + ) *trie2.Trie { + tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) + _ = tr.Update(key, value) + _ = tr.Update(key2, value2) + require.NoError(t, err) + _, nodes := tr.Commit() + err = trieDB.Update((*felt.Felt)(&newComm), &felt.Zero, trienode.NewMergeNodeSet(nodes)) + require.NoError(t, err) + return tr + } + + trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) + createTrie(t, trieutils.NewClassTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + contractTrie2 := createTrie(t, trieutils.NewContractTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + tmpTrieRoot, err := contractTrie2.Hash() + require.NoError(t, err) + trieRoot = tmpTrieRoot + + // recreate because the previous ones are committed + classTrie2, err := trie2.New( + trieutils.NewClassTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + contractTrie2, err = trie2.New( + trieutils.NewContractTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + classTrie = classTrie2 + contractTrie = contractTrie2 + } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -579,8 +635,8 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() - mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() - mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(classTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(contractTrie, nil).AnyTimes() logger := log.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, logger) @@ -733,7 +789,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -743,7 +799,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -759,8 +815,8 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -773,7 +829,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -783,7 +839,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil).Times(1) classHash := felt.NewFromUint64[felt.Felt](1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil).Times(1) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil).Times(1) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -795,7 +851,7 @@ func TestStorageProof(t *testing.T) { require.Equal(t, nonce, ld.Nonce) require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run( "contract leaf StorageRoot is the contract storage trie root, not the global contracts trie root", @@ -805,7 +861,7 @@ func TestStorageProof(t *testing.T) { // Build a separate storage trie with different contents so its root differs from // the contracts trie root. - contractStorageTrie := emptyTrie(t) + contractStorageTrie := emptyDeprecatedTrie(t) storageKey := felt.NewFromUint64[felt.Felt](99) storageVal := felt.NewFromUint64[felt.Felt](999) _, _ = contractStorageTrie.Put(storageKey, storageVal) @@ -854,7 +910,7 @@ func TestStorageProof(t *testing.T) { Return(headBlock.Header, nil) contract := felt.NewFromUint64[felt.Felt](0xabcd) - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -863,7 +919,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl // Similar code, but testing a different case t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { @@ -871,7 +927,7 @@ func TestStorageProof(t *testing.T) { Return(headBlock.Header, nil) contract := felt.NewFromUint64[felt.Felt](0xabcd) - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -880,7 +936,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -890,7 +946,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil) classHash := felt.NewFromUint64[felt.Felt](1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil) proof, rpcErr := handler.StorageProof( &blockLatest, @@ -1291,7 +1347,11 @@ func TestStorageProof_StorageRoots(t *testing.T) { logger := log.NewNopZapLogger() testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, logger, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), time.Second) @@ -1512,7 +1572,7 @@ func verifyIf( require.Equal(t, leaf, *value) } -func emptyTrie(t *testing.T) *trie.Trie { +func emptyDeprecatedTrie(t *testing.T) *trie.Trie { memdb := memory.New() txn := memdb.NewIndexedBatch() @@ -1521,6 +1581,15 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } +func emptyTrie(t *testing.T) core.Trie { + if statetestutils.UseNewState() { + tempTrie, err := trie2.NewEmptyPedersen() + require.NoError(t, err) + return tempTrie + } + return emptyDeprecatedTrie(t) +} + func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { stateVersion := felt.NewFromBytes[felt.Felt]([]byte(`STARKNET_STATE_V0`)) if classRoot.IsZero() { diff --git a/rpc/v10/trace_test.go b/rpc/v10/trace_test.go index f58c2829f8..db955954de 100644 --- a/rpc/v10/trace_test.go +++ b/rpc/v10/trace_test.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -658,7 +659,11 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpcv10.New(chain, nil, nil, logger) if description == "pre_confirmed" { diff --git a/rpc/v8/block_test.go b/rpc/v8/block_test.go index 39e315b10f..f2036e582b 100644 --- a/rpc/v8/block_test.go +++ b/rpc/v8/block_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -263,7 +264,11 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpc.New(chain, mockSyncReader, nil, logger) @@ -454,7 +459,11 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpc.New(chain, mockSyncReader, nil, logger) diff --git a/rpc/v8/events_test.go b/rpc/v8/events_test.go index af935333e9..a06a87fa4b 100644 --- a/rpc/v8/events_test.go +++ b/rpc/v8/events_test.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" "github.com/NethermindEth/juno/rpc/rpccore" @@ -21,7 +22,11 @@ import ( func TestEvents(t *testing.T) { testDB := memory.New() n := &networks.Sepolia - chain := blockchain.New(testDB, n) + chain := blockchain.New( + testDB, + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) diff --git a/rpc/v8/state_update_test.go b/rpc/v8/state_update_test.go index 17eeab0e39..7449e8f5ac 100644 --- a/rpc/v8/state_update_test.go +++ b/rpc/v8/state_update_test.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" "github.com/NethermindEth/juno/rpc/rpccore" @@ -33,7 +34,11 @@ func TestStateUpdate(t *testing.T) { n := &networks.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpc.New(chain, mockSyncReader, nil, nil) diff --git a/rpc/v8/storage.go b/rpc/v8/storage.go index 55d47e4fb8..771105d989 100644 --- a/rpc/v8/storage.go +++ b/rpc/v8/storage.go @@ -7,6 +7,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -208,21 +210,26 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp } func getClassProof(tr core.Trie, classes []felt.Felt) ([]*HashToNode, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", tr) - } - - classProof := trie.NewProofNodeSet() - for _, class := range classes { - if err := t.Prove(&class, classProof); err != nil { - return nil, err + switch t := tr.(type) { + case *trie.Trie: + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } } + return adaptDeprecatedTrieProofNodes(classProof), nil + case *trie2.Trie: + classProof := trie2.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } + } + return adaptTrieProofNodes(classProof) + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) } - - return adaptProofNodes(classProof), nil } func getContractProof( @@ -230,23 +237,27 @@ func getContractProof( state core.StateReader, contracts []felt.Felt, ) (*ContractProof, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { + switch t := tr.(type) { + case *trie.Trie: + return getContractProofWithDeprecatedTrie(t, state, contracts) + case *trie2.Trie: + return getContractProofWithTrie(t, state, contracts) + default: return nil, fmt.Errorf("unknown trie type: %T", tr) } +} - contractProof := trie.NewProofNodeSet() +func buildContractLeavesData( + state core.StateReader, + contracts []felt.Felt, +) ([]*LeafData, error) { contractLeavesData := make([]*LeafData, len(contracts)) - for i, contract := range contracts { - if err := t.Prove(&contract, contractProof); err != nil { - return nil, err - } + for i, contract := range contracts { classHash, err := state.ContractClassHash(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, db.ErrKeyNotFound) { continue } return nil, err @@ -274,8 +285,58 @@ func getContractProof( } } + return contractLeavesData, nil +} + +func getContractProofWithDeprecatedTrie( + tr *trie.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { + contractProof := trie.NewProofNodeSet() + + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + + return &ContractProof{ + Nodes: adaptDeprecatedTrieProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractProofWithTrie( + tr *trie2.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { + contractProof := trie2.NewProofNodeSet() + + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + + nodes, err := adaptTrieProofNodes(contractProof) + if err != nil { + return nil, err + } + return &ContractProof{ - Nodes: adaptProofNodes(contractProof), + Nodes: nodes, LeavesData: contractLeavesData, }, nil } @@ -291,24 +352,38 @@ func getContractStorageProof( return nil, err } - t, ok := contractStorageTrie.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", t) - } - contractStorageProof := trie.NewProofNodeSet() - for _, key := range storageKey.Keys { - if err := t.Prove(&key, contractStorageProof); err != nil { + switch t := contractStorageTrie.(type) { + case *trie.Trie: + contractStorageProof := trie.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) + case *trie2.Trie: + contractStorageProof := trie2.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + + nodes, err := adaptTrieProofNodes(contractStorageProof) + if err != nil { return nil, err } - } - contractStorageRes[i] = adaptProofNodes(contractStorageProof) + contractStorageRes[i] = nodes + default: + return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) + } } return contractStorageRes, nil } -func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptDeprecatedTrieProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -338,6 +413,59 @@ func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } +func adaptTrieProofNodes(proof *trie2.ProofNodeSet) ([]*HashToNode, error) { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trienode.BinaryNode: + leftChild, err := nodeFelt(n.Children[0]) + if err != nil { + return nil, err + } + rightChild, err := nodeFelt(n.Children[1]) + if err != nil { + return nil, err + } + node = &BinaryNode{ + Left: &leftChild, + Right: &rightChild, + } + case *trienode.EdgeNode: + pathFelt := n.Path.Felt() + child, err := nodeFelt(n.Child) + if err != nil { + return nil, err + } + node = &EdgeNode{ + Path: pathFelt.String(), + Length: int(n.Path.Len()), + Child: &child, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes, nil +} + +func nodeFelt(n trienode.Node) (felt.Felt, error) { + switch n := n.(type) { + case *trienode.HashNode: + return felt.Felt(*n), nil + case *trienode.ValueNode: + return felt.Felt(*n), nil + default: + return felt.Zero, fmt.Errorf("unknown node type: %T", n) + } +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v8/storage_test.go b/rpc/v8/storage_test.go index 7c0de0275c..2fb8b254ce 100644 --- a/rpc/v8/storage_test.go +++ b/rpc/v8/storage_test.go @@ -12,7 +12,11 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -213,11 +217,62 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - tempTrie := emptyTrie(t) - _, _ = tempTrie.Put(key, value) - _, _ = tempTrie.Put(key2, value2) - _ = tempTrie.Commit() - trieRoot, _ := tempTrie.Hash() + var classTrie, contractTrie core.Trie + trieRoot := felt.Zero + + if !statetestutils.UseNewState() { + tempTrie := emptyDeprecatedTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ = tempTrie.Root() + classTrie = tempTrie + contractTrie = tempTrie + } else { + newComm := felt.FromUint64[felt.StateRootHash](1) + createTrie := func( + t *testing.T, + id trieutils.TrieID, + trieDB *trie2.TestNodeDatabase, + ) *trie2.Trie { + tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) + _ = tr.Update(key, value) + _ = tr.Update(key2, value2) + require.NoError(t, err) + _, nodes := tr.Commit() + err = trieDB.Update((*felt.Felt)(&newComm), &felt.Zero, trienode.NewMergeNodeSet(nodes)) + require.NoError(t, err) + return tr + } + + trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) + createTrie(t, trieutils.NewClassTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + contractTrie2 := createTrie(t, trieutils.NewContractTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + tmpTrieRoot, _ := contractTrie2.Hash() + trieRoot = tmpTrieRoot + + // recreate because the previous ones are committed + classTrie2, err := trie2.New( + trieutils.NewClassTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + contractTrie2, err = trie2.New( + trieutils.NewContractTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + classTrie = classTrie2 + contractTrie = contractTrie2 + } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -227,8 +282,8 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() - mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() - mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(classTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(contractTrie, nil).AnyTimes() logger := log.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, logger) @@ -365,7 +420,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -374,7 +429,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -389,14 +444,13 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). Return(headBlock.Header, nil) mockState.EXPECT().ContractClassHash(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(1) - mockState.EXPECT().ContractNonce(noSuchKey).Return(felt.Zero, db.ErrKeyNotFound).Times(0) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*noSuchKey}, nil) require.Nil(t, rpcErr) @@ -404,7 +458,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -413,7 +467,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil).Times(1) classHash := new(felt.Felt).SetUint64(1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil).Times(1) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil).Times(1) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -425,14 +479,14 @@ func TestStorageProof(t *testing.T) { require.Equal(t, nonce, ld.Nonce) require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("contract leaf StorageRoot is the contract storage trie root, not the global contracts trie root", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). Return(headBlock.Header, nil) // Build a separate storage trie with different contents so its root differs from the contracts trie root. - contractStorageTrie := emptyTrie(t) + contractStorageTrie := emptyDeprecatedTrie(t) storageKey := felt.NewFromUint64[felt.Felt](99) storageVal := felt.NewFromUint64[felt.Felt](999) _, _ = contractStorageTrie.Put(storageKey, storageVal) @@ -479,7 +533,7 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). Return(headBlock.Header, nil) contract := felt.NewUnsafeFromString[felt.Felt]("0xabcd") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -488,14 +542,14 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). Return(headBlock.Header, nil) contract := felt.NewUnsafeFromString[felt.Felt]("0xadd0") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -504,7 +558,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -513,7 +567,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil) classHash := new(felt.Felt).SetUint64(1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil) proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -747,7 +801,11 @@ func TestStorageProof_StorageRoots(t *testing.T) { logger := log.NewNopZapLogger() testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, logger, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), time.Second) @@ -926,7 +984,7 @@ func verifyIf( require.Equal(t, leaf, *value) } -func emptyTrie(t *testing.T) *trie.Trie { +func emptyDeprecatedTrie(t *testing.T) *trie.Trie { memdb := memory.New() txn := memdb.NewIndexedBatch() @@ -935,6 +993,15 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } +func emptyTrie(t *testing.T) core.Trie { + if statetestutils.UseNewState() { + tempTrie, err := trie2.NewEmptyPedersen() + require.NoError(t, err) + return tempTrie + } + return emptyDeprecatedTrie(t) +} + func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) if classRoot.IsZero() { diff --git a/rpc/v8/subscriptions_test.go b/rpc/v8/subscriptions_test.go index 329039f485..3cd93639e7 100644 --- a/rpc/v8/subscriptions_test.go +++ b/rpc/v8/subscriptions_test.go @@ -17,6 +17,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" pendingpkg "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -549,10 +550,18 @@ func TestSubscribeNewHeadsHistorical(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &networks.Mainnet) + chain = blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) syncer := newFakeSyncer() ctx, cancel := context.WithCancel(t.Context()) diff --git a/rpc/v8/trace_test.go b/rpc/v8/trace_test.go index d434107978..3989e94545 100644 --- a/rpc/v8/trace_test.go +++ b/rpc/v8/trace_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -563,7 +564,11 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpc.New(chain, nil, nil, logger) update, httpHeader, rpcErr := handler.TraceBlockTransactions(t.Context(), &blockID) diff --git a/rpc/v9/block_test.go b/rpc/v9/block_test.go index d98a5d8467..0f32d9e299 100644 --- a/rpc/v9/block_test.go +++ b/rpc/v9/block_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -283,7 +284,11 @@ func TestBlockWithTxHashes(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) @@ -481,7 +486,11 @@ func TestBlockWithTxs(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) diff --git a/rpc/v9/events_test.go b/rpc/v9/events_test.go index b0d5e5eff4..3672ab5c56 100644 --- a/rpc/v9/events_test.go +++ b/rpc/v9/events_test.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -171,7 +172,11 @@ func setupTestChain( ) (*blockchain.Blockchain, *adaptfeeder.Feeder) { t.Helper() testDB := memory.New() - chain := blockchain.New(testDB, network) + chain := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, network) gw := adaptfeeder.New(client) diff --git a/rpc/v9/state_update_test.go b/rpc/v9/state_update_test.go index a24bcdfa1c..72e190c01e 100644 --- a/rpc/v9/state_update_test.go +++ b/rpc/v9/state_update_test.go @@ -9,6 +9,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -36,7 +37,11 @@ func TestStateUpdate(t *testing.T) { n := &networks.Mainnet for description, id := range errTests { t.Run(description, func(t *testing.T) { - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) if description == "pre_confirmed" { mockSyncReader = mocks.NewMockSyncReader(mockCtrl) mockSyncReader.EXPECT().PreConfirmed().Return(nil, db.ErrKeyNotFound) diff --git a/rpc/v9/storage.go b/rpc/v9/storage.go index 0778130d03..23a959111b 100644 --- a/rpc/v9/storage.go +++ b/rpc/v9/storage.go @@ -7,6 +7,8 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/rpc/rpccore" @@ -209,21 +211,26 @@ func (h *Handler) isBlockSupported(blockID *BlockID, chainHeight uint64) *jsonrp } func getClassProof(tr core.Trie, classes []felt.Felt) ([]*HashToNode, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", tr) - } - - classProof := trie.NewProofNodeSet() - for _, class := range classes { - if err := t.Prove(&class, classProof); err != nil { - return nil, err + switch t := tr.(type) { + case *trie.Trie: + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } } + return adaptDeprecatedTrieProofNodes(classProof), nil + case *trie2.Trie: + classProof := trie2.NewProofNodeSet() + for _, class := range classes { + if err := t.Prove(&class, classProof); err != nil { + return nil, err + } + } + return adaptTrieProofNodes(classProof) + default: + return nil, fmt.Errorf("unknown trie type: %T", tr) } - - return adaptProofNodes(classProof), nil } func getContractProof( @@ -231,23 +238,27 @@ func getContractProof( state core.StateReader, contracts []felt.Felt, ) (*ContractProof, error) { - // TODO(maksym): remove after trie2 integration. RPC packages shouldn't - // care about which trie implementation is being used and the output format should be the same - t, ok := tr.(*trie.Trie) - if !ok { + switch t := tr.(type) { + case *trie.Trie: + return getContractProofWithDeprecatedTrie(t, state, contracts) + case *trie2.Trie: + return getContractProofWithTrie(t, state, contracts) + default: return nil, fmt.Errorf("unknown trie type: %T", tr) } +} - contractProof := trie.NewProofNodeSet() +func buildContractLeavesData( + state core.StateReader, + contracts []felt.Felt, +) ([]*LeafData, error) { contractLeavesData := make([]*LeafData, len(contracts)) - for i, contract := range contracts { - if err := t.Prove(&contract, contractProof); err != nil { - return nil, err - } + for i, contract := range contracts { classHash, err := state.ContractClassHash(&contract) if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + // contract does not exist, skip getting leaf data + if errors.Is(err, db.ErrKeyNotFound) { continue } return nil, err @@ -275,8 +286,58 @@ func getContractProof( } } + return contractLeavesData, nil +} + +func getContractProofWithDeprecatedTrie( + tr *trie.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { + contractProof := trie.NewProofNodeSet() + + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + return &ContractProof{ - Nodes: adaptProofNodes(contractProof), + Nodes: adaptDeprecatedTrieProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractProofWithTrie( + tr *trie2.Trie, + state core.StateReader, + contracts []felt.Felt, +) (*ContractProof, error) { + contractProof := trie2.NewProofNodeSet() + + for _, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + } + + contractLeavesData, err := buildContractLeavesData(state, contracts) + if err != nil { + return nil, err + } + + nodes, err := adaptTrieProofNodes(contractProof) + if err != nil { + return nil, err + } + + return &ContractProof{ + Nodes: nodes, LeavesData: contractLeavesData, }, nil } @@ -292,24 +353,36 @@ func getContractStorageProof( return nil, err } - t, ok := contractStorageTrie.(*trie.Trie) - if !ok { - return nil, fmt.Errorf("unknown trie type: %T", t) - } - contractStorageProof := trie.NewProofNodeSet() - for _, key := range storageKey.Keys { - if err := t.Prove(&key, contractStorageProof); err != nil { + switch t := contractStorageTrie.(type) { + case *trie.Trie: + contractStorageProof := trie.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + contractStorageRes[i] = adaptDeprecatedTrieProofNodes(contractStorageProof) + case *trie2.Trie: + contractStorageProof := trie2.NewProofNodeSet() + for _, key := range storageKey.Keys { + if err := t.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + nodes, err := adaptTrieProofNodes(contractStorageProof) + if err != nil { return nil, err } + contractStorageRes[i] = nodes + default: + return nil, fmt.Errorf("unknown trie type: %T", contractStorageTrie) } - - contractStorageRes[i] = adaptProofNodes(contractStorageProof) } return contractStorageRes, nil } -func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { +func adaptDeprecatedTrieProofNodes(proof *trie.ProofNodeSet) []*HashToNode { nodes := make([]*HashToNode, proof.Size()) nodeList := proof.List() for i, hash := range proof.Keys() { @@ -339,6 +412,58 @@ func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { return nodes } +func adaptTrieProofNodes(proof *trie2.ProofNodeSet) ([]*HashToNode, error) { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trienode.BinaryNode: + leftChild, err := nodeFelt(n.Children[0]) + if err != nil { + return nil, err + } + rightChild, err := nodeFelt(n.Children[1]) + if err != nil { + return nil, err + } + node = &BinaryNode{ + Left: &leftChild, + Right: &rightChild, + } + case *trienode.EdgeNode: + pathFelt := n.Path.Felt() + child, err := nodeFelt(n.Child) + if err != nil { + return nil, err + } + node = &EdgeNode{ + Path: pathFelt.String(), + Length: int(n.Path.Len()), + Child: &child, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes, nil +} + +func nodeFelt(n trienode.Node) (felt.Felt, error) { + switch n := n.(type) { + case *trienode.HashNode: + return felt.Felt(*n), nil + case *trienode.ValueNode: + return felt.Felt(*n), nil + } + return felt.Zero, fmt.Errorf("unknown node type: %T", n) +} + type StorageKeys struct { Contract *felt.Felt `json:"contract_address"` Keys []felt.Felt `json:"storage_keys"` diff --git a/rpc/v9/storage_test.go b/rpc/v9/storage_test.go index 5d12438c13..0d27c81ff7 100644 --- a/rpc/v9/storage_test.go +++ b/rpc/v9/storage_test.go @@ -13,7 +13,11 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/core/trie2" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/jsonrpc" @@ -239,11 +243,63 @@ func TestStorageProof(t *testing.T) { blockNumber = uint64(1313) ) - tempTrie := emptyTrie(t) - _, _ = tempTrie.Put(key, value) - _, _ = tempTrie.Put(key2, value2) - _ = tempTrie.Commit() - trieRoot, _ := tempTrie.Hash() + var classTrie, contractTrie core.Trie + trieRoot := felt.Zero + + if !statetestutils.UseNewState() { + tempTrie := emptyDeprecatedTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ = tempTrie.Root() + classTrie = tempTrie + contractTrie = tempTrie + } else { + newComm := felt.FromUint64[felt.StateRootHash](1) + createTrie := func( + t *testing.T, + id trieutils.TrieID, + trieDB *trie2.TestNodeDatabase, + ) *trie2.Trie { + tr, err := trie2.New(id, 251, crypto.Pedersen, trieDB) + _ = tr.Update(key, value) + _ = tr.Update(key2, value2) + require.NoError(t, err) + _, nodes := tr.Commit() + err = trieDB.Update((*felt.Felt)(&newComm), &felt.Zero, trienode.NewMergeNodeSet(nodes)) + require.NoError(t, err) + return tr + } + + trieDB := trie2.NewTestNodeDatabase(memory.New(), trie2.PathScheme) + createTrie(t, trieutils.NewClassTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + contractTrie2 := createTrie(t, trieutils.NewContractTrieID( + felt.FromUint64[felt.StateRootHash](0), + ), &trieDB) + tmpTrieRoot, err := contractTrie2.Hash() + require.NoError(t, err) + trieRoot = tmpTrieRoot + + // recreate because the previous ones are committed + classTrie2, err := trie2.New( + trieutils.NewClassTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + contractTrie2, err = trie2.New( + trieutils.NewContractTrieID(newComm), + 251, + crypto.Pedersen, + &trieDB, + ) + require.NoError(t, err) + classTrie = classTrie2 + contractTrie = contractTrie2 + } headBlock := &core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}} @@ -253,8 +309,8 @@ func TestStorageProof(t *testing.T) { mockReader.EXPECT().Head().Return(headBlock, nil).AnyTimes() mockReader.EXPECT().BlockByNumber(blockNumber).Return(headBlock, nil).AnyTimes() mockReader.EXPECT().Height().Return(blockNumber, nil).AnyTimes() - mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() - mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(classTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(contractTrie, nil).AnyTimes() logger := log.NewNopZapLogger() handler := rpc.New(mockReader, nil, nil, logger) @@ -392,7 +448,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ClassesProof, classTrie.HashFn()) }) t.Run("class trie hash exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -402,7 +458,7 @@ func TestStorageProof(t *testing.T) { require.Nil(t, rpcErr) require.NotNil(t, proof) arityTest(t, proof, 3, 0, 0, 0) - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) }) t.Run("only unique proof nodes are returned", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -418,8 +474,8 @@ func TestStorageProof(t *testing.T) { require.Len(t, rootNodes, 1) // verify we can still prove any of the keys in query - verifyIf(t, &trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) - verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ClassesProof, classTrie.HashFn()) + verifyIf(t, &trieRoot, key2, value2, proof.ClassesProof, classTrie.HashFn()) }) t.Run("storage trie address does not exist in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -432,7 +488,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 3, 1, 0) require.Nil(t, proof.ContractsProof.LeavesData[0]) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("storage trie address exists in a trie", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -442,7 +498,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil).Times(1) classHash := felt.NewFromUint64[felt.Felt](1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil).Times(1) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil).Times(1) proof, rpcErr := handler.StorageProof(&blockLatest, nil, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -454,14 +510,14 @@ func TestStorageProof(t *testing.T) { require.Equal(t, nonce, ld.Nonce) require.Equal(t, classHash, ld.ClassHash) - verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsProof.Nodes, classTrie.HashFn()) }) t.Run("contract leaf StorageRoot is the contract storage trie root, not the global contracts trie root", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). Return(headBlock.Header, nil) // Build a separate storage trie with different contents so its root differs from the contracts trie root. - contractStorageTrie := emptyTrie(t) + contractStorageTrie := emptyDeprecatedTrie(t) storageKey := felt.NewFromUint64[felt.Felt](99) storageVal := felt.NewFromUint64[felt.Felt](999) _, _ = contractStorageTrie.Put(storageKey, storageVal) @@ -510,7 +566,7 @@ func TestStorageProof(t *testing.T) { Return(headBlock.Header, nil) contract := felt.NewFromUint64[felt.Felt](0xabcd) - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*noSuchKey}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -519,7 +575,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) //nolint:dupl t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { @@ -527,7 +583,7 @@ func TestStorageProof(t *testing.T) { Return(headBlock.Header, nil) contract := felt.NewUnsafeFromString[felt.Felt]("0xadd0") - mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + mockState.EXPECT().ContractStorageTrie(contract).Return(contractTrie, nil).Times(1) storageKeys := []rpc.StorageKeys{{Contract: contract, Keys: []felt.Felt{*key}}} proof, rpcErr := handler.StorageProof(&blockLatest, nil, nil, storageKeys) @@ -536,7 +592,7 @@ func TestStorageProof(t *testing.T) { arityTest(t, proof, 0, 0, 0, 1) require.Len(t, proof.ContractsStorageProofs[0], 3) - verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + verifyIf(t, &trieRoot, key, value, proof.ContractsStorageProofs[0], contractTrie.HashFn()) }) t.Run("class & storage tries proofs requested", func(t *testing.T) { mockReader.EXPECT().BlockHeaderByNumber(blockNumber). @@ -546,7 +602,7 @@ func TestStorageProof(t *testing.T) { mockState.EXPECT().ContractNonce(key).Return(*nonce, nil) classHash := felt.NewFromUint64[felt.Felt](1234) mockState.EXPECT().ContractClassHash(key).Return(*classHash, nil) - mockState.EXPECT().ContractStorageTrie(key).Return(tempTrie, nil) + mockState.EXPECT().ContractStorageTrie(key).Return(contractTrie, nil) proof, rpcErr := handler.StorageProof(&blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) require.Nil(t, rpcErr) @@ -780,7 +836,11 @@ func TestStorageProof_StorageRoots(t *testing.T) { logger := log.NewNopZapLogger() testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New(bc, dataSource, logger, time.Duration(0), time.Duration(0), false, testDB) ctx, cancel := context.WithTimeout(t.Context(), time.Second) @@ -959,7 +1019,7 @@ func verifyIf( require.Equal(t, leaf, *value) } -func emptyTrie(t *testing.T) *trie.Trie { +func emptyDeprecatedTrie(t *testing.T) *trie.Trie { memdb := memory.New() txn := memdb.NewIndexedBatch() @@ -968,6 +1028,15 @@ func emptyTrie(t *testing.T) *trie.Trie { return tempTrie } +func emptyTrie(t *testing.T) core.Trie { + if statetestutils.UseNewState() { + tempTrie, err := trie2.NewEmptyPedersen() + require.NoError(t, err) + return tempTrie + } + return emptyDeprecatedTrie(t) +} + func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { stateVersion := felt.NewFromBytes[felt.Felt]([]byte(`STARKNET_STATE_V0`)) if classRoot.IsZero() { diff --git a/rpc/v9/subscriptions_test.go b/rpc/v9/subscriptions_test.go index d55c612504..85ce54193e 100644 --- a/rpc/v9/subscriptions_test.go +++ b/rpc/v9/subscriptions_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/feed" @@ -1170,10 +1171,18 @@ func TestSubscribeNewHeadsHistorical(t *testing.T) { require.NoError(t, err) testDB := memory.New() - chain := blockchain.New(testDB, &networks.Mainnet) + chain := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) assert.NoError(t, chain.Store(block0, &emptyCommitments, stateUpdate0, nil)) - chain = blockchain.New(testDB, &networks.Mainnet) + chain = blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) syncer := newFakeSyncer() ctx, cancel := context.WithCancel(t.Context()) diff --git a/rpc/v9/trace_test.go b/rpc/v9/trace_test.go index 80c6a5d8bd..f0bcff05d9 100644 --- a/rpc/v9/trace_test.go +++ b/rpc/v9/trace_test.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" @@ -760,7 +761,11 @@ func TestTraceBlockTransactions(t *testing.T) { t.Run(description, func(t *testing.T) { logger := log.NewNopZapLogger() n := &networks.Mainnet - chain := blockchain.New(memory.New(), n) + chain := blockchain.New( + memory.New(), + n, + blockchain.WithNewState(statetestutils.UseNewState()), + ) handler := rpc.New(chain, nil, nil, logger) if description == "pre_confirmed" { diff --git a/sequencer/sequencer_test.go b/sequencer/sequencer_test.go index 4bbe3ffb65..667880c088 100644 --- a/sequencer/sequencer_test.go +++ b/sequencer/sequencer_test.go @@ -11,6 +11,7 @@ import ( "github.com/NethermindEth/juno/builder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/mempool" @@ -31,7 +32,11 @@ func getEmptySequencer(t *testing.T, blockTime time.Duration, seqAddr *felt.Felt mockCtrl := gomock.NewController(t) mockVM := mocks.NewMockVM(mockCtrl) network := &networks.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) emptyStateDiff := core.EmptyStateDiff() require.NoError(t, bc.StoreGenesis(&emptyStateDiff, nil)) privKey, err := ecdsa.GenerateKey(rand.Reader) @@ -105,7 +110,11 @@ func getGenesisSequencer( testDB := memory.New() network := &networks.Mainnet - bc := blockchain.New(testDB, network) + bc := blockchain.New( + testDB, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) logger := log.NewNopZapLogger() privKey, err := ecdsa.GenerateKey(rand.Reader) require.NoError(t, err) @@ -130,6 +139,7 @@ func getGenesisSequencer( bc.Network(), vm.DefaultMaxSteps, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) diff --git a/sync/pending_polling_test.go b/sync/pending_polling_test.go index 7e6fe81736..1b5b53ca47 100644 --- a/sync/pending_polling_test.go +++ b/sync/pending_polling_test.go @@ -14,6 +14,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" "github.com/NethermindEth/juno/utils/log" @@ -61,7 +62,11 @@ func (m *MockDataSource) PreConfirmedBlockByNumber( func TestPollPreLatest(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) dataSource := NewFeederGatewayDataSource(bc, gw) @@ -235,7 +240,11 @@ func TestPollPreLatest(t *testing.T) { func TestPollPreConfirmedLoop(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Sepolia) + bc := blockchain.New( + testDB, + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Sepolia) gw := adaptfeeder.New(client) dataSource := NewFeederGatewayDataSource(bc, gw) @@ -330,7 +339,11 @@ func TestPollPreConfirmedLoop(t *testing.T) { func TestPollPendingData(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Sepolia) + bc := blockchain.New( + testDB, + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Sepolia) gw := adaptfeeder.New(client) dataSource := NewFeederGatewayDataSource(bc, gw) @@ -410,7 +423,11 @@ func TestPollPendingData(t *testing.T) { func TestPollPendingDataPreConfirmedPolling(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Sepolia) + bc := blockchain.New( + testDB, + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) client := feeder.NewTestClient(t, &networks.Sepolia) gw := adaptfeeder.New(client) dataSource := NewFeederGatewayDataSource(bc, gw) @@ -471,7 +488,11 @@ func TestPollPendingDataPreConfirmedPolling(t *testing.T) { func TestStorePreConfirmed(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) logger := log.NewNopZapLogger() client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) diff --git a/sync/reorg_test.go b/sync/reorg_test.go index 643b555ffb..383ebf2b20 100644 --- a/sync/reorg_test.go +++ b/sync/reorg_test.go @@ -16,6 +16,7 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/pending" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/genesis" "github.com/NethermindEth/juno/starknet/compiler" @@ -98,7 +99,11 @@ func initGenesis(t *testing.T) (*memory.Database, sync.CommittedBlock) { t.Helper() database := memory.New() - bc := blockchain.New(database, network) + bc := blockchain.New( + database, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) genesisConfig, err := genesis.Read("../genesis/genesis_prefund_accounts.json") require.NoError(t, err) @@ -119,6 +124,7 @@ func initGenesis(t *testing.T) (*memory.Database, sync.CommittedBlock) { bc.Network(), vm.DefaultMaxGas, vm.DefaultMaxGas, + statetestutils.UseNewState(), compiler.NewUnsafe(), ) require.NoError(t, err) @@ -143,7 +149,11 @@ func newBlockGenerator( database *memory.Database, sequencer uint64, ) *blockGenerator { - bc := blockchain.New(database, network) + bc := blockchain.New( + database, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) builder := newTestBuilder(log.NewNopZapLogger(), bc) return &blockGenerator{ @@ -242,7 +252,11 @@ func setup( logger, err := log.NewZapLogger(log.NewLevel(logLevel)) require.NoError(t, err) - blockchain := blockchain.New(synchronizerDatabase, network) + blockchain := blockchain.New( + synchronizerDatabase, + network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) wg := gosync.WaitGroup{} ctx, cancel := context.WithCancel(t.Context()) diff --git a/sync/sync.go b/sync/sync.go index 655a069fec..a1b50d0567 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -300,7 +300,8 @@ func (s *Synchronizer) handlePluginRevertBlock() { err = s.plugin.RevertBlock( &junoplugin.BlockAndStateUpdate{Block: fromBlock, StateUpdate: fromSU}, toBlockAndStateUpdate, - &reverseStateDiff) + &reverseStateDiff, + ) if err != nil { s.logger.Error("Plugin RevertBlock failure:", zap.Error(err)) } diff --git a/sync/sync_test.go b/sync/sync_test.go index 2b82e2cfd3..d70e0f3814 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -12,6 +12,7 @@ import ( "github.com/NethermindEth/juno/clients/feeder" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/mocks" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -56,7 +57,11 @@ func TestSyncBlocks(t *testing.T) { logger := log.NewNopZapLogger() t.Run("sync multiple blocks in an empty db", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New( bc, @@ -77,7 +82,11 @@ func TestSyncBlocks(t *testing.T) { t.Run("sync multiple blocks in a non-empty db", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) b0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) s0, err := gw.StateUpdate(t.Context(), 0) @@ -104,7 +113,11 @@ func TestSyncBlocks(t *testing.T) { t.Run("sync multiple blocks, with an unreliable gw", func(t *testing.T) { testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) mockSNData := mocks.NewMockStarknetData(mockCtrl) @@ -181,7 +194,11 @@ func TestReorg(t *testing.T) { testDB := memory.New() // sync to Sepolia for 2 blocks - bc := blockchain.New(testDB, &networks.Sepolia) + bc := blockchain.New( + testDB, + &networks.Sepolia, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, sepoliaGw) synchronizer := sync.New(bc, dataSource, log.NewNopZapLogger(), 0, 0, false, testDB) @@ -190,7 +207,11 @@ func TestReorg(t *testing.T) { cancel() t.Run("resync to mainnet with the same db", func(t *testing.T) { - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) // Ensure current head is Sepolia head head, err := bc.HeadsHeader() @@ -244,7 +265,11 @@ func TestSubscribeNewHeads(t *testing.T) { testDB := memory.New() logger := log.NewNopZapLogger() network := networks.Mainnet - chain := blockchain.New(testDB, &network) + chain := blockchain.New( + testDB, + &network, + blockchain.WithNewState(statetestutils.UseNewState()), + ) feeder := feeder.NewTestClient(t, &network) gw := adaptfeeder.New(feeder) dataSource := sync.NewFeederGatewayDataSource(chain, gw) @@ -273,7 +298,11 @@ func TestPreConfirmedAfterSync(t *testing.T) { testDB := memory.New() logger := log.NewNopZapLogger() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) dataSource := sync.NewFeederGatewayDataSource(bc, gw) synchronizer := sync.New( bc, @@ -313,7 +342,11 @@ func TestPreConfirmed(t *testing.T) { t.Run("Returns pre_confirmed data when available", func(t *testing.T) { t.Parallel() testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + blockchain.WithNewState(statetestutils.UseNewState()), + ) b0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) s0, err := gw.StateUpdate(t.Context(), 0) @@ -338,7 +371,12 @@ func TestPreConfirmed(t *testing.T) { t.Run("Returns empty pre_confirmed when nothing stored", func(t *testing.T) { t.Parallel() testDB := memory.New() - bc := blockchain.New(testDB, &networks.Mainnet) + bc := blockchain.New( + testDB, + &networks.Mainnet, + + blockchain.WithNewState(statetestutils.UseNewState()), + ) b0, err := gw.BlockByNumber(t.Context(), 0) require.NoError(t, err) s0, err := gw.StateUpdate(t.Context(), 0) diff --git a/vm/vm_test.go b/vm/vm_test.go index d8e0442e72..6f0b5e20c2 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -9,6 +9,10 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/deprecatedstate" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/state" + statetestutils "github.com/NethermindEth/juno/core/state/testutils" + "github.com/NethermindEth/juno/core/trie2/triedb" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/rpc/rpccore" adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" @@ -19,7 +23,7 @@ import ( func TestCallDeprecatedCairo(t *testing.T) { testDB := memory.New() - txn := testDB.NewIndexedBatch() + batch := testDB.NewBatch() client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -29,7 +33,8 @@ func TestCallDeprecatedCairo(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := deprecatedstate.New(txn) + testState, err := NewState(t, &felt.Zero, testDB, batch) + require.NoError(t, err) require.NoError(t, testState.Update(&core.Header{Number: 0}, &core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: felt.NewUnsafeFromString[felt.Felt]("0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), @@ -41,6 +46,7 @@ func TestCallDeprecatedCairo(t *testing.T) { }, map[felt.Felt]core.ClassDefinition{ *classHash: simpleClass, }, false)) + require.NoError(t, batch.Write()) entryPoint := felt.NewUnsafeFromString[felt.Felt]("0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") @@ -67,6 +73,16 @@ func TestCallDeprecatedCairo(t *testing.T) { require.NoError(t, err) assert.Equal(t, []*felt.Felt{&felt.Zero}, ret.Result) + // for new state, each block needs a fresh batch and a state rooted at the previous block's root + if statetestutils.UseNewState() { + batch = testDB.NewBatch() + newRoot := felt.NewUnsafeFromString[felt.Felt]( + "0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8", + ) + testState, err = NewState(t, newRoot, testDB, batch) + require.NoError(t, err) + } + require.NoError(t, testState.Update(&core.Header{Number: 1}, &core.StateUpdate{ OldRoot: felt.NewUnsafeFromString[felt.Felt]("0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), NewRoot: felt.NewUnsafeFromString[felt.Felt]("0x4a948783e8786ba9d8edaf42de972213bd2deb1b50c49e36647f1fef844890f"), @@ -78,6 +94,9 @@ func TestCallDeprecatedCairo(t *testing.T) { }, }, }, nil, false)) + if statetestutils.UseNewState() { + require.NoError(t, batch.Write()) + } ret, err = New(&chainInfo, false, nil).Call( &CallInfo{ @@ -100,7 +119,7 @@ func TestCallDeprecatedCairo(t *testing.T) { func TestCallDeprecatedCairoMaxSteps(t *testing.T) { testDB := memory.New() - txn := testDB.NewIndexedBatch() + batch := testDB.NewBatch() client := feeder.NewTestClient(t, &networks.Mainnet) gw := adaptfeeder.New(client) @@ -110,7 +129,9 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := deprecatedstate.New(txn) + testState, err := NewState(t, &felt.Zero, testDB, batch) + require.NoError(t, err) + require.NoError(t, testState.Update(&core.Header{Number: 0}, &core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: felt.NewUnsafeFromString[felt.Felt]("0x3d452fbb3c3a32fe85b1a3fbbcdec316d5fc940cefc028ee808ad25a15991c8"), @@ -122,6 +143,7 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { }, map[felt.Felt]core.ClassDefinition{ *classHash: simpleClass, }, false)) + require.NoError(t, batch.Write()) entryPoint := felt.NewUnsafeFromString[felt.Felt]("0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695") feeTokens := networks.DefaultFeeTokenAddresses @@ -147,7 +169,7 @@ func TestCallDeprecatedCairoMaxSteps(t *testing.T) { func TestCallCairo(t *testing.T) { testDB := memory.New() - txn := testDB.NewIndexedBatch() + batch := testDB.NewBatch() client := feeder.NewTestClient(t, &networks.Goerli) gw := adaptfeeder.New(client) @@ -162,7 +184,8 @@ func TestCallCairo(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - state := deprecatedstate.New(txn) + state, err := NewState(t, &felt.Zero, testDB, batch) + require.NoError(t, err) firstStateUpdate := core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: felt.NewUnsafeFromString[felt.Felt]( @@ -176,6 +199,7 @@ func TestCallCairo(t *testing.T) { } declaredClass := map[felt.Felt]core.ClassDefinition{*classHash: simpleClass} require.NoError(t, state.Update(&core.Header{Number: 0}, &firstStateUpdate, declaredClass, false)) + require.NoError(t, batch.Write()) logLevel := log.NewLevel(log.ERROR) logger, err := log.NewZapLogger(logLevel) @@ -230,7 +254,17 @@ func TestCallCairo(t *testing.T) { }, }, } + // for new state, each block needs a fresh batch and a state rooted at the previous block's root + if statetestutils.UseNewState() { + batch = testDB.NewBatch() + state, err = NewState(t, firstStateUpdate.NewRoot, testDB, batch) + require.NoError(t, err) + } + require.NoError(t, state.Update(&core.Header{Number: 1}, &secondStateUpdate, nil, false)) + if statetestutils.UseNewState() { + require.NoError(t, batch.Write()) + } ret, err = vm.Call( &callInfo, @@ -247,7 +281,7 @@ func TestCallCairo(t *testing.T) { func TestCallInfoErrorHandling(t *testing.T) { testDB := memory.New() - txn := testDB.NewIndexedBatch() + batch := testDB.NewBatch() client := feeder.NewTestClient(t, &networks.Sepolia) gw := adaptfeeder.New(client) @@ -256,7 +290,8 @@ func TestCallInfoErrorHandling(t *testing.T) { simpleClass, err := gw.Class(t.Context(), classHash) require.NoError(t, err) - testState := deprecatedstate.New(txn) + testState, err := NewState(t, &felt.Zero, testDB, batch) + require.NoError(t, err) require.NoError(t, testState.Update(&core.Header{Number: 0}, &core.StateUpdate{ OldRoot: &felt.Zero, NewRoot: felt.NewUnsafeFromString[felt.Felt]("0xa6258de574e5540253c4a52742137d58b9e8ad8f584115bee46d9d18255c42"), @@ -268,6 +303,7 @@ func TestCallInfoErrorHandling(t *testing.T) { }, map[felt.Felt]core.ClassDefinition{ *classHash: simpleClass, }, false)) + require.NoError(t, batch.Write()) logLevel := log.NewLevel(log.ERROR) logger, err := log.NewZapLogger(logLevel) @@ -324,9 +360,10 @@ func TestCallInfoErrorHandling(t *testing.T) { func TestExecute(t *testing.T) { testDB := memory.New() - txn := testDB.NewIndexedBatch() + batch := testDB.NewBatch() - state := deprecatedstate.New(txn) + state, err := NewState(t, &felt.Zero, testDB, batch) + require.NoError(t, err) t.Run("empty transaction list", func(t *testing.T) { feeTokens := networks.DefaultFeeTokenAddresses @@ -387,3 +424,20 @@ func TestSetVersionedConstants(t *testing.T) { assert.ErrorContains(t, SetVersionedConstants("not_exists.json"), "no such file or directory") }) } + +func NewState( + t *testing.T, + stateRoot *felt.Felt, + testDB db.KeyValueStore, + batch db.Batch, +) (core.State, error) { + if !statetestutils.UseNewState() { + //nolint:staticcheck,nolintlint // used by old state + txn := testDB.NewIndexedBatch() + deprecatedState := deprecatedstate.New(txn) + return deprecatedState, nil + } + triedb := triedb.New(testDB, nil) + stateDB := state.NewStateDB(testDB, triedb) + return state.New(stateRoot, stateDB, batch) +}