diff --git a/.gitignore b/.gitignore index 14a1822..f894503 100644 --- a/.gitignore +++ b/.gitignore @@ -24,8 +24,5 @@ # Executables *.exe -# Internal Test files -*_test.go - # internal tools /tools diff --git a/pkg/connector/auth_recovery.go b/pkg/connector/auth_recovery.go new file mode 100644 index 0000000..143cf8f --- /dev/null +++ b/pkg/connector/auth_recovery.go @@ -0,0 +1,70 @@ +package connector + +import ( + "context" + "fmt" + + "github.com/highesttt/matrix-line-messenger/pkg/line" +) + +type lineCallDeps[T any] struct { + newClient func() *line.Client + recover func(context.Context) error + isAuthError func(error) bool + call func(*line.Client) (T, error) +} + +func callLineWithRecovery[T any](ctx context.Context, client *line.Client, deps lineCallDeps[T]) (*line.Client, T, error) { + if client == nil { + client = deps.newClient() + } + res, err := deps.call(client) + if err == nil || !deps.isAuthError(err) { + return client, res, err + } + + if errRecover := deps.recover(ctx); errRecover != nil { + var zero T + return client, zero, fmt.Errorf("failed to recover token after LINE auth error: %w", errRecover) + } + + client = deps.newClient() + res, err = deps.call(client) + return client, res, err +} + +func (lc *LineClient) isTokenError(err error) bool { + if line.IsNoUsableE2EEGroupKey(err) || line.IsNoUsableE2EEPublicKey(err) { + return false + } + return line.IsAuthError(err) +} + +func (lc *LineClient) callLine(ctx context.Context, call func(*line.Client) error) (*line.Client, error) { + return lc.callLineUsing(ctx, nil, call) +} + +func (lc *LineClient) callLineUsing(ctx context.Context, client *line.Client, call func(*line.Client) error) (*line.Client, error) { + client, _, err := callLineWithRecovery(ctx, client, lineCallDeps[struct{}]{ + newClient: func() *line.Client { return lc.newClient() }, + recover: lc.recoverToken, + isAuthError: lc.isTokenError, + call: func(client *line.Client) (struct{}, error) { + return struct{}{}, call(client) + }, + }) + return client, err +} + +func callLineResult[T any](lc *LineClient, ctx context.Context, call func(*line.Client) (T, error)) (*line.Client, T, error) { + return callLineResultUsing(lc, ctx, nil, call) +} + +func callLineResultUsing[T any](lc *LineClient, ctx context.Context, client *line.Client, call func(*line.Client) (T, error)) (*line.Client, T, error) { + return callLineWithRecovery(ctx, client, lineCallDeps[T]{ + newClient: func() *line.Client { return lc.newClient() }, + recover: lc.recoverToken, + isAuthError: lc.isTokenError, + call: call, + }) +} diff --git a/pkg/connector/auth_recovery_test.go b/pkg/connector/auth_recovery_test.go new file mode 100644 index 0000000..029b4e3 --- /dev/null +++ b/pkg/connector/auth_recovery_test.go @@ -0,0 +1,261 @@ +package connector + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/highesttt/matrix-line-messenger/pkg/line" +) + +var ( + errAuthRequired = errors.New(`API error 400: {"code":10051,"message":"RESPONSE_ERROR","data":{"name":"TalkException","code":119,"reason":"Access token refresh required"}}`) + errNotMember = errors.New(`API error 400: {"code":10051,"data":{"name":"TalkException","code":10,"reason":"not a member"}}`) + errNetwork = errors.New("request failed: dial tcp: i/o timeout") +) + +func TestCallLineWithRecovery(t *testing.T) { + tests := []struct { + name string + callErrors []error + recoverErr error + wantCalls int + wantRecover int + wantErr error + wantErrPrefix string + }{ + { + name: "success without recovery", + callErrors: []error{nil}, + wantCalls: 1, + }, + { + name: "non auth error is returned without recovery", + callErrors: []error{errNotMember}, + wantCalls: 1, + wantRecover: 0, + wantErr: errNotMember, + }, + { + name: "network error is returned without recovery", + callErrors: []error{errNetwork}, + wantCalls: 1, + wantRecover: 0, + wantErr: errNetwork, + }, + { + name: "auth error recovers and retries once", + callErrors: []error{errAuthRequired, nil}, + wantCalls: 2, + wantRecover: 1, + }, + { + name: "recovery failure is returned without retry", + callErrors: []error{errAuthRequired}, + recoverErr: errors.New("refresh failed"), + wantCalls: 1, + wantRecover: 1, + wantErrPrefix: "failed to recover token after LINE auth error", + }, + { + name: "retry auth error is not retried again", + callErrors: []error{errAuthRequired, errAuthRequired}, + wantCalls: 2, + wantRecover: 1, + wantErr: errAuthRequired, + }, + { + name: "retry non auth error is returned to caller", + callErrors: []error{errAuthRequired, errors.New("Extension does not support file upload")}, + wantCalls: 2, + wantRecover: 1, + wantErrPrefix: "Extension does not support file upload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var calls int + var recoveries int + + _, _, err := callLineWithRecovery(context.Background(), nil, lineCallDeps[struct{}]{ + newClient: func() *line.Client { + return line.NewClient("token") + }, + recover: func(context.Context) error { + recoveries++ + return tt.recoverErr + }, + isAuthError: line.IsAuthError, + call: func(*line.Client) (struct{}, error) { + err := tt.callErrors[calls] + calls++ + return struct{}{}, err + }, + }) + + if calls != tt.wantCalls { + t.Fatalf("calls = %d, want %d", calls, tt.wantCalls) + } + if recoveries != tt.wantRecover { + t.Fatalf("recoveries = %d, want %d", recoveries, tt.wantRecover) + } + if tt.wantErr != nil && !errors.Is(err, tt.wantErr) { + t.Fatalf("err = %v, want %v", err, tt.wantErr) + } + if tt.wantErrPrefix != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErrPrefix) { + t.Fatalf("err = %v, want containing %q", err, tt.wantErrPrefix) + } + } + if tt.wantErr == nil && tt.wantErrPrefix == "" && err != nil { + t.Fatalf("unexpected err: %v", err) + } + }) + } +} + +func TestCallLineWithRecoveryReusesClientUntilRecovery(t *testing.T) { + ctx := context.Background() + initialClient := line.NewClient("initial") + refreshedClient := line.NewClient("refreshed") + var newClients int + var calls []string + + client, _, err := callLineWithRecovery(ctx, initialClient, lineCallDeps[struct{}]{ + newClient: func() *line.Client { + newClients++ + return refreshedClient + }, + recover: func(context.Context) error { + return nil + }, + isAuthError: line.IsAuthError, + call: func(client *line.Client) (struct{}, error) { + calls = append(calls, client.AccessToken) + if len(calls) == 1 { + return struct{}{}, errAuthRequired + } + return struct{}{}, nil + }, + }) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if client != refreshedClient { + t.Fatal("expected recovered client to be returned") + } + if newClients != 1 { + t.Fatalf("new clients = %d, want 1", newClients) + } + if len(calls) != 2 || calls[0] != "initial" || calls[1] != "refreshed" { + t.Fatalf("calls used clients %v, want [initial refreshed]", calls) + } +} + +func TestCallLineWithRecoveryUsesProvidedClientWithoutRecreating(t *testing.T) { + ctx := context.Background() + initialClient := line.NewClient("initial") + var newClients int + + client, _, err := callLineWithRecovery(ctx, initialClient, lineCallDeps[struct{}]{ + newClient: func() *line.Client { + newClients++ + return line.NewClient("unexpected") + }, + recover: func(context.Context) error { return nil }, + isAuthError: line.IsAuthError, + call: func(client *line.Client) (struct{}, error) { + if client.AccessToken != "initial" { + t.Fatalf("client token = %q, want initial", client.AccessToken) + } + return struct{}{}, nil + }, + }) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if client != initialClient { + t.Fatal("expected provided client to be returned") + } + if newClients != 0 { + t.Fatalf("new clients = %d, want 0", newClients) + } +} + +func TestLineClientIsTokenErrorExcludesE2EEErrors(t *testing.T) { + lc := &LineClient{} + if !lc.isTokenError(errAuthRequired) { + t.Fatal("expected auth-required error to be classified as token error") + } + if lc.isTokenError(line.ErrNoUsableE2EEGroupKey) { + t.Fatal("E2EE group key errors must not trigger token recovery") + } + if lc.isTokenError(line.ErrNoUsableE2EEPublicKey) { + t.Fatal("E2EE public key errors must not trigger token recovery") + } +} + +func TestRunTokenRecoverySkipsRecentRecovery(t *testing.T) { + lc := &LineClient{recoverTime: time.Now()} + var calls int + + err := lc.runTokenRecovery(context.Background(), func(context.Context) error { + calls++ + return nil + }) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if calls != 0 { + t.Fatalf("recovery calls = %d, want 0", calls) + } +} + +func TestRunTokenRecoverySerializesConcurrentRecovery(t *testing.T) { + var lc LineClient + var calls int32 + started := make(chan struct{}) + release := make(chan struct{}) + + recover := func(context.Context) error { + if atomic.AddInt32(&calls, 1) == 1 { + close(started) + <-release + } + return nil + } + + var wg sync.WaitGroup + errs := make(chan error, 4) + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + errs <- lc.runTokenRecovery(context.Background(), recover) + }() + } + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("first recovery did not start") + } + + close(release) + wg.Wait() + close(errs) + + for err := range errs { + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("recovery calls = %d, want 1", got) + } +} diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 2221840..d9ada74 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -30,6 +30,10 @@ type LineClient struct { sentReqSeqs map[int]time.Time lastReqSeq int + tokenMu sync.RWMutex + recoverMu sync.Mutex + recoverTime time.Time + // cacheMu protects peerKeys, blockedUsers, contactCache, mediaFlowCache, // noE2EEGroups, groupMemberCache, and generatedGroupNameCache. // Hold it only around map accesses; never across network calls. @@ -67,6 +71,37 @@ type cachedContact struct { } const defaultMediaFlowTTL = 6 * time.Hour +const recentTokenRecoveryWindow = 10 * time.Second + +func (lc *LineClient) getAccessToken() string { + lc.tokenMu.RLock() + defer lc.tokenMu.RUnlock() + return lc.AccessToken +} + +func (lc *LineClient) getTokens() (string, string) { + lc.tokenMu.RLock() + defer lc.tokenMu.RUnlock() + return lc.AccessToken, lc.RefreshToken +} + +func (lc *LineClient) setTokens(accessToken, refreshToken string) (string, string) { + lc.tokenMu.Lock() + defer lc.tokenMu.Unlock() + lc.AccessToken = accessToken + if refreshToken != "" { + lc.RefreshToken = refreshToken + } + return lc.AccessToken, lc.RefreshToken +} + +func (lc *LineClient) hasAccessToken() bool { + return lc.getAccessToken() != "" +} + +func (lc *LineClient) newClient() *line.Client { + return line.NewClient(lc.getAccessToken()) +} func (lc *LineClient) avatarFromPicturePath(picturePath string) *bridgev2.Avatar { if picturePath == "" { @@ -98,8 +133,9 @@ func (lc *LineClient) shouldUseE2EEMediaFlow(chatMid string, contentType int) bo } lc.cacheMu.Unlock() - client := line.NewClient(lc.AccessToken) - resp, err := client.DetermineMediaMessageFlow(chatMid) + _, resp, err := callLineResult(lc, context.Background(), func(client *line.Client) (*line.MediaMessageFlowResponse, error) { + return client.DetermineMediaMessageFlow(chatMid) + }) if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Str("chat_mid", chatMid). Msg("Failed to determine media flow, defaulting to E2EE upload") @@ -141,28 +177,26 @@ var _ bridgev2.BackfillingNetworkAPI = (*LineClient)(nil) var _ bridgev2.ReactionHandlingNetworkAPI = (*LineClient)(nil) func (lc *LineClient) refreshAndSave(ctx context.Context) error { - if lc.RefreshToken == "" { + accessToken, refreshToken := lc.getTokens() + if refreshToken == "" { return fmt.Errorf("no refresh token available") } - client := line.NewClient(lc.AccessToken) - res, err := client.RefreshAccessToken(lc.RefreshToken) + client := line.NewClient(accessToken) + res, err := client.RefreshAccessToken(refreshToken) if err != nil { return fmt.Errorf("failed to refresh token: %w", err) } - lc.AccessToken = res.AccessToken - if res.RefreshToken != "" { - lc.RefreshToken = res.RefreshToken - } + accessToken, refreshToken = lc.setTokens(res.AccessToken, res.RefreshToken) // Rotating the main access token invalidates any OBS token derived from it, // so drop the cached one — the next OBS call will mint a fresh one. line.InvalidateOBSTokenCache() meta := lc.UserLogin.Metadata.(*UserLoginMetadata) - meta.AccessToken = lc.AccessToken - meta.RefreshToken = lc.RefreshToken + meta.AccessToken = accessToken + meta.RefreshToken = refreshToken err = lc.UserLogin.Save(ctx) if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to save refreshed tokens to DB") @@ -174,23 +208,39 @@ func (lc *LineClient) refreshAndSave(ctx context.Context) error { } func (lc *LineClient) isRefreshRequired(err error) bool { - return strings.Contains(err.Error(), "\"code\":119") || strings.Contains(err.Error(), "Access token refresh required") + return line.IsRefreshRequired(err) } func (lc *LineClient) isLoggedOut(err error) bool { - msg := err.Error() - return strings.Contains(msg, "V3_TOKEN_CLIENT_LOGGED_OUT") + return line.IsLoggedOut(err) } // recoverToken attempts to restore a valid session by refreshing, then re-logging in. // Returns nil on success. On failure the caller should send StateBadCredentials. func (lc *LineClient) recoverToken(ctx context.Context) error { - if err := lc.refreshAndSave(ctx); err == nil { - lc.UserLogin.Bridge.Log.Info().Msg("Token recovered via refresh") + return lc.runTokenRecovery(ctx, func(ctx context.Context) error { + if err := lc.refreshAndSave(ctx); err == nil { + lc.UserLogin.Bridge.Log.Info().Msg("Token recovered via refresh") + return nil + } + lc.UserLogin.Bridge.Log.Info().Msg("Refresh failed, attempting re-login with stored credentials...") + return lc.tryLogin(ctx) + }) +} + +func (lc *LineClient) runTokenRecovery(ctx context.Context, recover func(context.Context) error) error { + lc.recoverMu.Lock() + defer lc.recoverMu.Unlock() + + if !lc.recoverTime.IsZero() && time.Since(lc.recoverTime) < recentTokenRecoveryWindow { return nil } - lc.UserLogin.Bridge.Log.Info().Msg("Refresh failed, attempting re-login with stored credentials...") - return lc.tryLogin(ctx) + + if err := recover(ctx); err != nil { + return err + } + lc.recoverTime = time.Now() + return nil } func (lc *LineClient) Connect(ctx context.Context) { @@ -219,7 +269,7 @@ func (lc *LineClient) Connect(ctx context.Context) { lc.Mid = meta.Mid } } - if lc.AccessToken == "" { + if !lc.hasAccessToken() { if err := lc.tryLogin(ctx); err != nil { lc.UserLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateBadCredentials, @@ -242,7 +292,7 @@ func (lc *LineClient) Connect(ctx context.Context) { return } - lc.UserLogin.Bridge.Log.Info().Int("token_len", len(lc.AccessToken)).Msg("LINE client connected; notifying bridge") + lc.UserLogin.Bridge.Log.Info().Int("token_len", len(lc.getAccessToken())).Msg("LINE client connected; notifying bridge") lc.UserLogin.BridgeState.Send(status.BridgeState{ StateEvent: status.StateConnected, }) @@ -262,7 +312,7 @@ func (lc *LineClient) Connect(ctx context.Context) { } // Storage key is optional for runtime decrypt/encrypt; try it for file support - client := line.NewClient(lc.AccessToken) + client := lc.newClient() ei3, err := client.GetEncryptedIdentityV3() if err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to fetch EncryptedIdentityV3") @@ -358,15 +408,17 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { res = waitRes client = waitClient } - lc.AccessToken = client.AccessToken + accessToken := client.AccessToken + refreshToken := "" if res.TokenV3IssueResult != nil { if res.TokenV3IssueResult.AccessToken != "" { - lc.AccessToken = res.TokenV3IssueResult.AccessToken + accessToken = res.TokenV3IssueResult.AccessToken } if res.TokenV3IssueResult.RefreshToken != "" { - lc.RefreshToken = res.TokenV3IssueResult.RefreshToken + refreshToken = res.TokenV3IssueResult.RefreshToken } } + accessToken, refreshToken = lc.setTokens(accessToken, refreshToken) // Re-login replaces the main access token, which invalidates any cached // OBS token derived from the previous one. @@ -381,8 +433,8 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { // Save the new tokens and updated certificate to metadata if meta, ok := lc.UserLogin.Metadata.(*UserLoginMetadata); ok { - meta.AccessToken = lc.AccessToken - meta.RefreshToken = lc.RefreshToken + meta.AccessToken = accessToken + meta.RefreshToken = refreshToken if res.Certificate != "" { meta.Certificate = res.Certificate } @@ -396,7 +448,7 @@ func (lc *LineClient) tryLogin(ctx context.Context) error { } func (lc *LineClient) ensureValidToken(ctx context.Context) error { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() _, err := client.GetProfile() if err == nil { return nil @@ -428,7 +480,7 @@ func (lc *LineClient) Disconnect() { lc.wg.Wait() } -func (lc *LineClient) IsLoggedIn() bool { return lc.AccessToken != "" } +func (lc *LineClient) IsLoggedIn() bool { return lc.hasAccessToken() } func (lc *LineClient) GetUserID() networkid.UserID { return makeUserID(lc.Mid) diff --git a/pkg/connector/creategroup.go b/pkg/connector/creategroup.go index eb6ddbc..4e44ccb 100644 --- a/pkg/connector/creategroup.go +++ b/pkg/connector/creategroup.go @@ -26,7 +26,7 @@ func (lc *LineClient) CreateGroup(ctx context.Context, params *bridgev2.GroupCre name = params.Name.Name } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() var chat *line.Chat var err error chatType := 1 // ROOM: members join automatically. @@ -34,7 +34,7 @@ func (lc *LineClient) CreateGroup(ctx context.Context, params *bridgev2.GroupCre chat, err = client.CreateChat(participantMids, lineName, chatType) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chat, err = client.CreateChat(participantMids, lineName, chatType) } } @@ -150,7 +150,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb } members = otherMembers - client := line.NewClient(lc.AccessToken) + client := lc.newClient() // Fetch current E2EE public keys for all other members as a batch. If the batch // call fails (e.g. server 500 for a specific member), fall back to fetching @@ -163,7 +163,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() pubKeys, err = client.GetLastE2EEPublicKeys(pubKeysReq) } } @@ -184,7 +184,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb } if lc.isRefreshRequired(nErr) || lc.isLoggedOut(nErr) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() res, nErr = client.NegotiateE2EEPublicKey(mid) } } @@ -259,7 +259,7 @@ func (lc *LineClient) registerGroupKey(ctx context.Context, chatMid string, memb if err := client.RegisterE2EEGroupKey(1, chatMid, apiMembers, keyIds, encryptedKeys); err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() err = client.RegisterE2EEGroupKey(1, chatMid, apiMembers, keyIds, encryptedKeys) } } diff --git a/pkg/connector/e2ee_keys.go b/pkg/connector/e2ee_keys.go index 5f453b5..77fc43a 100644 --- a/pkg/connector/e2ee_keys.go +++ b/pkg/connector/e2ee_keys.go @@ -21,7 +21,7 @@ func (lc *LineClient) fetchAndUnwrapGroupKey(ctx context.Context, chatMid string return fmt.Errorf("E2EE manager not initialized") } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() fetch := func() (*line.E2EEGroupSharedKey, error) { if groupKeyID > 0 { return client.GetE2EEGroupSharedKey(chatMid, groupKeyID) @@ -44,7 +44,7 @@ func (lc *LineClient) fetchAndUnwrapGroupKey(ctx context.Context, chatMid string // Token recovery for other error types if err != nil && !line.IsNoUsableE2EEGroupKey(err) && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() sharedKey, err = fetch() } else { return fmt.Errorf("failed to recover token before fetching group key: %w", errRecover) @@ -107,7 +107,7 @@ func (lc *LineClient) ensurePeerKey(_ context.Context, mid string) (int, string, return cached.raw, cached.pub, nil } } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() res, err := client.NegotiateE2EEPublicKey(mid) if err != nil { // Cache negative result so we don't keep hitting the API @@ -165,12 +165,12 @@ func (lc *LineClient) clearGroupNoE2EE(chatMid string) { // Invitees are included because group key registration must happen before they accept, // otherwise the key won't be available when they start sending messages. func (lc *LineClient) getChatMemberMIDs(ctx context.Context, chatMid string) ([]string, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() chats, err := client.GetChats([]string{chatMid}, true, true) if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chats, err = client.GetChats([]string{chatMid}, true, true) } } @@ -319,7 +319,7 @@ func (lc *LineClient) ensurePeerKeyByID(_ context.Context, mid string, keyID int return cached.raw, cached.pub, nil } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() // keyVersion 1 res, err := client.GetE2EEPublicKey(mid, 1, keyID) if err != nil { diff --git a/pkg/connector/handle_message.go b/pkg/connector/handle_message.go index 41ef9e5..4813320 100644 --- a/pkg/connector/handle_message.go +++ b/pkg/connector/handle_message.go @@ -30,7 +30,7 @@ func (lc *LineClient) newMessageHandler() *handlers.Handler { RecoverToken: lc.recoverToken, IsRefreshRequired: lc.isRefreshRequired, IsLoggedOut: lc.isLoggedOut, - NewClient: func() *line.Client { return line.NewClient(lc.AccessToken) }, + NewClient: func() *line.Client { return lc.newClient() }, DecryptMedia: lc.decryptImageData, } } diff --git a/pkg/connector/handlers/audio.go b/pkg/connector/handlers/audio.go index 7fe1822..02ae5af 100644 --- a/pkg/connector/handlers/audio.go +++ b/pkg/connector/handlers/audio.go @@ -51,7 +51,6 @@ func (h *Handler) ConvertAudio(ctx context.Context, portal *bridgev2.Portal, int client = newClient audioData, err = client.DownloadOBSWithSIDOptions(ctx, oid, data.ID, sid, downloadOptions) } - _ = client if err != nil { h.Log.Warn(). diff --git a/pkg/connector/handlers/file.go b/pkg/connector/handlers/file.go index 2788c53..9763476 100644 --- a/pkg/connector/handlers/file.go +++ b/pkg/connector/handlers/file.go @@ -37,6 +37,12 @@ func (h *Handler) ConvertFile(ctx context.Context, portal *bridgev2.Portal, inte } downloadOptions := lineOBSDownloadOptions(data.ContentMetadata, isPlainMedia) fileData, err := client.DownloadOBSWithSIDOptions(ctx, oid, data.ID, sid, downloadOptions) + + if newClient, ok := h.tryRecoverClient(ctx, err); ok { + client = newClient + fileData, err = client.DownloadOBSWithSIDOptions(ctx, oid, data.ID, sid, downloadOptions) + } + if err != nil { h.Log.Warn(). Err(err). diff --git a/pkg/connector/handlers/handler.go b/pkg/connector/handlers/handler.go index 12b1a66..88e78a0 100644 --- a/pkg/connector/handlers/handler.go +++ b/pkg/connector/handlers/handler.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "strings" "github.com/rs/zerolog" @@ -33,7 +32,7 @@ func (h *Handler) tryRecoverClient(ctx context.Context, err error) (*line.Client if err == nil { return nil, false } - if !strings.Contains(err.Error(), "401") && !h.IsRefreshRequired(err) && !h.IsLoggedOut(err) { + if !line.IsUnauthorizedStatus(err) && !h.IsRefreshRequired(err) && !h.IsLoggedOut(err) { return nil, false } if errRecover := h.RecoverToken(ctx); errRecover != nil { diff --git a/pkg/connector/handlers/video.go b/pkg/connector/handlers/video.go index 8dbf05c..c91cd3c 100644 --- a/pkg/connector/handlers/video.go +++ b/pkg/connector/handlers/video.go @@ -53,7 +53,6 @@ func (h *Handler) ConvertVideo(ctx context.Context, portal *bridgev2.Portal, int client = newClient videoData, err = client.DownloadOBSWithSIDOptions(ctx, oid, data.ID, sid, downloadOptions) } - _ = client if err != nil { h.Log.Warn(). diff --git a/pkg/connector/reaction.go b/pkg/connector/reaction.go index b571861..fba1a3b 100644 --- a/pkg/connector/reaction.go +++ b/pkg/connector/reaction.go @@ -365,9 +365,11 @@ func (lc *LineClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.Ma return nil, err } - client := line.NewClient(lc.AccessToken) reqSeq := lc.nextReqSeq() - if err = client.React(int64(reqSeq), targetID, ref.reactionType()); err != nil { + _, err = lc.callLine(ctx, func(client *line.Client) error { + return client.React(int64(reqSeq), targetID, ref.reactionType()) + }) + if err != nil { if line.IsInvalidPaidReactionType(err) { return nil, unsupportedMatrixReactionError(key) } @@ -391,9 +393,10 @@ func (lc *LineClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridg if err != nil { return err } - client := line.NewClient(lc.AccessToken) reqSeq := lc.nextReqSeq() - err = client.CancelReaction(int64(reqSeq), targetID) + _, err = lc.callLine(ctx, func(client *line.Client) error { + return client.CancelReaction(int64(reqSeq), targetID) + }) if line.IsNotAMemberError(err) { return reactionNotAMemberError() } diff --git a/pkg/connector/send_message.go b/pkg/connector/send_message.go index ec42b8e..b586790 100644 --- a/pkg/connector/send_message.go +++ b/pkg/connector/send_message.go @@ -31,7 +31,26 @@ type mentionEntry struct { var mentionLinkRegex = regexp.MustCompile(`]*href="https://matrix\.to/#/([^"]+)"[^>]*>([^<]+)`) func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (*bridgev2.MatrixMessageResponse, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() + callLineErr := func(call func(*line.Client) error) error { + var err error + client, err = lc.callLineUsing(ctx, client, call) + return err + } + callLineString := func(call func(*line.Client) (string, error)) (string, error) { + var err error + var res string + client, res, err = callLineResultUsing(lc, ctx, client, call) + return res, err + } + sendLineMessage := func(reqSeq int, lineMsg *line.Message) (*line.Message, error) { + var err error + var sentMsg *line.Message + client, sentMsg, err = callLineResultUsing(lc, ctx, client, func(client *line.Client) (*line.Message, error) { + return client.SendMessage(int64(reqSeq), lineMsg) + }) + return sentMsg, err + } portalMid := string(msg.Portal.ID) fromMid := lc.midOrFallback() @@ -205,7 +224,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat return nil, fmt.Errorf("failed to encrypt image data: %w", err) } - oid, err := client.UploadOBS(uploadData) + oid, err := callLineString(func(client *line.Client) (string, error) { + return client.UploadOBS(uploadData) + }) if err != nil { return nil, fmt.Errorf("failed to upload image to OBS: %w", err) } @@ -217,7 +238,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to encrypt thumbnail, continuing without it") } else { previewOID := fmt.Sprintf("%s__ud-preview", oid) - if err := client.UploadOBSWithOID(thumbToUpload, previewOID); err != nil { + if err := callLineErr(func(client *line.Client) error { + return client.UploadOBSWithOID(thumbToUpload, previewOID) + }); err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to upload preview, continuing without it") } else { mediaThumbInfo := map[string]interface{}{ @@ -303,7 +326,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat return nil, fmt.Errorf("failed to encrypt file data: %w", err) } - oid, err := client.UploadOBSWithSID(uploadData, "emf") + oid, err := callLineString(func(client *line.Client) (string, error) { + return client.UploadOBSWithSID(uploadData, "emf") + }) if err != nil { return nil, fmt.Errorf("failed to upload file to OBS: %w", err) } @@ -384,7 +409,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat return nil, fmt.Errorf("failed to encrypt video data: %w", err) } - oid, err := client.UploadOBSWithSID(uploadData, "emv") + oid, err := callLineString(func(client *line.Client) (string, error) { + return client.UploadOBSWithSID(uploadData, "emv") + }) if err != nil { return nil, fmt.Errorf("failed to upload video to OBS: %w", err) } @@ -392,7 +419,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat chunkHashes := generateChunkHashes(uploadData[:len(uploadData)-32]) if len(chunkHashes) > 0 { hashOID := fmt.Sprintf("%s__ud-hash", oid) - if err := client.UploadOBSWithOIDAndSID(chunkHashes, hashOID, "emv"); err != nil { + if err := callLineErr(func(client *line.Client) error { + return client.UploadOBSWithOIDAndSID(chunkHashes, hashOID, "emv") + }); err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to upload video hash, continuing without it") } else { lc.UserLogin.Bridge.Log.Info(). @@ -418,7 +447,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to encrypt video thumbnail, continuing without it") } else { previewOID := fmt.Sprintf("%s__ud-preview", oid) - if err := client.UploadOBSWithOIDAndSID(thumbToUpload, previewOID, "emv"); err != nil { + if err := callLineErr(func(client *line.Client) error { + return client.UploadOBSWithOIDAndSID(thumbToUpload, previewOID, "emv") + }); err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to upload video preview, continuing without it") } else { mediaThumbInfo := map[string]interface{}{ @@ -480,7 +511,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat return nil, fmt.Errorf("failed to encrypt audio data: %w", err) } - oid, err := client.UploadOBSWithSID(uploadData, "ema") + oid, err := callLineString(func(client *line.Client) (string, error) { + return client.UploadOBSWithSID(uploadData, "ema") + }) if err != nil { return nil, fmt.Errorf("failed to upload audio to OBS: %w", err) } @@ -622,7 +655,7 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat reqSeq := int(now % 1_000_000_000) lc.trackReqSeq(reqSeq) - sentMsg, err := client.SendMessage(int64(reqSeq), lineMsg) + sentMsg, err := sendLineMessage(reqSeq, lineMsg) // LINE rejects some file types from the Chrome Extension client. // Retry by wrapping the file in a ZIP archive (matching Chrome Extension behavior). @@ -647,7 +680,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat if encErr != nil { return nil, fmt.Errorf("failed to encrypt zipped file: %w", encErr) } - oid, uploadErr := client.UploadOBSWithSID(uploadData, "emf") + oid, uploadErr := callLineString(func(client *line.Client) (string, error) { + return client.UploadOBSWithSID(uploadData, "emf") + }) if uploadErr != nil { return nil, fmt.Errorf("failed to upload zipped file to OBS: %w", uploadErr) } @@ -682,7 +717,7 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat retryReqSeq := int(time.Now().UnixMilli() % 1_000_000_000) lc.trackReqSeq(retryReqSeq) - sentMsg, err = client.SendMessage(int64(retryReqSeq), lineMsg) + sentMsg, err = sendLineMessage(retryReqSeq, lineMsg) } // If LINE rejects with "group key is not registered" (code 99), @@ -709,7 +744,7 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat if err == nil { retryReqSeq := int(time.Now().UnixMilli() % 1_000_000_000) lc.trackReqSeq(retryReqSeq) - sentMsg, err = client.SendMessage(int64(retryReqSeq), lineMsg) + sentMsg, err = sendLineMessage(retryReqSeq, lineMsg) } } else { lc.UserLogin.Bridge.Log.Warn().Str("chat_mid", portalMid). @@ -733,7 +768,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat obsType = "file" } - if err := client.UploadOBSPlain(plainMediaData, sentMsg.ID, obsType); err != nil { + if err := callLineErr(func(client *line.Client) error { + return client.UploadOBSPlain(plainMediaData, sentMsg.ID, obsType) + }); err != nil { return nil, fmt.Errorf("failed to upload plain media to OBS: %w", err) } lc.UserLogin.Bridge.Log.Info(). @@ -744,7 +781,9 @@ func (lc *LineClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Mat if plainThumbData != nil { previewID := fmt.Sprintf("%s__ud-preview", sentMsg.ID) - if err := client.UploadOBSPlain(plainThumbData, previewID, obsType); err != nil { + if err := callLineErr(func(client *line.Client) error { + return client.UploadOBSPlain(plainThumbData, previewID, obsType) + }); err != nil { lc.UserLogin.Bridge.Log.Warn().Err(err).Msg("Failed to upload plain media thumbnail, continuing without it") } } @@ -793,12 +832,12 @@ func contentTypeForMsgType(msgType event.MessageType) int { } func (lc *LineClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { - client := line.NewClient(lc.AccessToken) - reqSeq := int(time.Now().UnixMilli() % 1_000_000_000) lc.trackReqSeq(reqSeq) - err := client.UnsendMessage(int64(reqSeq), string(msg.TargetMessage.ID)) + _, err := lc.callLine(ctx, func(client *line.Client) error { + return client.UnsendMessage(int64(reqSeq), string(msg.TargetMessage.ID)) + }) if err != nil && strings.Contains(err.Error(), "message too old") { return bridgev2.WrapErrorInStatus(fmt.Errorf("message too old to unsend on LINE (24h limit)")). WithStatus(event.MessageStatusFail). @@ -810,8 +849,6 @@ func (lc *LineClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridge } func (lc *LineClient) HandleMatrixLeaveRoom(ctx context.Context, portal *bridgev2.Portal) error { - client := line.NewClient(lc.AccessToken) - reqSeq := int(time.Now().UnixMilli() % 1_000_000_000) // A still-pending invite (the user declined the Request without accepting it) must be @@ -819,12 +856,18 @@ func (lc *LineClient) HandleMatrixLeaveRoom(ctx context.Context, portal *bridgev // clears MessageRequest once the invite is accepted, so an accepted chat falls through to // the normal SendChatRemoved leave path below. if portal.MessageRequest { - return client.RejectChatInvitation(int64(reqSeq), string(portal.ID)) + _, err := lc.callLine(ctx, func(client *line.Client) error { + return client.RejectChatInvitation(int64(reqSeq), string(portal.ID)) + }) + return err } lc.trackReqSeq(reqSeq) - return client.SendChatRemoved(int64(reqSeq), string(portal.ID), "0", 0) + _, err := lc.callLine(ctx, func(client *line.Client) error { + return client.SendChatRemoved(int64(reqSeq), string(portal.ID), "0", 0) + }) + return err } // Compile-time assertion that LineClient handles Beeper message-request acceptance. @@ -833,9 +876,11 @@ var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*LineClient)(nil) // HandleMatrixAcceptMessageRequest is called when the user accepts a Request in Beeper (a // pending LINE group invitation). It accepts the invitation on the LINE side, joining the chat. func (lc *LineClient) HandleMatrixAcceptMessageRequest(ctx context.Context, msg *bridgev2.MatrixAcceptMessageRequest) error { - client := line.NewClient(lc.AccessToken) reqSeq := int64(time.Now().UnixMilli() % 1_000_000_000) - return client.AcceptChatInvitation(reqSeq, string(msg.Portal.ID)) + _, err := lc.callLine(ctx, func(client *line.Client) error { + return client.AcceptChatInvitation(reqSeq, string(msg.Portal.ID)) + }) + return err } func (lc *LineClient) buildMentionMetadata(ctx context.Context, body, formattedBody string, mentions *event.Mentions) map[string]string { diff --git a/pkg/connector/sync.go b/pkg/connector/sync.go index 9768215..603d7f3 100644 --- a/pkg/connector/sync.go +++ b/pkg/connector/sync.go @@ -31,11 +31,11 @@ const ( ) func (lc *LineClient) refreshBlockedContacts(ctx context.Context) ([]string, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() blockedMIDs, err := client.GetBlockedContactIds() if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() blockedMIDs, err = client.GetBlockedContactIds() } } @@ -114,7 +114,7 @@ func (lc *LineClient) saveBlockedContactsSnapshot(ctx context.Context) { func (lc *LineClient) syncDMChats(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newClient() opts := line.MessageBoxesOptions{ ActiveOnly: true, MessageBoxCountLimit: 100, @@ -125,7 +125,7 @@ func (lc *LineClient) syncDMChats(ctx context.Context) { res, err := client.GetMessageBoxes(opts) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() res, err = client.GetMessageBoxes(opts) } } @@ -265,11 +265,11 @@ func (lc *LineClient) FetchMessages(ctx context.Context, params bridgev2.FetchMe limit = 50 } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() msgs, err := client.GetRecentMessagesV2(chatMID, limit) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() msgs, err = client.GetRecentMessagesV2(chatMID, limit) } } @@ -331,7 +331,7 @@ func (lc *LineClient) FetchMessages(ctx context.Context, params bridgev2.FetchMe func (lc *LineClient) prefetchMessages(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newClient() opts := line.MessageBoxesOptions{ ActiveOnly: true, MessageBoxCountLimit: 100, @@ -342,7 +342,7 @@ func (lc *LineClient) prefetchMessages(ctx context.Context) { res, err := client.GetMessageBoxes(opts) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() res, err = client.GetMessageBoxes(opts) } } @@ -404,11 +404,11 @@ func (lc *LineClient) prefetchMessages(ctx context.Context) { // (live) message path. Used by prefetchMessages on startup. func (lc *LineClient) backfillRecentMessages(ctx context.Context, chatMID string, limit int) { start := time.Now() - client := line.NewClient(lc.AccessToken) + client := lc.newClient() msgs, err := client.GetRecentMessagesV2(chatMID, limit) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() msgs, err = client.GetRecentMessagesV2(chatMID, limit) } } @@ -450,11 +450,11 @@ func (lc *LineClient) backfillRecentMessages(ctx context.Context, chatMID string func (lc *LineClient) syncChats(ctx context.Context) { defer lc.wg.Done() - client := line.NewClient(lc.AccessToken) + client := lc.newClient() midsResp, err := client.GetAllChatMids(true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() midsResp, err = client.GetAllChatMids(true, true) } } @@ -478,7 +478,7 @@ func (lc *LineClient) syncChats(ctx context.Context) { chatsResp, err := client.GetChats(batch, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chatsResp, err = client.GetChats(batch, true, true) } } @@ -742,11 +742,11 @@ func (lc *LineClient) cacheGroupMembersFromRecentMessages(ctx context.Context, c if len(lc.getCachedGroupMembers(chatMid)) > 1 { return } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() msgs, err := client.GetRecentMessagesV2(chatMid, 50) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() msgs, err = client.GetRecentMessagesV2(chatMid, 50) } } @@ -841,13 +841,13 @@ func (lc *LineClient) pollLoop(ctx context.Context) { defer lc.wg.Done() var localRev int64 = 0 - client := line.NewClient(lc.AccessToken) + client := lc.newClient() lc.UserLogin.Bridge.Log.Info().Msg("Starting LINE SSE loop...") rev, err := client.GetLastOpRevision() if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() rev, err = client.GetLastOpRevision() } else { lc.UserLogin.Bridge.Log.Warn().Err(errRecover).Msg("Failed to recover token for getLastOpRevision") @@ -936,7 +936,7 @@ func (lc *LineClient) pollLoop(ctx context.Context) { }) return } - client = line.NewClient(lc.AccessToken) + client = lc.newClient() } } time.Sleep(3 * time.Second) @@ -1414,11 +1414,11 @@ func (lc *LineClient) clearReactionDedupEntries(msgID string, removeOnly bool) { func (lc *LineClient) syncSingleChat(ctx context.Context, op line.Operation) { chatMid := op.Param1 - client := line.NewClient(lc.AccessToken) + client := lc.newClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } @@ -1475,11 +1475,11 @@ func (lc *LineClient) syncSingleChat(ctx context.Context, op line.Operation) { // checkChatMembership calls GetAllChatMids to verify whether the bridge user // is a member or invitee of the given chat. func (lc *LineClient) checkChatMembership(ctx context.Context, chatMid string) (isMember, isInvitee bool) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() midsResp, err := client.GetAllChatMids(true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() midsResp, err = client.GetAllChatMids(true, true) } } @@ -1550,11 +1550,11 @@ func (lc *LineClient) handleMemberJoin(chatMid, joinerMid string) { } func (lc *LineClient) handleInvite(ctx context.Context, chatMid string, opType OperationType) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } @@ -1598,11 +1598,11 @@ func (lc *LineClient) handleInvite(ctx context.Context, chatMid string, opType O } func (lc *LineClient) handleInviteForSelf(ctx context.Context, chatMid string) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() chatsResp, err := client.GetChats([]string{chatMid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() chatsResp, err = client.GetChats([]string{chatMid}, true, true) } } diff --git a/pkg/connector/userinfo.go b/pkg/connector/userinfo.go index a6d4388..ffb7589 100644 --- a/pkg/connector/userinfo.go +++ b/pkg/connector/userinfo.go @@ -37,8 +37,10 @@ func (lc *LineClient) HandleMatrixReadReceipt(ctx context.Context, read *bridgev return nil } - client := line.NewClient(lc.AccessToken) - return client.SendChatChecked(string(read.Portal.ID), targetID) + _, err := lc.callLine(ctx, func(client *line.Client) error { + return client.SendChatChecked(string(read.Portal.ID), targetID) + }) + return err } func (lc *LineClient) GetCapabilities(ctx context.Context, portal *bridgev2.Portal) *event.RoomFeatures { @@ -140,11 +142,11 @@ func (lc *LineClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) mid := string(portal.ID) lowerMid := strings.ToLower(mid) if strings.HasPrefix(lowerMid, "c") || strings.HasPrefix(lowerMid, "r") { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() res, err := client.GetChats([]string{mid}, true, true) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() res, err = client.GetChats([]string{mid}, true, true) } } @@ -207,11 +209,11 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { // Use GetProfile for our own user data if mid == lc.Mid || mid == string(lc.UserLogin.ID) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() profile, err := client.GetProfile() if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() profile, err = client.GetProfile() } } @@ -223,11 +225,11 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { return line.Contact{Mid: mid, DisplayName: mid} } - client := line.NewClient(lc.AccessToken) + client := lc.newClient() res, err := client.GetContactsV2([]string{mid}) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() res, err = client.GetContactsV2([]string{mid}) } } @@ -243,7 +245,7 @@ func (lc *LineClient) getContact(ctx context.Context, mid string) line.Contact { buddy, err := client.GetBuddyProfile(mid) if err != nil && (lc.isRefreshRequired(err) || lc.isLoggedOut(err)) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() buddy, err = client.GetBuddyProfile(mid) } } @@ -314,7 +316,7 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev // Try by LINE user ID first lowerQuery := strings.ToLower(strings.TrimSpace(query)) if lowerQuery != "" { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() contact, err := client.FindContactByUserid(lowerQuery) if err == nil && contact != nil && contact.Mid != "" { if r := lc.midToResolveIdentifier(ctx, contact.Mid); r != nil { @@ -327,12 +329,12 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev } // Search contacts by display name - client := line.NewClient(lc.AccessToken) + client := lc.newClient() allMids, err := client.GetAllContactIds() if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() allMids, err = client.GetAllContactIds() } } @@ -373,12 +375,12 @@ func (lc *LineClient) SearchUsers(ctx context.Context, query string) ([]*bridgev var _ bridgev2.UserSearchingNetworkAPI = (*LineClient)(nil) func (lc *LineClient) GetContactList(ctx context.Context) ([]*bridgev2.ResolveIdentifierResponse, error) { - client := line.NewClient(lc.AccessToken) + client := lc.newClient() allMids, err := client.GetAllContactIds() if err != nil { if lc.isRefreshRequired(err) || lc.isLoggedOut(err) { if errRecover := lc.recoverToken(ctx); errRecover == nil { - client = line.NewClient(lc.AccessToken) + client = lc.newClient() allMids, err = client.GetAllContactIds() } } diff --git a/pkg/line/errors.go b/pkg/line/errors.go index 411c33b..438bef2 100644 --- a/pkg/line/errors.go +++ b/pkg/line/errors.go @@ -13,6 +13,45 @@ var ( ErrGroupKeyNotFound = errors.New("group key not found") ) +// IsRefreshRequired returns true when LINE reports that the access token must +// be refreshed before the request can be retried. +func IsRefreshRequired(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "\"code\":119") || + strings.Contains(msg, "access token refresh required") +} + +func IsLoggedOut(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "V3_TOKEN_CLIENT_LOGGED_OUT") +} + +func IsUnauthorizedStatus(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "api error 401") || + strings.Contains(msg, "api error 403") || + strings.Contains(msg, "http 401") || + strings.Contains(msg, "http 403") || + strings.Contains(msg, "sse error: 401") || + strings.Contains(msg, "sse error: 403") || + strings.Contains(msg, "obs upload failed (401)") || + strings.Contains(msg, "obs upload failed (403)") || + strings.Contains(msg, "obs download failed (401)") || + strings.Contains(msg, "obs download failed (403)") +} + +func IsAuthError(err error) bool { + return IsRefreshRequired(err) || IsLoggedOut(err) || IsUnauthorizedStatus(err) +} + // IsNoUsableE2EEPublicKey returns true when a peer has Letter Sealing disabled // (negotiateE2EEPublicKey returns empty allowedTypes / specVersion -1, or no key data). func IsNoUsableE2EEPublicKey(err error) bool { diff --git a/pkg/line/errors_test.go b/pkg/line/errors_test.go new file mode 100644 index 0000000..e907186 --- /dev/null +++ b/pkg/line/errors_test.go @@ -0,0 +1,93 @@ +package line + +import ( + "errors" + "testing" +) + +func TestIsRefreshRequired(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "talk exception code 119", + err: errors.New(`API error 400: {"code":10051,"message":"RESPONSE_ERROR","data":{"name":"TalkException","code":119,"reason":"Access token refresh required"}}`), + want: true, + }, + { + name: "refresh text", + err: errors.New("Access token refresh required"), + want: true, + }, + { + name: "refresh text lower case", + err: errors.New("access token refresh required"), + want: true, + }, + { + name: "other talk exception", + err: errors.New(`API error 400: {"code":10051,"data":{"name":"TalkException","code":10,"reason":"not a member"}}`), + want: false, + }, + { + name: "nil", + err: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsRefreshRequired(tt.err); got != tt.want { + t.Fatalf("IsRefreshRequired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsLoggedOut(t *testing.T) { + if !IsLoggedOut(errors.New("V3_TOKEN_CLIENT_LOGGED_OUT")) { + t.Fatal("expected logged-out error to be detected") + } + if IsLoggedOut(errors.New("Access token refresh required")) { + t.Fatal("refresh-required error should not be classified as logged out") + } +} + +func TestIsUnauthorizedStatus(t *testing.T) { + for _, err := range []error{ + errors.New("API error 401: unauthorized"), + errors.New("API error 403: forbidden"), + errors.New("HTTP 401: unauthorized"), + errors.New("HTTP 403: forbidden"), + errors.New("SSE error: 401"), + errors.New("SSE error: 403"), + errors.New("OBS upload failed (401): unauthorized"), + errors.New("OBS upload failed (403): forbidden"), + errors.New("OBS download failed (401): unauthorized"), + errors.New("OBS download failed (403): forbidden"), + errors.New("api error 401: unauthorized"), + errors.New("http 403: forbidden"), + errors.New("sse ERROR: 401"), + errors.New("obs DOWNLOAD failed (403): forbidden"), + } { + if !IsUnauthorizedStatus(err) { + t.Fatalf("expected %q to be unauthorized", err) + } + } + + if IsUnauthorizedStatus(errors.New("HTTP 404: not found")) { + t.Fatal("404 should not be classified as unauthorized") + } + for _, err := range []error{ + errors.New("request failed: dial tcp: i/o timeout"), + errors.New("OBS upload request failed: connection reset by peer"), + errors.New("OBS download request failed: context deadline exceeded"), + } { + if IsUnauthorizedStatus(err) { + t.Fatalf("network error %q should not be classified as unauthorized", err) + } + } +}