diff --git a/apps/solana/transaction.go b/apps/solana/transaction.go index d15dfdc..b6f95aa 100644 --- a/apps/solana/transaction.go +++ b/apps/solana/transaction.go @@ -425,7 +425,9 @@ func ExtractTransfersFromTransaction(ctx context.Context, tx *solana.Transaction innerInstructions[inner.Index] = sis } - for _, balance := range meta.PreTokenBalances { + bs := meta.PreTokenBalances + bs = append(bs, meta.PostTokenBalances...) + for _, balance := range bs { if account, err := msg.Account(balance.AccountIndex); err == nil { tokenAccounts[account] = token.Account{ Owner: *balance.Owner, @@ -465,6 +467,36 @@ func ExtractTransfersFromTransaction(ctx context.Context, tx *solana.Transaction return transfers, nil } +func ExtractTransferFromTransactionByIndex(ctx context.Context, tx *solana.Transaction, meta *rpc.TransactionMeta, index int64) *Transfer { + if meta.Err != nil { + panic(fmt.Sprint(meta.Err)) + } + msg := tx.Message + + var ( + tokenAccounts = map[solana.PublicKey]token.Account{} + owners = []*solana.PublicKey{} + ) + + bs := meta.PreTokenBalances + bs = append(bs, meta.PostTokenBalances...) + for _, balance := range bs { + if account, err := msg.Account(balance.AccountIndex); err == nil { + tokenAccounts[account] = token.Account{ + Owner: *balance.Owner, + Mint: balance.Mint, + } + if !slices.ContainsFunc(owners, func(owner *solana.PublicKey) bool { + return owner.Equals(*balance.Owner) + }) { + owners = append(owners, balance.Owner) + } + } + } + + return extractTransfersFromInstruction(&msg, msg.Instructions[index], tokenAccounts, owners, nil) +} + func ExtractMintsFromTransaction(tx *solana.Transaction) []string { var assets []string for index, ix := range tx.Message.Instructions { diff --git a/cmd/computer.go b/cmd/computer.go index 78f01ec..733a87b 100644 --- a/cmd/computer.go +++ b/cmd/computer.go @@ -81,7 +81,7 @@ func ComputerBootCmd(c *cli.Context) error { return err } defer kd.Close() - computer := computer.NewNode(kd, group, messenger, mc.Computer, client, mw) + computer := computer.NewNode(kd, group, messenger, mc, client, mw) computer.Boot(ctx, version) if mmc := mc.Computer.MonitorConversationId; mmc != "" { diff --git a/cmd/monitor.go b/cmd/monitor.go index d0ec079..ba667ec 100644 --- a/cmd/monitor.go +++ b/cmd/monitor.go @@ -158,7 +158,7 @@ func bundleComputerState(ctx context.Context, node *computer.Node, mixin *mixin. state = state + fmt.Sprintf("💍 Payer %s Balance: %s SOL\n", node.SolanaPayer(), b) limit := mc.NewIntegerFromString("0.5") - if xinBalance.Cmp(limit) < 0 || solBalance.Cmp(limit) < 0 || mc.NewIntegerFromString(b).Cmp(limit) < 0 { + if node.Prod() && (xinBalance.Cmp(limit) < 0 || solBalance.Cmp(limit) < 0 || mc.NewIntegerFromString(b).Cmp(limit) < 0) { state = state + "@40518661\n" } diff --git a/config/dev.go b/config/dev.go index 49c6cb9..40971bb 100644 --- a/config/dev.go +++ b/config/dev.go @@ -5,19 +5,11 @@ import ( "net/http" _ "net/http/pprof" + computer "github.com/MixinNetwork/computer/solana" "github.com/MixinNetwork/mixin/logger" ) -const MainNetworkName = "main" -const TestNetworkName = "test" - -type DevConfig struct { - ProfilePort int `toml:"profile-port"` - LogLevel int `toml:"log-level"` - Network string `toml:"network"` -} - -func handleDevConfig(c *DevConfig) { +func handleDevConfig(c *computer.DevConfig) { logger.SetLevel(logger.INFO) if c == nil { return @@ -28,6 +20,6 @@ func handleDevConfig(c *DevConfig) { go http.ListenAndServe(l, http.DefaultServeMux) } if c.Network == "" { - c.Network = MainNetworkName + c.Network = computer.MainNetworkName } } diff --git a/config/reader.go b/config/reader.go index 463b874..fd8a673 100644 --- a/config/reader.go +++ b/config/reader.go @@ -10,12 +10,7 @@ import ( "github.com/pelletier/go-toml" ) -type Configuration struct { - Computer *computer.Configuration `toml:"computer"` - Dev *DevConfig `toml:"dev"` -} - -func ReadConfiguration(path, role string) (*Configuration, error) { +func ReadConfiguration(path, role string) (*computer.Config, error) { if strings.HasPrefix(path, "~/") { usr, _ := user.Current() path = filepath.Join(usr.HomeDir, (path)[2:]) @@ -24,7 +19,7 @@ func ReadConfiguration(path, role string) (*Configuration, error) { if err != nil { return nil, err } - var conf Configuration + var conf computer.Config err = toml.Unmarshal(f, &conf) if err != nil { return nil, err diff --git a/solana/computer_test.go b/solana/computer_test.go index 219bc83..4e8fe69 100644 --- a/solana/computer_test.go +++ b/solana/computer_test.go @@ -147,6 +147,7 @@ func TestCompaction(t *testing.T) { if i == 2 { require.Equal("", ar.Compaction) require.Len(ar.Transactions, 1) + require.Len(call.GetRefundIds(), 1) } else { require.Equal(mtg.StorageAssetId, ar.Compaction) require.Len(ar.Transactions, 0) @@ -173,6 +174,22 @@ func TestDepositCompaction(t *testing.T) { err = node.store.LockNonceAccountWithMix(ctx, nonce.Address, user.MixAddress) require.Nil(err) + now := time.Now() + id := uuid.Must(uuid.NewV4()).String() + for _, node := range nodes { + call := &store.SystemCall{ + RequestId: id, + Superior: id, + Type: store.CallTypeDeposit, + State: common.RequestStateDone, + Hash: sql.NullString{Valid: true, String: "6b33cca6e650a1c2abe4122a466eb7d02f7faa47ee935c80536178bacd913a56"}, + CreatedAt: now, + UpdatedAt: now, + } + err := node.store.TestWriteCall(ctx, call) + require.Nil(err) + } + sequence += 10 out, err := testWriteOutputForNodes(ctx, mds, conf.AppId, mtg.StorageAssetId, "", "", sequence, decimal.RequireFromString("0.90432841")) out.OutputId = "329346e1-34c2-4de0-8e35-729518eda8bd" @@ -195,14 +212,18 @@ func TestDepositCompaction(t *testing.T) { sequence += 100 _, err = testWriteOutputForNodes(ctx, mds, conf.AppId, mtg.StorageAssetId, "", "", uint64(sequence), decimal.RequireFromString("0.36")) require.Nil(err) + a.Sequence = sequence - out.Sequence = sequence for _, node := range nodes { a.TestAttachActionToGroup(node.group) - txs, compaction := node.processDeposit(ctx, a) + txs, compaction := node.processDeposit(ctx, a, i > 0) if i == 2 { require.Len(txs, 1) require.Equal("", compaction) + + call, err := node.store.ReadSystemCallByRequestId(ctx, id, 0) + require.Nil(err) + require.Len(call.GetRefundIds(), 1) } else { require.Len(txs, 0) require.Equal(mtg.StorageAssetId, compaction) @@ -307,6 +328,7 @@ func testObserverConfirmPostProcessCall(ctx context.Context, require *require.As call, err := node.store.ReadSystemCallByRequestId(ctx, sub.Superior, common.RequestStateDone) require.Nil(err) require.NotNil(call) + require.Len(call.GetRefundIds(), 1) ar, _, err := node.store.ReadActionResult(ctx, nid, nid) require.Nil(err) @@ -915,9 +937,7 @@ func testPrepare(require *require.Assertions, signerTest bool) (context.Context, func testBuildNode(ctx context.Context, require *require.Assertions, root string, i int) (*Node, *mtg.SQLite3Store) { f, _ := os.ReadFile("../config/example.toml") - var conf struct { - Computer *Configuration `toml:"computer"` - } + var conf Config err := toml.Unmarshal(f, &conf) require.Nil(err) @@ -948,7 +968,7 @@ func testBuildNode(ctx context.Context, require *require.Assertions, root string require.Nil(err) group.EnableDebug() - node := NewNode(kd, group, nil, conf.Computer, nil, nil) + node := NewNode(kd, group, nil, &conf, nil, nil) group.AttachWorker(node.conf.AppId, node) return node, md } diff --git a/solana/group.go b/solana/group.go index 20a95f5..1572364 100644 --- a/solana/group.go +++ b/solana/group.go @@ -39,7 +39,7 @@ func (node *Node) processAction(ctx context.Context, out *mtg.Action) ([]*mtg.Tr isDeposit := node.verifyKernelTransaction(ctx, out) if isDeposit { - return node.processDeposit(ctx, out) + return node.processDeposit(ctx, out, out.Restored()) } req, err := node.parseRequest(out) @@ -58,7 +58,7 @@ func (node *Node) processAction(ctx context.Context, out *mtg.Action) ([]*mtg.Tr } if ar != nil { if req.Restored && ar.Compaction != "" { - err = node.store.ResetRequest(ctx, req) + err = node.store.ResetRequest(ctx, req.Id, req.Sequence) if err != nil { panic(err) } diff --git a/solana/http.go b/solana/http.go index 88bd0f8..e377a56 100644 --- a/solana/http.go +++ b/solana/http.go @@ -331,11 +331,10 @@ func buildSystemCallView(call *store.SystemCall) map[string]any { "state": state, "hash": call.Hash.String, } - uid := "" if call.Type == store.CallTypeMain { - uid = call.UserIdFromPublicPath() + v["user_id"] = call.UserIdFromPublicPath() + v["refund_traces"] = call.GetRefundIds() } - v["user_id"] = uid return v } diff --git a/solana/interface.go b/solana/interface.go index a72d618..f146061 100644 --- a/solana/interface.go +++ b/solana/interface.go @@ -8,6 +8,17 @@ import ( "github.com/MixinNetwork/safe/mtg" ) +type Config struct { + Computer *Configuration `toml:"computer"` + Dev *DevConfig `toml:"dev"` +} + +type DevConfig struct { + ProfilePort int `toml:"profile-port"` + LogLevel int `toml:"log-level"` + Network string `toml:"network"` +} + type Configuration struct { AppId string `toml:"app-id"` StoreDir string `toml:"store-dir"` diff --git a/solana/mvm.go b/solana/mvm.go index 40f4c4f..187fc35 100644 --- a/solana/mvm.go +++ b/solana/mvm.go @@ -683,7 +683,7 @@ func (node *Node) processObserverCreateDepositCall(ctx context.Context, req *sto } // deposit from Solana to mtg deposit entry -func (node *Node) processDeposit(ctx context.Context, out *mtg.Action) ([]*mtg.Transaction, string) { +func (node *Node) processDeposit(ctx context.Context, out *mtg.Action, restored bool) ([]*mtg.Transaction, string) { logger.Printf("node.processDeposit(%v)", out) ar, handled, err := node.store.ReadActionResult(ctx, out.OutputId, out.OutputId) logger.Printf("store.ReadActionResult(%s %s) => %v %t %v", out.OutputId, out.OutputId, ar, handled, err) @@ -691,7 +691,15 @@ func (node *Node) processDeposit(ctx context.Context, out *mtg.Action) ([]*mtg.T panic(err) } if ar != nil { - return ar.Transactions, ar.Compaction + if restored { + err = node.store.ResetRequest(ctx, out.OutputId, out.Sequence) + if err != nil { + panic(err) + } + handled = false + } else { + return ar.Transactions, ar.Compaction + } } if handled { err = node.store.FailAction(ctx, &store.Request{ @@ -704,16 +712,16 @@ func (node *Node) processDeposit(ctx context.Context, out *mtg.Action) ([]*mtg.T return nil, "" } - var ts []*solanaApp.Transfer + var t *solanaApp.Transfer var tx *solana.Transaction var meta *rpc.TransactionMeta if common.CheckTestEnvironment(ctx) { - ts = append(ts, &solanaApp.Transfer{ + t = &solanaApp.Transfer{ AssetId: out.AssetId, Receiver: node.SolanaDepositEntry().String(), Sender: "GTQaVWXJyTyqauC4XgrDKUeVhSFkbS94YnbTnVCbFRiF", Value: new(big.Int).SetInt64(90432841), - }) + } } else { if len(out.DepositHash.String) < 16 { panic(out.TransactionHash) @@ -731,84 +739,112 @@ func (node *Node) processDeposit(ctx context.Context, out *mtg.Action) ([]*mtg.T if err != nil { panic(err) } - ts, err = solanaApp.ExtractTransfersFromTransaction(ctx, tx, rpcTx.Meta, nil) - if err != nil { - panic(err) - } + t = solanaApp.ExtractTransferFromTransactionByIndex(ctx, tx, rpcTx.Meta, out.DepositIndex.Int64) + logger.Printf("solana.ExtractTransferFromTransactionByIndex(%s %s %d) => %v", out.OutputId, out.DepositHash.String, out.DepositIndex.Int64, t) + } + if t == nil || t.AssetId != out.AssetId || t.Receiver != node.SolanaDepositEntry().String() { + return node.failDepositRequest(ctx, out, "") + } + asset, err := common.SafeReadAssetUntilSufficient(ctx, t.AssetId) + if err != nil { + panic(err) + } + var expected mc.Integer + if asset.ChainID == common.SafeSolanaChainId { + expected = mc.NewIntegerFromString(decimal.NewFromBigInt(t.Value, -int32(asset.Precision)).String()) + } else { + expected = mc.NewIntegerFromString(decimal.NewFromBigInt(t.Value, -int32(solanaApp.AssetDecimal)).String()) + } + actual := mc.NewIntegerFromString(out.Amount.String()) + if expected.Cmp(actual) != 0 { + logger.Printf("invalid deposit amount: %s %s", actual.String(), out.Amount.String()) + return node.failDepositRequest(ctx, out, "") } - var txs []*mtg.Transaction - var compaction string - for i, t := range ts { - logger.Printf("%d-th transfer: %v", i, t) - if t.AssetId != out.AssetId { - continue - } - if t.Receiver != node.SolanaDepositEntry().String() { - continue + // user == nil: transfer solana withdrawn assets from mtg to mtg deposit entry by post call for failed prepare call + // user != nil: transfer or burn assets from user account to mtg deposit entry by post call or deposit call + user, err := node.store.ReadUserByChainAddress(ctx, t.Sender) + logger.Printf("store.ReadUserByAddress(%s) => %v %v", t.Sender, user, err) + if err != nil { + panic(err) + } + var call *store.SystemCall + if user == nil { + memo := solanaApp.ExtractMemoFromTransaction(ctx, tx, meta, node.SolanaPayer()) + logger.Printf("solana.ExtractMemoFromTransaction(%s) => %s", tx.Signatures[0].String(), memo) + if memo == "" { + return node.failDepositRequest(ctx, out, "") } - user, err := node.store.ReadUserByChainAddress(ctx, t.Sender) - logger.Printf("store.ReadUserByAddress(%s) => %v %v", t.Sender, user, err) + call, err = node.store.ReadSystemCallByRequestId(ctx, memo, common.RequestStateFailed) + logger.Printf("store.ReadSystemCallByRequestId(%s) => %v %v", memo, call, err) if err != nil { panic(err) - } else if user == nil { - memo := solanaApp.ExtractMemoFromTransaction(ctx, tx, meta, node.SolanaPayer()) - logger.Printf("solana.ExtractMemoFromTransaction(%s) => %s", tx.Signatures[0].String(), memo) - if memo == "" { - continue - } - call, err := node.store.ReadSystemCallByRequestId(ctx, memo, common.RequestStateFailed) - logger.Printf("store.ReadSystemCallByRequestId(%s) => %v %v", memo, call, err) - if err != nil { - panic(err) - } - if call == nil || call.Type != store.CallTypePrepare { - continue - } - superir, err := node.store.ReadSystemCallByRequestId(ctx, call.Superior, common.RequestStateFailed) - if err != nil { - panic(err) - } - user, err = node.store.ReadUser(ctx, superir.UserIdFromPublicPath()) - if err != nil { - panic(err) - } } - mix, err := bot.NewMixAddressFromString(user.MixAddress) + if call == nil || call.Type != store.CallTypePrepare { + return node.failDepositRequest(ctx, out, "") + } + superior, err := node.store.ReadSystemCallByRequestId(ctx, call.Superior, common.RequestStateFailed) + logger.Printf("store.ReadSystemCallByRequestId(%s) => %v %v", call.Superior, superior, err) if err != nil { panic(err) } - asset, err := common.SafeReadAssetUntilSufficient(ctx, t.AssetId) + user, err = node.store.ReadUser(ctx, superior.UserIdFromPublicPath()) if err != nil { panic(err) } - var expected mc.Integer - if asset.ChainID == common.SafeSolanaChainId { - expected = mc.NewIntegerFromString(decimal.NewFromBigInt(t.Value, -int32(asset.Precision)).String()) - } else { - expected = mc.NewIntegerFromString(decimal.NewFromBigInt(t.Value, -int32(solanaApp.AssetDecimal)).String()) + call = superior + } else { + call, err = node.store.ReadSystemCallByHash(ctx, out.DepositHash.String) + logger.Printf("store.ReadSystemCallByHash(%s) => %v %v", out.DepositHash.String, call, err) + if err != nil { + panic(err) } - actual := mc.NewIntegerFromString(out.Amount.String()) - if expected.Cmp(actual) != 0 { - continue + if call == nil || call.State != common.RequestStateDone { + return node.failDepositRequest(ctx, out, "") } - id := common.UniqueId(out.DepositHash.String, fmt.Sprintf("deposit-%d", i)) - id = common.UniqueId(id, t.Receiver) - hash := out.DepositHash.String - tx := node.buildTransaction(ctx, out, node.conf.AppId, t.AssetId, mix.Members(), int(mix.Threshold), out.Amount.String(), []byte(hash), id) - if tx == nil { - return nil, t.AssetId + switch call.Type { + case store.CallTypeDeposit: + case store.CallTypePostProcess: + superior, err := node.store.ReadSystemCallByRequestId(ctx, call.Superior, 0) + if err != nil { + panic(err) + } + call = superior + default: + return node.failDepositRequest(ctx, out, "") } - txs = append(txs, tx) } + mix, err := bot.NewMixAddressFromString(user.MixAddress) + if err != nil { + panic(err) + } + id := common.UniqueId(out.DepositHash.String, fmt.Sprint(out.DepositIndex.Int64)) + id = common.UniqueId(id, t.Receiver) + mtx := node.buildTransaction(ctx, out, node.conf.AppId, t.AssetId, mix.Members(), int(mix.Threshold), out.Amount.String(), []byte(out.DepositHash.String), id) + if mtx == nil { + return node.failDepositRequest(ctx, out, t.AssetId) + } + txs := []*mtg.Transaction{mtx} + old := call.GetRefundIds() + old = append(old, mtx.TraceId) + call.RefundTraces = sql.NullString{Valid: true, String: strings.Join(old, ",")} - err = node.store.WriteDepositRequestIfNotExist(ctx, out, common.RequestStateDone, txs, compaction) - logger.Printf("store.WriteDepositRequestIfNotExist(%v %d %s) => %v", out, len(txs), compaction, err) + err = node.store.WriteDepositRequestIfNotExist(ctx, out, common.RequestStateDone, call, []*mtg.Transaction{mtx}, "") + logger.Printf("store.WriteDepositRequestIfNotExist(%v %s) => %v", out, mtx.TraceId, err) if err != nil { panic(err) } - return txs, compaction + return txs, "" +} + +func (node *Node) failDepositRequest(ctx context.Context, out *mtg.Action, compaction string) ([]*mtg.Transaction, string) { + logger.Printf("node.failDepositRequest(%v %s)", out, compaction) + err := node.store.FailDepositRequestIfNotExist(ctx, out, compaction) + if err != nil { + panic(err) + } + return nil, compaction } func (node *Node) refundAndFailRequest(ctx context.Context, req *store.Request, members []string, threshod int, call *store.SystemCall, os []*store.UserOutput) ([]*mtg.Transaction, string) { @@ -993,6 +1029,7 @@ func (node *Node) confirmBurnRelatedSystemCall(ctx context.Context, req *store.R changes := node.buildUserBalanceChangesFromMeta(ctx, tx, rpcTx.Meta, solana.MPK(user.ChainAddress)) var txs []*mtg.Transaction + var ids []string bs := solanaApp.ExtractBurnsFromTransaction(ctx, tx) for _, burn := range bs { address := burn.GetMintAccount().PublicKey.String() @@ -1031,7 +1068,11 @@ func (node *Node) confirmBurnRelatedSystemCall(ctx context.Context, req *store.R return node.failRequest(ctx, req, da.AssetId) } txs = append(txs, tx) + ids = append(ids, tx.TraceId) } + old := call.GetRefundIds() + old = append(old, ids...) + call.RefundTraces = sql.NullString{Valid: true, String: strings.Join(old, ",")} err = node.store.ConfirmBurnRelatedSystemCallWithRequest(ctx, req, call, txs) if err != nil { diff --git a/solana/node.go b/solana/node.go index f5b4d59..8d43727 100644 --- a/solana/node.go +++ b/solana/node.go @@ -19,9 +19,15 @@ import ( "github.com/fox-one/mixin-sdk-go/v2" ) +const ( + MainNetworkName = "main" + TestNetworkName = "test" +) + type Node struct { - id party.ID - threshold int + id party.ID + threshold int + networkName string conf *Configuration group *mtg.Group @@ -36,20 +42,22 @@ type Node struct { wallet *common.MixinWallet } -func NewNode(store *store.SQLite3Store, group *mtg.Group, network Network, conf *Configuration, mixin *mixin.Client, wallet *common.MixinWallet) *Node { +func NewNode(store *store.SQLite3Store, group *mtg.Group, network Network, cf *Config, mixin *mixin.Client, wallet *common.MixinWallet) *Node { + conf := cf.Computer node := &Node{ - id: party.ID(conf.MTG.App.AppId), - threshold: conf.Threshold, - conf: conf, - group: group, - network: network, - mutex: new(sync.Mutex), - sessions: make(map[string]*MultiPartySession), - operations: make(map[string]bool), - store: store, - mixin: mixin, - wallet: wallet, - solana: solanaApp.NewClient(conf.SolanaRPC), + id: party.ID(conf.MTG.App.AppId), + threshold: conf.Threshold, + networkName: cf.Dev.Network, + conf: conf, + group: group, + network: network, + mutex: new(sync.Mutex), + sessions: make(map[string]*MultiPartySession), + operations: make(map[string]bool), + store: store, + mixin: mixin, + wallet: wallet, + solana: solanaApp.NewClient(conf.SolanaRPC), } members := node.GetMembers() @@ -62,6 +70,11 @@ func NewNode(store *store.SQLite3Store, group *mtg.Group, network Network, conf } func (node *Node) Boot(ctx context.Context, version string) { + err := node.store.Migrate(ctx) + if err != nil { + panic(err) + } + go node.bootObserver(ctx, version) go node.bootSigner(ctx) go node.mtgBalanceCheckLoop(ctx) @@ -106,6 +119,10 @@ func (node *Node) findMember(m string) int { return slices.Index(node.GetMembers(), m) } +func (node *Node) Prod() bool { + return node.networkName == MainNetworkName +} + func (node *Node) synced(ctx context.Context) bool { if common.CheckTestEnvironment(ctx) { return true diff --git a/store/call.go b/store/call.go index 6fa2601..8c8da98 100644 --- a/store/call.go +++ b/store/call.go @@ -38,15 +38,16 @@ type SystemCall struct { Signature sql.NullString RequestSignerAt sql.NullTime Hash sql.NullString + RefundTraces sql.NullString CreatedAt time.Time UpdatedAt time.Time } -var systemCallCols = []string{"id", "superior_id", "request_hash", "call_type", "nonce_account", "public", "skip_postprocess", "message_hash", "raw", "state", "withdrawal_traces", "signature", "request_signer_at", "hash", "created_at", "updated_at"} +var systemCallCols = []string{"id", "superior_id", "request_hash", "call_type", "nonce_account", "public", "skip_postprocess", "message_hash", "raw", "state", "withdrawal_traces", "signature", "request_signer_at", "hash", "refund_traces", "created_at", "updated_at"} func systemCallFromRow(row Row) (*SystemCall, error) { var c SystemCall - err := row.Scan(&c.RequestId, &c.Superior, &c.RequestHash, &c.Type, &c.NonceAccount, &c.Public, &c.SkipPostProcess, &c.MessageHash, &c.Raw, &c.State, &c.WithdrawalTraces, &c.Signature, &c.RequestSignerAt, &c.Hash, &c.CreatedAt, &c.UpdatedAt) + err := row.Scan(&c.RequestId, &c.Superior, &c.RequestHash, &c.Type, &c.NonceAccount, &c.Public, &c.SkipPostProcess, &c.MessageHash, &c.Raw, &c.State, &c.WithdrawalTraces, &c.Signature, &c.RequestSignerAt, &c.Hash, &c.RefundTraces, &c.CreatedAt, &c.UpdatedAt) if err == sql.ErrNoRows { return nil, nil } @@ -57,6 +58,10 @@ func (c *SystemCall) GetWithdrawalIds() []string { return util.SplitIds(c.WithdrawalTraces.String, ",") } +func (c *SystemCall) GetRefundIds() []string { + return util.SplitIds(c.RefundTraces.String, ",") +} + func (c *SystemCall) UserIdFromPublicPath() string { data := common.DecodeHexOrPanic(c.Public) if len(data) != 16 { @@ -193,12 +198,14 @@ func (s *SQLite3Store) RefundOutputsWithRequest(ctx context.Context, req *Reques } defer common.Rollback(tx) - if call != nil { - query := "UPDATE system_calls SET state=?, updated_at=? WHERE id=? AND state=? AND withdrawal_traces IS NULL" - _, err = tx.ExecContext(ctx, query, common.RequestStateFailed, req.CreatedAt, call.RequestId, common.RequestStateInitial) - if err != nil { - return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) - } + var ids []string + for _, tx := range txs { + ids = append(ids, tx.TraceId) + } + query := "UPDATE system_calls SET state=?, refund_traces=?, updated_at=? WHERE id=? AND (state=? OR state=?) AND withdrawal_traces IS NULL" + _, err = tx.ExecContext(ctx, query, common.RequestStateFailed, strings.Join(ids, ","), req.CreatedAt, call.RequestId, common.RequestStateInitial, common.RequestStateFailed) + if err != nil { + return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) } for _, o := range os { @@ -285,8 +292,14 @@ func (s *SQLite3Store) ConfirmBurnRelatedSystemCallWithRequest(ctx context.Conte } defer common.Rollback(tx) - query := "UPDATE system_calls SET state=?, hash=?, updated_at=? WHERE id=? AND state=?" - err = s.execOne(ctx, tx, query, call.State, call.Hash, req.CreatedAt, call.RequestId, common.RequestStatePending) + query := "UPDATE system_calls SET state=?, hash=?, refund_traces=?, updated_at=? WHERE id=? AND state=?" + err = s.execOne(ctx, tx, query, call.State, call.Hash, call.RefundTraces, req.CreatedAt, call.RequestId, common.RequestStatePending) + if err != nil { + return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) + } + + query = "UPDATE system_calls SET refund_traces=?, updated_at=? WHERE id=?" + err = s.execOne(ctx, tx, query, call.RefundTraces, req.CreatedAt, call.Superior) if err != nil { return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) } @@ -332,8 +345,12 @@ func (s *SQLite3Store) FailSystemCallWithRequest(ctx context.Context, req *Reque } // if a prepare system call failed, fail its main call if call.Type == CallTypePrepare { - query = "UPDATE system_calls SET state=?, updated_at=? WHERE superior_id=? AND call_type=? AND state=?" - _, err = tx.ExecContext(ctx, query, common.RequestStateFailed, req.CreatedAt, call.Superior, CallTypeMain, common.RequestStatePending) + var ids []string + for _, tx := range txs { + ids = append(ids, tx.TraceId) + } + query = "UPDATE system_calls SET state=?, refund_traces=?, updated_at=? WHERE superior_id=? AND call_type=? AND state=?" + _, err = tx.ExecContext(ctx, query, common.RequestStateFailed, strings.Join(ids, ","), req.CreatedAt, call.Superior, CallTypeMain, common.RequestStatePending) if err != nil { return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) } @@ -595,7 +612,7 @@ func (s *SQLite3Store) CheckUnfinishedPreviousMainCalls(ctx context.Context, cal } func (s *SQLite3Store) writeSystemCall(ctx context.Context, tx *sql.Tx, call *SystemCall) error { - vals := []any{call.RequestId, call.Superior, call.RequestHash, call.Type, call.NonceAccount, call.Public, call.SkipPostProcess, call.MessageHash, call.Raw, call.State, call.WithdrawalTraces, call.Signature, call.RequestSignerAt, call.Hash, call.CreatedAt, call.UpdatedAt} + vals := []any{call.RequestId, call.Superior, call.RequestHash, call.Type, call.NonceAccount, call.Public, call.SkipPostProcess, call.MessageHash, call.Raw, call.State, call.WithdrawalTraces, call.Signature, call.RequestSignerAt, call.Hash, call.RefundTraces, call.CreatedAt, call.UpdatedAt} err := s.execOne(ctx, tx, buildInsertionSQL("system_calls", systemCallCols), vals...) if err != nil { return fmt.Errorf("INSERT system_calls %v", err) diff --git a/store/migrate.go b/store/migrate.go new file mode 100644 index 0000000..b96e1ac --- /dev/null +++ b/store/migrate.go @@ -0,0 +1,40 @@ +package store + +import ( + "context" + "database/sql" + "time" + + "github.com/MixinNetwork/safe/common" +) + +func (s *SQLite3Store) Migrate(ctx context.Context) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer common.Rollback(tx) + + key, val := "SCHEMA:VERSION:REFUND_TRACES", "" + row := tx.QueryRowContext(ctx, "SELECT value FROM properties WHERE key=?", key) + err = row.Scan(&val) + if err == nil || err != sql.ErrNoRows { + return err + } + now := time.Now().UTC() + + query := "ALTER TABLE system_calls ADD COLUMN refund_traces VARCHAR;" + _, err = tx.ExecContext(ctx, query) + if err != nil { + return err + } + _, err = tx.ExecContext(ctx, "INSERT INTO properties (key, value, created_at, updated_at) VALUES (?, ?, ?, ?)", key, query, now, now) + if err != nil { + return err + } + + return tx.Commit() +} diff --git a/store/request.go b/store/request.go index 6c996b1..7d0ab46 100644 --- a/store/request.go +++ b/store/request.go @@ -96,7 +96,7 @@ func (s *SQLite3Store) WriteRequestIfNotExist(ctx context.Context, req *Request) return tx.Commit() } -func (s *SQLite3Store) WriteDepositRequestIfNotExist(ctx context.Context, out *mtg.Action, state int, txs []*mtg.Transaction, compaction string) error { +func (s *SQLite3Store) WriteDepositRequestIfNotExist(ctx context.Context, out *mtg.Action, state int, call *SystemCall, txs []*mtg.Transaction, compaction string) error { s.mutex.Lock() defer s.mutex.Unlock() @@ -107,14 +107,28 @@ func (s *SQLite3Store) WriteDepositRequestIfNotExist(ctx context.Context, out *m defer common.Rollback(tx) existed, err := s.checkExistence(ctx, tx, "SELECT request_id FROM requests WHERE request_id=?", out.OutputId) - if err != nil || existed { + if err != nil { return err } - vals := []any{out.OutputId, out.TransactionHash, out.OutputIndex, out.AssetId, out.Amount, 0, 0, "", state, out.SequencerCreatedAt, out.SequencerCreatedAt, out.Sequence} - err = s.execOne(ctx, tx, buildInsertionSQL("requests", requestCols), vals...) + if existed { + err := s.execOne(ctx, tx, "UPDATE requests SET state=?, updated_at=? WHERE request_id=? AND state=?", + common.RequestStateDone, time.Now().UTC(), out.OutputId, common.RequestStateInitial) + if err != nil { + return fmt.Errorf("UPDATE requests %v", err) + } + } else { + vals := []any{out.OutputId, out.TransactionHash, out.OutputIndex, out.AssetId, out.Amount, 0, 0, "", state, out.SequencerCreatedAt, out.SequencerCreatedAt, out.Sequence} + err = s.execOne(ctx, tx, buildInsertionSQL("requests", requestCols), vals...) + if err != nil { + return fmt.Errorf("INSERT requests %v", err) + } + } + + query := "UPDATE system_calls SET refund_traces=? WHERE id=?" + err = s.execOne(ctx, tx, query, call.RefundTraces, call.RequestId) if err != nil { - return fmt.Errorf("INSERT requests %v", err) + return fmt.Errorf("SQLite3Store UPDATE system_calls %v", err) } err = s.writeActionResult(ctx, tx, out.OutputId, compaction, txs, out.OutputId) @@ -132,6 +146,35 @@ func (s *SQLite3Store) WriteDepositRequestIfNotExist(ctx context.Context, out *m return tx.Commit() } +func (s *SQLite3Store) FailDepositRequestIfNotExist(ctx context.Context, out *mtg.Action, compaction string) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer common.Rollback(tx) + + existed, err := s.checkExistence(ctx, tx, "SELECT request_id FROM requests WHERE request_id=?", out.OutputId) + if err != nil || existed { + return err + } + + vals := []any{out.OutputId, out.TransactionHash, out.OutputIndex, out.AssetId, out.Amount, 0, 0, "", common.RequestStateFailed, out.SequencerCreatedAt, out.SequencerCreatedAt, out.Sequence} + err = s.execOne(ctx, tx, buildInsertionSQL("requests", requestCols), vals...) + if err != nil { + return fmt.Errorf("INSERT requests %v", err) + } + + err = s.writeActionResult(ctx, tx, out.OutputId, compaction, nil, out.OutputId) + if err != nil { + return err + } + + return tx.Commit() +} + func (s *SQLite3Store) ReadRequest(ctx context.Context, id string) (*Request, error) { query := fmt.Sprintf("SELECT %s FROM requests WHERE request_id=?", strings.Join(requestCols, ",")) row := s.db.QueryRowContext(ctx, query, id) @@ -170,7 +213,7 @@ func (s *SQLite3Store) FailRequest(ctx context.Context, req *Request, compaction return tx.Commit() } -func (s *SQLite3Store) ResetRequest(ctx context.Context, req *Request) error { +func (s *SQLite3Store) ResetRequest(ctx context.Context, reqId string, reqSequence uint64) error { s.mutex.Lock() defer s.mutex.Unlock() @@ -181,7 +224,7 @@ func (s *SQLite3Store) ResetRequest(ctx context.Context, req *Request) error { defer common.Rollback(tx) _, err = tx.ExecContext(ctx, "UPDATE requests SET state=?, sequence=?, updated_at=? WHERE request_id=? AND state=?", - common.RequestStateInitial, req.Sequence, time.Now().UTC(), req.Id, common.RequestStateFailed) + common.RequestStateInitial, reqSequence, time.Now().UTC(), reqId, common.RequestStateFailed) if err != nil { return fmt.Errorf("UPDATE requests %v", err) } diff --git a/store/schema.sql b/store/schema.sql index 92adb2b..62817ca 100644 --- a/store/schema.sql +++ b/store/schema.sql @@ -156,6 +156,7 @@ CREATE TABLE IF NOT EXISTS system_calls ( signature VARCHAR, request_signer_at TIMESTAMP, hash VARCHAR, + refund_traces VARCHAR, created_at TIMESTAMP NOT NULL, updated_at TIMESTAMP NOT NULL, PRIMARY KEY ('id') diff --git a/store/test.go b/store/test.go index 3fcd914..7efd682 100644 --- a/store/test.go +++ b/store/test.go @@ -60,7 +60,7 @@ func (s *SQLite3Store) TestWriteCall(ctx context.Context, call *SystemCall) erro } defer common.Rollback(tx) - vals := []any{call.RequestId, call.Superior, call.RequestHash, call.Type, call.NonceAccount, call.Public, call.SkipPostProcess, call.MessageHash, call.Raw, call.State, call.WithdrawalTraces, call.Signature, call.RequestSignerAt, call.Hash, call.CreatedAt, call.UpdatedAt} + vals := []any{call.RequestId, call.Superior, call.RequestHash, call.Type, call.NonceAccount, call.Public, call.SkipPostProcess, call.MessageHash, call.Raw, call.State, call.WithdrawalTraces, call.Signature, call.RequestSignerAt, call.Hash, call.RefundTraces, call.CreatedAt, call.UpdatedAt} err = s.execOne(ctx, tx, buildInsertionSQL("system_calls", systemCallCols), vals...) if err != nil { return fmt.Errorf("INSERT system_calls %v", err)