diff --git a/apps/tradinggoose/.env.example b/apps/tradinggoose/.env.example index 6e0e3fa37..c12be0f53 100644 --- a/apps/tradinggoose/.env.example +++ b/apps/tradinggoose/.env.example @@ -12,8 +12,7 @@ # - Configure system-managed integration OAuth apps in Admin Integrations instead: # google-drive, google-email, github-repo, microsoft, slack, reddit, etc. # - Platform-managed vars like NODE_ENV, NEXT_RUNTIME, and VERCEL are omitted on purpose. -# - Internal React Email preview vars like EMAILS_DIR_ABSOLUTE_PATH and -# PREVIEW_SERVER_LOCATION are also omitted on purpose. +# - Internal React Email preview vars are omitted on purpose. # # Notes: # - Boolean-like flags should use true/false unless noted otherwise. @@ -232,14 +231,6 @@ KB_CONFIG_DELAY_BETWEEN_DOCUMENTS="50" # NEXT_PUBLIC_POSTHOG_KEY="" NEXT_PUBLIC_POSTHOG_DISABLED="1" -############################################################################### -# Email preview / local tooling -############################################################################### - -# Optional: only used by React Email preview and similar local tooling when -# NEXT_PUBLIC_APP_URL is not available in that process. -# EMAILS_PREVIEW_BASE_URL="http://localhost:3000" - ############################################################################### # Deployment and local infrastructure helpers ############################################################################### @@ -255,6 +246,3 @@ NEXT_PUBLIC_POSTHOG_DISABLED="1" # Optional: tell the app it is running inside a Docker image build. DOCKER_BUILD="false" - -# Optional: landing-page preview/dev helper used by getBaseUrl(). -# NEXT_PUBLIC_IS_PREVIEW_DEVELOPMENT="false" diff --git a/apps/tradinggoose/app/(auth)/auth-locale-redirects.test.tsx b/apps/tradinggoose/app/(auth)/auth-locale-redirects.test.tsx index 3d7499e60..e2e29dacf 100644 --- a/apps/tradinggoose/app/(auth)/auth-locale-redirects.test.tsx +++ b/apps/tradinggoose/app/(auth)/auth-locale-redirects.test.tsx @@ -15,6 +15,7 @@ import { VerifyContent } from './verify/verify-content' const mockPush = vi.hoisted(() => vi.fn()) const mockSignUpEmail = vi.hoisted(() => vi.fn()) const mockSignInEmail = vi.hoisted(() => vi.fn()) +const mockSignOut = vi.hoisted(() => vi.fn()) const mockSendVerificationOtp = vi.hoisted(() => vi.fn()) const mockRefetchSession = vi.hoisted(() => vi.fn()) const mockUseVerification = vi.hoisted(() => vi.fn()) @@ -45,6 +46,7 @@ vi.mock('@/i18n/navigation', () => ({ useRouter: () => ({ push: mockPush, }), + usePathname: () => '/login', })) vi.mock('@/lib/auth-client', () => ({ @@ -55,6 +57,7 @@ vi.mock('@/lib/auth-client', () => ({ signIn: { email: mockSignInEmail, }, + signOut: mockSignOut, emailOtp: { sendVerificationOtp: mockSendVerificationOtp, }, @@ -107,7 +110,11 @@ vi.mock('@/components/ui/label', () => ({ ...props }: React.LabelHTMLAttributes & { children?: React.ReactNode - }) => , + }) => ( + + ), })) vi.mock('@/components/ui/dialog', () => ({ @@ -149,6 +156,7 @@ describe('auth locale redirects', () => { mockPush.mockReset() mockSignUpEmail.mockReset() mockSignInEmail.mockReset() + mockSignOut.mockReset() mockSendVerificationOtp.mockReset() mockRefetchSession.mockReset() mockUseVerification.mockReset() @@ -162,6 +170,7 @@ describe('auth locale redirects', () => { root.unmount() }) container.remove() + vi.useRealTimers() vi.restoreAllMocks() reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = false }) @@ -204,6 +213,18 @@ describe('auth locale redirects', () => { }) } + async function renderLogin(locale: 'en' | 'es' | 'zh' = 'en') { + await renderWithLocale( + locale, + + ) + } + it.each(['es', 'zh'] as const)( 'pushes the canonical verify path after signup for %s', async (locale) => { @@ -226,6 +247,8 @@ describe('auth locale redirects', () => { await setInputValue('#password', 'Password1!') await submitRenderedForm() + expect(mockRefetchSession).not.toHaveBeenCalled() + expect(mockFetch).not.toHaveBeenCalled() expect(mockPush).toHaveBeenCalledWith('/verify?fromSignup=true') } ) @@ -235,15 +258,7 @@ describe('auth locale redirects', () => { async (locale) => { mockSignInEmail.mockRejectedValue({ code: 'EMAIL_NOT_VERIFIED' }) - await renderWithLocale( - locale, - - ) + await renderLogin(locale) await setInputValue('#email', 'ada@example.com') await setInputValue('#password', 'Password1!') @@ -253,6 +268,75 @@ describe('auth locale redirects', () => { } ) + it('runs reauth cleanup on arrival and waits before direct login starts', async () => { + vi.useFakeTimers() + testState.searchParams = new URLSearchParams('reauth=1&callbackUrl=%2Fworkspace') + const cleanupSignalRef: { current: AbortSignal | null } = { current: null } + mockSignOut.mockImplementation((options) => { + cleanupSignalRef.current = options?.fetchOptions?.signal ?? null + return new Promise(() => {}) + }) + mockSignInEmail.mockResolvedValue({}) + + await renderLogin() + + expect(mockSignOut).toHaveBeenCalledTimes(1) + expect(container.querySelector('form')).toBeInstanceOf(HTMLFormElement) + + await setInputValue('#email', 'ada@example.com') + await setInputValue('#password', 'Password1!') + + await submitRenderedForm() + + expect(mockSignInEmail).not.toHaveBeenCalled() + + await act(async () => { + await vi.runOnlyPendingTimersAsync() + await Promise.resolve() + }) + + expect(cleanupSignalRef.current?.aborted).toBe(true) + expect(mockSignInEmail).toHaveBeenCalledTimes(1) + }) + + it.each([ + 'FAILED_TO_CREATE_SESSION', + 'UNABLE_TO_CREATE_SESSION', + 'FAILED_TO_GET_SESSION', + 'SESSION_EXPIRED', + ])('runs reauth cleanup when direct login returns %s', async (errorCode) => { + mockSignInEmail.mockResolvedValue({ error: { code: errorCode } }) + mockSignOut.mockReturnValue(new Promise(() => {})) + + await renderLogin() + + await setInputValue('#email', 'ada@example.com') + await setInputValue('#password', 'Password1!') + await submitRenderedForm() + + expect(mockSignInEmail).toHaveBeenCalledTimes(1) + expect(mockSignOut).toHaveBeenCalledTimes(1) + expect(container.textContent).toContain(getPublicCopy('en').auth.login.errors.unableToSignInNow) + }) + + it('keeps invalid credential failures on the login form', async () => { + mockSignInEmail.mockResolvedValue({ + error: { code: 'INVALID_CREDENTIALS', status: 401 }, + }) + + await renderLogin() + + await setInputValue('#email', 'ada@example.com') + await setInputValue('#password', 'wrong-password') + await submitRenderedForm() + + expect(mockSignOut).not.toHaveBeenCalled() + expect(container.querySelector('form')).toBeInstanceOf(HTMLFormElement) + expect(container.textContent).toContain( + getPublicCopy('en').auth.login.errors.invalidCredentials + ) + }) + it('pushes the canonical signup path from the verify screen back action', async () => { mockUseVerification.mockReturnValue({ otp: '', @@ -272,11 +356,7 @@ describe('auth locale redirects', () => { await renderWithLocale( 'en', - + ) const backButton = Array.from(container.querySelectorAll('button')).find( diff --git a/apps/tradinggoose/app/(auth)/auth-provider-callbacks.test.tsx b/apps/tradinggoose/app/(auth)/auth-provider-callbacks.test.tsx new file mode 100644 index 000000000..4fa491029 --- /dev/null +++ b/apps/tradinggoose/app/(auth)/auth-provider-callbacks.test.tsx @@ -0,0 +1,198 @@ +/** + * @vitest-environment jsdom + */ + +import type React from 'react' +import { act } from 'react' +import { NextIntlClientProvider } from 'next-intl' +import { createRoot, type Root } from 'react-dom/client' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { getAuthErrorCallbackPath } from '@/lib/auth/auth-error-copy' +import { getPublicCopy } from '@/i18n/public-copy' +import { SocialLoginButtons } from './components/social-login-buttons' +import SSOForm from './sso/sso-form' + +const mockSocialSignIn = vi.hoisted(() => vi.fn()) +const mockSsoSignIn = vi.hoisted(() => vi.fn()) +const testState = vi.hoisted(() => ({ + searchParams: new URLSearchParams(), +})) + +vi.mock('next/navigation', () => ({ + useSearchParams: () => ({ + get: (key: string) => testState.searchParams.get(key), + }), +})) + +vi.mock('@/i18n/navigation', () => ({ + Link: ({ + children, + href, + ...props + }: React.AnchorHTMLAttributes & { + children?: React.ReactNode + href: string + }) => ( + + {children} + + ), +})) + +vi.mock('@/lib/auth-client', () => ({ + client: { + signIn: { + social: mockSocialSignIn, + sso: mockSsoSignIn, + }, + }, +})) + +vi.mock('@/components/ui/button', () => ({ + Button: ({ + children, + ...props + }: React.ButtonHTMLAttributes & { + children?: React.ReactNode + }) => , +})) + +vi.mock('@/components/ui/input', () => ({ + Input: (props: React.InputHTMLAttributes) => , +})) + +vi.mock('@/components/ui/label', () => ({ + Label: ({ + children, + ...props + }: React.LabelHTMLAttributes & { + children?: React.ReactNode + }) => ( + + ), +})) + +vi.mock('@/components/ui/alert', () => ({ + Alert: ({ children }: { children?: React.ReactNode }) =>
{children}
, + AlertDescription: ({ children }: { children?: React.ReactNode }) =>
{children}
, +})) + +vi.mock('@/components/icons/icons', () => ({ + GithubIcon: () => , + GoogleIcon: () => , +})) + +vi.mock('@/app/(auth)/components/auth-page-header', () => ({ + AuthPageHeader: () => null, +})) + +vi.mock('@/app/(auth)/components/auth-waitlist-note', () => ({ + AuthWaitlistNote: () => null, +})) + +vi.mock('@/app/fonts/inter', () => ({ + inter: { className: '' }, +})) + +describe('auth provider callback routing', () => { + let container: HTMLDivElement + let root: Root + const reactActEnvironment = globalThis as typeof globalThis & { + IS_REACT_ACT_ENVIRONMENT?: boolean + } + + beforeEach(() => { + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true + testState.searchParams = new URLSearchParams() + mockSocialSignIn.mockResolvedValue({}) + mockSsoSignIn.mockResolvedValue({}) + container = document.createElement('div') + document.body.appendChild(container) + root = createRoot(container) + }) + + afterEach(() => { + act(() => { + root.unmount() + }) + container.remove() + vi.clearAllMocks() + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = false + }) + + it.each([ + ['Google', 'google'], + ['GitHub', 'github'], + ])('routes %s OAuth callback failures to the auth error page', async (buttonText, provider) => { + await act(async () => { + root.render( + + + + ) + }) + + const button = Array.from(container.querySelectorAll('button')).find((candidate) => + candidate.textContent?.includes(buttonText) + ) + if (!(button instanceof HTMLButtonElement)) { + throw new Error(`Expected ${buttonText} button to render`) + } + + await act(async () => { + button.click() + }) + + expect(mockSocialSignIn).toHaveBeenCalledWith({ + provider, + callbackURL: '/workspace', + errorCallbackURL: getAuthErrorCallbackPath('/workspace'), + }) + }) + + it('routes SSO callback failures to the auth error page', async () => { + testState.searchParams = new URLSearchParams({ callbackUrl: '/workspace' }) + + await act(async () => { + root.render( + + + + ) + }) + + const input = container.querySelector('input[name="email"]') + if (!(input instanceof HTMLInputElement)) { + throw new Error('Expected SSO email input to render') + } + + const valueSetter = Object.getOwnPropertyDescriptor(HTMLInputElement.prototype, 'value')?.set + valueSetter?.call(input, 'user@example.com') + + await act(async () => { + input.dispatchEvent(new Event('input', { bubbles: true })) + }) + + const form = container.querySelector('form') + if (!(form instanceof HTMLFormElement)) { + throw new Error('Expected SSO form to render') + } + + await act(async () => { + form.dispatchEvent(new Event('submit', { bubbles: true, cancelable: true })) + }) + + expect(mockSsoSignIn).toHaveBeenCalledWith({ + email: 'user@example.com', + callbackURL: '/workspace', + errorCallbackURL: getAuthErrorCallbackPath('/workspace'), + }) + }) +}) diff --git a/apps/tradinggoose/app/(auth)/components/social-login-buttons.tsx b/apps/tradinggoose/app/(auth)/components/social-login-buttons.tsx index 2bcfeddbf..c7cc886f4 100644 --- a/apps/tradinggoose/app/(auth)/components/social-login-buttons.tsx +++ b/apps/tradinggoose/app/(auth)/components/social-login-buttons.tsx @@ -1,6 +1,7 @@ 'use client' import { type ReactNode, useEffect, useState } from 'react' +import { useMessages } from 'next-intl' import { GithubIcon, GoogleIcon } from '@/components/icons/icons' import { Alert, AlertDescription } from '@/components/ui/alert' import { Button } from '@/components/ui/button' @@ -8,7 +9,6 @@ import { useAuthRedirectUrls } from '@/lib/auth/redirect-urls' import { client } from '@/lib/auth-client' import { createLogger } from '@/lib/logs/console/logger' import { inter } from '@/app/fonts/inter' -import { useMessages } from 'next-intl' import { formatTemplate } from '@/i18n/utils' const logger = createLogger('SocialLoginButtons') @@ -18,6 +18,7 @@ interface SocialLoginButtonsProps { googleAvailable: boolean callbackURL?: string isProduction: boolean + beforeSignIn?: () => Promise children?: ReactNode } @@ -26,6 +27,7 @@ export function SocialLoginButtons({ googleAvailable, callbackURL, isProduction: _isProduction, + beforeSignIn, children, }: SocialLoginButtonsProps) { const [isGithubLoading, setIsGithubLoading] = useState(false) @@ -36,6 +38,7 @@ export function SocialLoginButtons({ const copy = useMessages() const socialCopy = copy.auth.social const resolvedCallbackURL = authRedirectUrls.providerCallbackPath(callbackURL) + const errorCallbackURL = authRedirectUrls.providerErrorPath(resolvedCallbackURL) useEffect(() => { setMounted(true) @@ -82,9 +85,11 @@ export function SocialLoginButtons({ setIsGithubLoading(true) setErrorMessage('') try { + await beforeSignIn?.() const result = await client.signIn.social({ provider: 'github', callbackURL: resolvedCallbackURL, + errorCallbackURL, }) if (result?.error) { @@ -105,9 +110,11 @@ export function SocialLoginButtons({ setIsGoogleLoading(true) setErrorMessage('') try { + await beforeSignIn?.() const result = await client.signIn.social({ provider: 'google', callbackURL: resolvedCallbackURL, + errorCallbackURL, }) if (result?.error) { diff --git a/apps/tradinggoose/app/(auth)/components/sso-login-button.tsx b/apps/tradinggoose/app/(auth)/components/sso-login-button.tsx index dd2e79e67..e1fccf425 100644 --- a/apps/tradinggoose/app/(auth)/components/sso-login-button.tsx +++ b/apps/tradinggoose/app/(auth)/components/sso-login-button.tsx @@ -1,15 +1,16 @@ 'use client' +import { useMessages } from 'next-intl' import { Button } from '@/components/ui/button' import { getEnv, isTruthy } from '@/lib/env' import { cn } from '@/lib/utils' -import { useMessages } from 'next-intl' import { useRouter } from '@/i18n/navigation' import { normalizeCallbackUrl } from '@/i18n/utils' interface SSOLoginButtonProps { callbackURL?: string className?: string + beforeSignIn?: () => Promise // Visual variant for button styling and placement contexts // - 'primary' matches the main auth action button style // - 'outline' matches social provider buttons @@ -19,6 +20,7 @@ interface SSOLoginButtonProps { export function SSOLoginButton({ callbackURL, className, + beforeSignIn, variant = 'outline', }: SSOLoginButtonProps) { const router = useRouter() @@ -31,7 +33,8 @@ export function SSOLoginButton({ const resolvedCallbackURL = callbackURL ? normalizeCallbackUrl(callbackURL) : undefined - const handleSSOClick = () => { + const handleSSOClick = async () => { + await beforeSignIn?.() const ssoUrl = `/sso${ resolvedCallbackURL ? `?callbackUrl=${encodeURIComponent(resolvedCallbackURL)}` : '' }` diff --git a/apps/tradinggoose/app/(auth)/login/login-form.tsx b/apps/tradinggoose/app/(auth)/login/login-form.tsx index 226ff2c56..f207976cf 100644 --- a/apps/tradinggoose/app/(auth)/login/login-form.tsx +++ b/apps/tradinggoose/app/(auth)/login/login-form.tsx @@ -1,8 +1,9 @@ 'use client' -import { useEffect, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { Eye, EyeOff } from 'lucide-react' import { useSearchParams } from 'next/navigation' +import { useMessages } from 'next-intl' import { Button } from '@/components/ui/button' import { Dialog, @@ -13,8 +14,7 @@ import { } from '@/components/ui/dialog' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' -import { normalizeAuthErrorCode } from '@/lib/auth/auth-error-copy' -import { handleAuthError } from '@/lib/auth/auth-error-handler' +import { isSessionRecoveryAuthError, normalizeAuthErrorCode } from '@/lib/auth/auth-error-copy' import { useAuthRedirectUrls } from '@/lib/auth/redirect-urls' import { client } from '@/lib/auth-client' import { quickValidateEmail } from '@/lib/email/validation' @@ -27,11 +27,12 @@ import { AuthWaitlistNote } from '@/app/(auth)/components/auth-waitlist-note' import { SocialLoginButtons } from '@/app/(auth)/components/social-login-buttons' import { SSOLoginButton } from '@/app/(auth)/components/sso-login-button' import { inter } from '@/app/fonts/inter' -import { useMessages } from 'next-intl' import { Link, useRouter } from '@/i18n/navigation' import { normalizeCallbackUrl } from '@/i18n/utils' +import { clearUserData } from '@/stores' const logger = createLogger('LoginForm') +const REAUTH_CLEANUP_TIMEOUT_MS = 4000 const validateEmailField = ( emailValue: string, @@ -122,6 +123,50 @@ export default function LoginPage({ const [email, setEmail] = useState('') const [emailErrors, setEmailErrors] = useState([]) const [showEmailValidationError, setShowEmailValidationError] = useState(false) + const isReauth = searchParams.get('reauth') === '1' + const shouldRunReauthCleanupRef = useRef(isReauth) + const reauthCleanupPromiseRef = useRef | null>(null) + + const runReauthCleanup = useCallback(() => { + if (reauthCleanupPromiseRef.current) { + return reauthCleanupPromiseRef.current + } + + const abortController = new AbortController() + let timeoutId: ReturnType | null = null + const signOutPromise = client + .signOut({ fetchOptions: { signal: abortController.signal } }) + .then(() => undefined) + .catch((error) => { + if (!abortController.signal.aborted) { + logger.warn('Reauth sign-out failed', { error }) + } + }) + const timeoutPromise = new Promise((resolve) => { + timeoutId = setTimeout(() => { + abortController.abort() + resolve() + }, REAUTH_CLEANUP_TIMEOUT_MS) + }) + + const cleanupPromise = Promise.race([signOutPromise, timeoutPromise]).finally(async () => { + if (timeoutId) { + clearTimeout(timeoutId) + } + await clearUserData() + shouldRunReauthCleanupRef.current = false + reauthCleanupPromiseRef.current = null + }) + + reauthCleanupPromiseRef.current = cleanupPromise + return cleanupPromise + }, []) + + const prepareAuthStart = useCallback(async () => { + if (shouldRunReauthCleanupRef.current || reauthCleanupPromiseRef.current) { + await runReauthCleanup() + } + }, [runReauthCleanup]) useEffect(() => { if (searchParams) { @@ -144,6 +189,13 @@ export default function LoginPage({ } }, [searchParams]) + useEffect(() => { + shouldRunReauthCleanupRef.current = isReauth + if (isReauth) { + void runReauthCleanup() + } + }, [isReauth, runReauthCleanup]) + useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.key === 'Enter' && forgotPasswordOpen) { @@ -181,6 +233,17 @@ export default function LoginPage({ setShowValidationError(false) } + const isSessionRecoveryError = (error: any) => + [ + error?.code, + error?.error, + error?.message, + error?.response?.data?.error, + error?.response?.data?.message, + ].some((value) => { + return isSessionRecoveryAuthError(value) + }) + const resolveLoginErrorMessage = (error: any) => { const rawMessage = error?.message ?? @@ -227,9 +290,6 @@ export default function LoginPage({ if (authErrorCode === 'EMAIL_PASSWORD_DISABLED') { return loginCopy.errors.emailPasswordDisabled } - if (authErrorCode === 'FAILED_TO_CREATE_SESSION') { - return loginCopy.errors.failedToCreateSession - } if (authErrorCode === 'TOO_MANY_ATTEMPTS' || searchable.includes('too many attempts')) { return loginCopy.errors.tooManyAttempts } @@ -278,6 +338,9 @@ export default function LoginPage({ } try { + await prepareAuthStart() + + let requiresReauthCleanup = false const result = await client.signIn.email( { email, @@ -287,24 +350,18 @@ export default function LoginPage({ { onError: (ctx) => { console.error('Login error:', ctx.error) + if (isSessionRecoveryError(ctx.error)) { + requiresReauthCleanup = true + return + } + const errorMessage: string[] = [] const resolvedMessage = resolveLoginErrorMessage(ctx.error) - const status = - (ctx.error as any)?.status ?? - (ctx.error as any)?.statusCode ?? - (ctx.error as any)?.response?.status - if (resolvedMessage === null) { return } - // If the backend rejected the request due to an invalid/expired auth state, hard reset auth. - if (status === 401) { - handleAuthError('login-unauthorized').catch(() => {}) - errorMessage.push(loginCopy.errors.sessionExpired) - } - if (resolvedMessage) { errorMessage.push(resolvedMessage) } @@ -320,6 +377,15 @@ export default function LoginPage({ ) if (!result || result.error) { + if (requiresReauthCleanup || isSessionRecoveryError(result?.error)) { + shouldRunReauthCleanupRef.current = true + void runReauthCleanup() + setPasswordErrors([loginCopy.errors.unableToSignInNow]) + setShowValidationError(true) + setIsLoading(false) + return + } + const message = resolveLoginErrorMessage(result?.error) ?? loginCopy.errors.unableToSignInNow @@ -336,6 +402,13 @@ export default function LoginPage({ router.push('/verify') return } + if (isSessionRecoveryError(err)) { + shouldRunReauthCleanupRef.current = true + void runReauthCleanup() + setPasswordErrors([loginCopy.errors.unableToSignInNow]) + setShowValidationError(true) + return + } console.error('Uncaught login error:', err) } finally { @@ -551,8 +624,15 @@ export default function LoginPage({ githubAvailable={githubAvailable} isProduction={isProduction} callbackURL={callbackUrl} + beforeSignIn={prepareAuthStart} > - {ssoEnabled && } + {ssoEnabled && ( + + )} )} diff --git a/apps/tradinggoose/app/(auth)/signup/signup-form.tsx b/apps/tradinggoose/app/(auth)/signup/signup-form.tsx index 9e6f5325d..5e5461118 100644 --- a/apps/tradinggoose/app/(auth)/signup/signup-form.tsx +++ b/apps/tradinggoose/app/(auth)/signup/signup-form.tsx @@ -7,7 +7,7 @@ import { useLocale } from 'next-intl' import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' -import { client, useSession } from '@/lib/auth-client' +import { client } from '@/lib/auth-client' import { quickValidateEmail } from '@/lib/email/validation' import { getEnv, isTruthy } from '@/lib/env' import { createLogger } from '@/lib/logs/console/logger' @@ -90,7 +90,6 @@ function SignupFormContent({ const signupCopy = copy.auth.signup const defaultCallbackPath = '/workspace' const searchParams = useSearchParams() - const { refetch: refetchSession } = useSession() const [isLoading, setIsLoading] = useState(false) const [, setMounted] = useState(false) const [showPassword, setShowPassword] = useState(false) @@ -348,21 +347,6 @@ function SignupFormContent({ return } - try { - await refetchSession() - const localeResponse = await fetch('/api/users/me/settings', { - method: 'PATCH', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ preferredLocale: locale }), - }) - if (!localeResponse.ok) { - throw new Error('Failed to persist preferred locale after signup') - } - logger.info('Session refreshed after successful signup') - } catch (sessionError) { - logger.error('Failed to refresh session or persist locale after signup:', sessionError) - } - if (typeof window !== 'undefined') { sessionStorage.setItem('verificationEmail', emailValue) if (isInviteFlow && redirectUrl) { diff --git a/apps/tradinggoose/app/(auth)/sso/sso-form.tsx b/apps/tradinggoose/app/(auth)/sso/sso-form.tsx index 399d41d26..590098f1e 100644 --- a/apps/tradinggoose/app/(auth)/sso/sso-form.tsx +++ b/apps/tradinggoose/app/(auth)/sso/sso-form.tsx @@ -2,22 +2,22 @@ import { useEffect, useState } from 'react' import { useSearchParams } from 'next/navigation' +import { useMessages } from 'next-intl' import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' +import { resolveSsoAuthErrorMessage } from '@/lib/auth/auth-error-copy' import { useAuthRedirectUrls } from '@/lib/auth/redirect-urls' import { client } from '@/lib/auth-client' import { quickValidateEmail } from '@/lib/email/validation' -import { normalizeAuthErrorCode } from '@/lib/auth/auth-error-copy' import { createLogger } from '@/lib/logs/console/logger' import { getAuthRegistrationHref, type RegistrationMode } from '@/lib/registration/shared' import { cn } from '@/lib/utils' -import { Link } from '@/i18n/navigation' -import { useMessages } from 'next-intl' -import { normalizeCallbackUrl } from '@/i18n/utils' import { AuthPageHeader } from '@/app/(auth)/components/auth-page-header' import { AuthWaitlistNote } from '@/app/(auth)/components/auth-waitlist-note' import { inter } from '@/app/fonts/inter' +import { Link } from '@/i18n/navigation' +import { normalizeCallbackUrl } from '@/i18n/utils' const logger = createLogger('SSOForm') @@ -81,24 +81,8 @@ export default function SSOForm({ registrationMode }: { registrationMode: Regist if (emailParam) { setEmail(emailParam) } - - const error = searchParams.get('error') - if (error) { - const errorMessages: Record = { - account_not_found: ssoCopy.errors.accountNotFound, - sso_failed: ssoCopy.errors.ssoFailed, - invalid_provider: ssoCopy.errors.providerNotConfigured, - } - setEmailErrors([errorMessages[error] || ssoCopy.errors.ssoFailed]) - setShowEmailValidationError(true) - } } - }, [ - searchParams, - ssoCopy.errors.accountNotFound, - ssoCopy.errors.providerNotConfigured, - ssoCopy.errors.ssoFailed, - ]) + }, [searchParams]) const handleEmailChange = (e: React.ChangeEvent) => { const newEmail = e.target.value @@ -136,32 +120,14 @@ export default function SSOForm({ registrationMode }: { registrationMode: Regist await client.signIn.sso({ email: emailValue, callbackURL: authRedirectUrls.providerCallbackPath(callbackUrl), - errorCallbackURL: authRedirectUrls.providerErrorPath( - `/sso?error=sso_failed&callbackUrl=${callbackUrlParam}` - ), + errorCallbackURL: authRedirectUrls.providerErrorPath(callbackUrl), }) } catch (err) { logger.error('SSO sign-in failed', { error: err, email: emailValue }) - const authErrorCode = err instanceof Error ? normalizeAuthErrorCode(err.message) : null - - let errorMessage = ssoCopy.errors.failed - if (err instanceof Error) { - if (authErrorCode === 'NO_PROVIDER_FOUND' || authErrorCode === 'INVALID_PROVIDER') { - errorMessage = ssoCopy.errors.providerNotConfigured - } else if (authErrorCode === 'INVALID_EMAIL_DOMAIN') { - errorMessage = ssoCopy.errors.invalidEmailDomain - } else if (authErrorCode === 'NETWORK_ERROR') { - errorMessage = ssoCopy.errors.network - } else if (authErrorCode === 'RATE_LIMIT' || authErrorCode === 'TOO_MANY_REQUESTS') { - errorMessage = ssoCopy.errors.rateLimit - } else if (authErrorCode === 'SSO_DISABLED') { - errorMessage = ssoCopy.errors.ssoDisabled - } else { - errorMessage = ssoCopy.errors.failed - } - } + const errorMessage = + err instanceof Error ? resolveSsoAuthErrorMessage(copy, err.message) : null - setEmailErrors([errorMessage]) + setEmailErrors([errorMessage ?? ssoCopy.errors.failed]) setShowEmailValidationError(true) setIsLoading(false) } @@ -252,14 +218,14 @@ export default function SSOForm({ registrationMode }: { registrationMode: Regist )}
{commonCopy.termsLeadSigningIn}{' '} {commonCopy.termsOfService} {' '} @@ -268,7 +234,7 @@ export default function SSOForm({ registrationMode }: { registrationMode: Regist href='/privacy' target='_blank' rel='noopener noreferrer' - className='hover:text-primary underline underline-offset-4' + className='underline underline-offset-4 hover:text-primary' > {commonCopy.privacyPolicy} diff --git a/apps/tradinggoose/app/(auth)/verify/use-verification.test.tsx b/apps/tradinggoose/app/(auth)/verify/use-verification.test.tsx index a14cc5538..ff5e5037d 100644 --- a/apps/tradinggoose/app/(auth)/verify/use-verification.test.tsx +++ b/apps/tradinggoose/app/(auth)/verify/use-verification.test.tsx @@ -13,6 +13,7 @@ const mockPush = vi.hoisted(() => vi.fn()) const mockEmailOtpSignIn = vi.hoisted(() => vi.fn()) const mockSendVerificationOtp = vi.hoisted(() => vi.fn()) const mockRefetchSession = vi.hoisted(() => vi.fn()) +const mockFetch = vi.hoisted(() => vi.fn()) const testState = vi.hoisted(() => ({ searchParams: new URLSearchParams(), })) @@ -122,8 +123,11 @@ describe('useVerification', () => { mockEmailOtpSignIn.mockReset() mockSendVerificationOtp.mockReset() mockRefetchSession.mockReset() + mockFetch.mockReset() + mockFetch.mockResolvedValue(new Response('{}', { status: 200 })) testState.searchParams = new URLSearchParams() window.history.replaceState({}, '', '/') + global.fetch = mockFetch as typeof fetch }) afterEach(() => { @@ -172,6 +176,15 @@ describe('useVerification', () => { email: 'ada@example.com', otp: '123456', }) + expect(mockFetch).toHaveBeenCalledWith( + '/api/users/me/settings', + expect.objectContaining({ + method: 'PATCH', + credentials: 'include', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ preferredLocale: 'zh' }), + }) + ) expect(mockPush).toHaveBeenCalledWith('/workspace') }) @@ -198,4 +211,34 @@ describe('useVerification', () => { expect(mockPush).toHaveBeenCalledWith('/workspace/ws-1/dashboard') }) + + it('persists locale before redirecting when verification is disabled', async () => { + mockRefetchSession.mockResolvedValue(undefined) + + await act(async () => { + root.render( + + {}} + /> + + ) + }) + + await act(async () => {}) + + expect(mockRefetchSession).toHaveBeenCalled() + expect(mockFetch).toHaveBeenCalledWith( + '/api/users/me/settings', + expect.objectContaining({ + method: 'PATCH', + credentials: 'include', + body: JSON.stringify({ preferredLocale: 'es' }), + }) + ) + expect(mockPush).toHaveBeenCalledWith('/workspace') + }) }) diff --git a/apps/tradinggoose/app/(auth)/verify/use-verification.ts b/apps/tradinggoose/app/(auth)/verify/use-verification.ts index bfed0ca5a..769bdb22a 100644 --- a/apps/tradinggoose/app/(auth)/verify/use-verification.ts +++ b/apps/tradinggoose/app/(auth)/verify/use-verification.ts @@ -2,11 +2,12 @@ import { useEffect, useState } from 'react' import { useSearchParams } from 'next/navigation' +import { useLocale } from 'next-intl' import { useRouter } from '@/i18n/navigation' import { normalizeAuthErrorCode } from '@/lib/auth/auth-error-copy' import { client, useSession } from '@/lib/auth-client' import { createLogger } from '@/lib/logs/console/logger' -import { normalizeCallbackUrl } from '@/i18n/utils' +import { normalizeCallbackUrl, type LocaleCode } from '@/i18n/utils' import type { Messages } from 'next-intl' const logger = createLogger('useVerification') @@ -95,6 +96,7 @@ export function useVerification({ copy, }: UseVerificationParams): UseVerificationReturn { const router = useRouter() + const locale = useLocale() as LocaleCode const searchParams = useSearchParams() const { refetch: refetchSession } = useSession() const [otp, setOtp] = useState('') @@ -162,6 +164,23 @@ export function useVerification({ const isOtpComplete = otp.length === 6 + async function persistPreferredLocale() { + try { + const response = await fetch('/api/users/me/settings', { + method: 'PATCH', + credentials: 'include', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ preferredLocale: locale }), + }) + + if (!response.ok) { + throw new Error('Failed to persist preferred locale after verification') + } + } catch (error) { + logger.warn('Failed to persist preferred locale after verification', { error, locale }) + } + } + async function verifyCode() { if (!isOtpComplete || !email) return @@ -185,6 +204,8 @@ export function useVerification({ logger.warn('Failed to refetch session after verification', e) } + await persistPreferredLocale() + if (typeof window !== 'undefined') { sessionStorage.removeItem('verificationEmail') @@ -278,6 +299,8 @@ export function useVerification({ logger.warn('Failed to refetch session during verification skip:', error) } + await persistPreferredLocale() + if (isInviteFlow && redirectUrl) { router.push(redirectUrl) } else { diff --git a/apps/tradinggoose/app/(landing)/components/feature/components/workflow-preview/workflow-preview.test.tsx b/apps/tradinggoose/app/(landing)/components/feature/components/workflow-preview/workflow-preview.test.tsx index b2a184a9f..4a1ae7333 100644 --- a/apps/tradinggoose/app/(landing)/components/feature/components/workflow-preview/workflow-preview.test.tsx +++ b/apps/tradinggoose/app/(landing)/components/feature/components/workflow-preview/workflow-preview.test.tsx @@ -6,6 +6,7 @@ import { act } from 'react' import { NextIntlClientProvider } from 'next-intl' import { createRoot, type Root } from 'react-dom/client' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { TooltipProvider } from '@/components/ui/tooltip' import { getPublicCopy } from '@/i18n/public-copy' vi.mock('@xyflow/react', () => ({ @@ -73,9 +74,11 @@ vi.mock('@/components/ui/dropdown-menu', () => ({ vi.mock('@/components/listing-selector/listing/row', () => ({ getListingDisplaySymbol: (listing: { base?: string | null; name?: string | null }) => listing?.base || listing?.name || 'Listing', - ListingDisplayRow: ({ listing }: { listing?: { base?: string | null; name?: string | null } }) => ( - {listing?.base || listing?.name || 'Listing'} - ), + ListingDisplayRow: ({ + listing, + }: { + listing?: { base?: string | null; name?: string | null } + }) => {listing?.base || listing?.name || 'Listing'}, })) vi.mock('@/components/listing-selector/selector/resolve-request', () => ({ requestListingResolution: vi.fn().mockResolvedValue(null), @@ -107,11 +110,8 @@ vi.mock('@/widgets/widgets/components/widget-header-control', () => ({ widgetHeaderMenuTextClassName: '', })) -import { - buildTradingAgentWorkflowDemos, - type WorkflowPreviewDemo, -} from './workflow-preview-demos' import { WorkflowPreview } from './workflow-preview' +import { buildTradingAgentWorkflowDemos, type WorkflowPreviewDemo } from './workflow-preview-demos' describe('WorkflowPreview', () => { let container: HTMLDivElement @@ -145,12 +145,11 @@ describe('WorkflowPreview', () => { await act(async () => { root.render( - - - + + + + + ) }) @@ -247,15 +246,16 @@ describe('WorkflowPreview', () => { await act(async () => { root.render( - - - + + + + + ) }) - expect(container.textContent).toContain(copy.workspace.widgets.workflowEditor.summary.objectItem) + expect(container.textContent).toContain( + copy.workspace.widgets.workflowEditor.summary.objectItem + ) }) }) diff --git a/apps/tradinggoose/app/(landing)/components/nav/nav.test.tsx b/apps/tradinggoose/app/(landing)/components/nav/nav.test.tsx index 81efaf923..26f570df4 100644 --- a/apps/tradinggoose/app/(landing)/components/nav/nav.test.tsx +++ b/apps/tradinggoose/app/(landing)/components/nav/nav.test.tsx @@ -17,7 +17,14 @@ const mockReplace = vi.fn() const mockRefresh = vi.fn() const mockReplaceLocaleDocument = vi.fn() const mockUpdateSetting = vi.fn() -let mockSessionUserId: string | null = null +const mockSetTheme = vi.fn() +let mockSessionUser: { + id: string + email: string + name?: string | null + image?: string | null + updatedAt?: Date +} | null = null let mockPathname = '/' let mockSearchParams = '' const flush = () => new Promise((resolve) => setTimeout(resolve, 0)) @@ -83,16 +90,75 @@ vi.mock('@/lib/branding/branding', () => ({ vi.mock('@/lib/auth-client', () => ({ useSession: () => ({ - data: mockSessionUserId ? { user: { id: mockSessionUserId } } : null, + data: mockSessionUser ? { user: mockSessionUser } : null, isPending: false, error: null, refetch: vi.fn(), }), + signOut: vi.fn(), })) vi.mock('@/stores/settings/general/store', () => ({ - useGeneralStore: (selector: (state: { updateSetting: typeof mockUpdateSetting }) => unknown) => - selector({ updateSetting: mockUpdateSetting }), + useGeneralStore: ( + selector: (state: { + theme: 'system' + setTheme: typeof mockSetTheme + updateSetting: typeof mockUpdateSetting + isLoading: boolean + isThemeLoading: boolean + }) => unknown + ) => + selector({ + theme: 'system', + setTheme: mockSetTheme, + updateSetting: mockUpdateSetting, + isLoading: false, + isThemeLoading: false, + }), +})) + +vi.mock('@/hooks/queries/organization', () => ({ + useOrganizations: () => ({ + data: { + activeOrganization: null, + billingData: { data: { billingEnabled: false } }, + }, + }), + useOrganizationBilling: () => ({ data: null }), +})) + +vi.mock('@/hooks/queries/subscription', () => ({ + useSubscriptionData: () => ({ + data: { billingEnabled: false }, + isLoading: false, + }), +})) + +vi.mock('@/lib/billing/billing-portal', () => ({ + openBillingPortal: vi.fn(), +})) + +vi.mock('@/lib/environment', () => ({ + isHosted: false, +})) + +vi.mock('@/stores', () => ({ + clearUserData: vi.fn(), +})) + +vi.mock('@/global-navbar/settings-modal/components/help/help-modal', () => ({ + HelpModal: () => null, +})) + +vi.mock('@/global-navbar/settings-modal/settings-dialog', () => ({ + SettingsDialog: ({ + open, + section, + }: { + open: boolean + section: string + onOpenChange: (open: boolean) => void + }) => (open ?
{section}
: null), })) describe('landing nav registration mode', () => { @@ -108,7 +174,8 @@ describe('landing nav registration mode', () => { vi.clearAllMocks() vi.mocked(getRegistrationModeForRender).mockReset() mockUpdateSetting.mockResolvedValue(undefined) - mockSessionUserId = null + mockSetTheme.mockResolvedValue(undefined) + mockSessionUser = null mockPathname = '/' mockSearchParams = '' container = document.createElement('div') @@ -129,7 +196,11 @@ describe('landing nav registration mode', () => { vi.mocked(getRegistrationModeForRender).mockResolvedValue('waitlist') await act(async () => { - root.render(await PublicNav()) + root.render( + + {await PublicNav()} + + ) }) expect(getRegistrationModeForRender).toHaveBeenCalledTimes(1) @@ -141,7 +212,11 @@ describe('landing nav registration mode', () => { it('reuses an already resolved registration mode when provided', async () => { await act(async () => { - root.render(await PublicNav({ registrationMode: 'disabled' })) + root.render( + + {await PublicNav({ registrationMode: 'disabled' })} + + ) }) expect(getRegistrationModeForRender).not.toHaveBeenCalled() @@ -165,6 +240,112 @@ describe('landing nav registration mode', () => { expect(container.textContent).toContain(getPublicCopy('en').registration.open.primary) }) + it('routes ordinary login navigation without reauth cleanup', async () => { + await act(async () => { + root.render( + +
- - + + {isAuthenticated ? ( + + ) : null} + ) } diff --git a/apps/tradinggoose/app/(landing)/components/structured-data.tsx b/apps/tradinggoose/app/(landing)/components/structured-data.tsx index e7c39f175..250f10986 100644 --- a/apps/tradinggoose/app/(landing)/components/structured-data.tsx +++ b/apps/tradinggoose/app/(landing)/components/structured-data.tsx @@ -2,13 +2,15 @@ import { getLocale } from 'next-intl/server' import { getPublicBillingCatalog } from '@/lib/billing/catalog' import { buildHostedPricingNarrative } from '@/lib/billing/public-catalog' import { getPublicCopy } from '@/i18n/public-copy' -import { type LocaleCode, localizeSiteUrl, SITE_BASE_URL } from '@/i18n/utils' +import { type LocaleCode, localizeSiteUrl } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' const STRUCTURED_DATA_MODIFIED_AT = '2026-04-04T00:00:00+00:00' -const siteEntityUrl = (id: string) => `${SITE_BASE_URL}/#${id}` -const siteAssetUrl = (pathname: string) => `${SITE_BASE_URL}${pathname}` -function buildStructuredOffers(catalog: Awaited>) { +function buildStructuredOffers( + catalog: Awaited>, + siteEntityUrl: (id: string) => string +) { if (!catalog.billingEnabled) { return [] } @@ -86,6 +88,9 @@ export default async function StructuredData() { const billingCatalog = await getPublicBillingCatalog() const locale = (await getLocale()) as LocaleCode const copy = getPublicCopy(locale) + const siteBaseUrl = getBaseUrl() + const siteEntityUrl = (id: string) => `${siteBaseUrl}/#${id}` + const siteAssetUrl = (pathname: string) => `${siteBaseUrl}${pathname}` const siteHomeUrl = localizeSiteUrl(locale, '/') const pricingNarrative = billingCatalog.billingEnabled ? buildHostedPricingNarrative(billingCatalog) @@ -211,7 +216,7 @@ export default async function StructuredData() { applicationSubCategory: 'Trading Platform', operatingSystem: 'Web, Windows, macOS, Linux', softwareVersion: '2026.04.04', - offers: buildStructuredOffers(billingCatalog), + offers: buildStructuredOffers(billingCatalog, siteEntityUrl), featureList: [ 'Visual workflow canvas for trading strategies', 'Custom indicator editor (PineTS)', diff --git a/apps/tradinggoose/app/[locale]/(auth)/auth-entry-pages.test.tsx b/apps/tradinggoose/app/[locale]/(auth)/auth-entry-pages.test.tsx new file mode 100644 index 000000000..4169a0682 --- /dev/null +++ b/apps/tradinggoose/app/[locale]/(auth)/auth-entry-pages.test.tsx @@ -0,0 +1,177 @@ +import type React from 'react' +import { renderToStaticMarkup } from 'react-dom/server' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { getAuthErrorCallbackPath } from '@/lib/auth/auth-error-copy' + +const mockGetLocale = vi.fn() +const mockGetSession = vi.fn() +const mockRedirect = vi.fn((url: string) => { + throw new Error(`redirect:${url}`) +}) +const mockGetBrandConfig = vi.fn() +const mockGetOAuthProviderStatus = vi.fn() +const mockGetRegistrationModeForRender = vi.fn() + +vi.mock('next-intl/server', () => ({ + getLocale: () => mockGetLocale(), +})) + +vi.mock('@/lib/auth', () => ({ + getSession: (...args: unknown[]) => mockGetSession(...args), +})) + +vi.mock('@/lib/branding/branding', () => ({ + getBrandConfig: () => mockGetBrandConfig(), +})) + +vi.mock('@/i18n/navigation', () => ({ + Link: ({ + children, + href, + ...props + }: React.AnchorHTMLAttributes & { + children?: React.ReactNode + href: string + }) => ( + + {children} + + ), + redirect: ({ href, locale }: { href: string; locale?: string }) => { + const localizedPath = locale && href.startsWith('/') ? `/${locale}${href}` : href + return mockRedirect(localizedPath) + }, +})) + +vi.mock('@/app/(auth)/components/oauth-provider-checker', () => ({ + getOAuthProviderStatus: () => mockGetOAuthProviderStatus(), +})) + +vi.mock('@/lib/registration/service', () => ({ + getRegistrationModeForRender: () => mockGetRegistrationModeForRender(), +})) + +vi.mock('@/app/(auth)/components/auth-page-header', () => ({ + AuthPageHeader: () => null, +})) + +vi.mock('@/components/ui/button', () => ({ + Button: ({ children }: { children?: React.ReactNode }) => , +})) + +vi.mock('@/app/(auth)/login/login-form', () => ({ + default: ({ registrationMode }: { registrationMode: string }) => ( +
+ ), +})) + +vi.mock('@/app/(auth)/signup/signup-form', () => ({ + default: ({ registrationMode }: { registrationMode: string }) => ( +
+ ), +})) + +describe('localized auth entry pages', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + mockGetLocale.mockResolvedValue('es') + mockGetSession.mockResolvedValue(null) + mockGetOAuthProviderStatus.mockResolvedValue({ + githubAvailable: false, + googleAvailable: false, + isProduction: false, + }) + mockGetRegistrationModeForRender.mockResolvedValue('open') + mockGetBrandConfig.mockReturnValue({ supportEmail: 'support@tradinggoose.ai' }) + }) + + it('redirects login to the localized workspace when a session is present', async () => { + mockGetSession.mockResolvedValue({ + user: { + id: 'user-1', + }, + }) + + const LoginPage = (await import('./login/page')).default + + await expect(LoginPage()).rejects.toThrow('redirect:/es/workspace') + expect(mockGetSession).toHaveBeenCalledWith() + expect(mockGetOAuthProviderStatus).not.toHaveBeenCalled() + expect(mockGetRegistrationModeForRender).not.toHaveBeenCalled() + }) + + it('renders login reauth routes without redirecting an existing session first', async () => { + mockGetSession.mockResolvedValue({ + user: { + id: 'user-1', + }, + }) + + const LoginPage = (await import('./login/page')).default + + const result = await LoginPage({ searchParams: Promise.resolve({ reauth: '1' }) }) + const markup = renderToStaticMarkup(result) + + expect(markup).toContain('data-testid="login-form"') + expect(mockGetSession).not.toHaveBeenCalled() + expect(mockRedirect).not.toHaveBeenCalled() + }) + + it('renders login when the session check is empty', async () => { + const LoginPage = (await import('./login/page')).default + + const result = await LoginPage() + const markup = renderToStaticMarkup(result) + + expect(markup).toContain('data-testid="login-form"') + expect(markup).toContain('data-registration-mode="open"') + expect(mockRedirect).not.toHaveBeenCalled() + }) + + it('redirects signup to the localized workspace when a session is present', async () => { + mockGetSession.mockResolvedValue({ + user: { + id: 'user-1', + }, + }) + + const SignupPage = (await import('./signup/page')).default + + await expect(SignupPage({ searchParams: Promise.resolve({}) })).rejects.toThrow( + 'redirect:/es/workspace' + ) + expect(mockGetSession).toHaveBeenCalledWith() + expect(mockGetOAuthProviderStatus).not.toHaveBeenCalled() + expect(mockGetRegistrationModeForRender).not.toHaveBeenCalled() + }) + + it('renders signup when the session check is empty', async () => { + const SignupPage = (await import('./signup/page')).default + + const result = await SignupPage({ searchParams: Promise.resolve({}) }) + const markup = renderToStaticMarkup(result) + + expect(markup).toContain('data-testid="signup-form"') + expect(markup).toContain('data-registration-mode="open"') + expect(mockRedirect).not.toHaveBeenCalled() + }) + + it('keeps provider error callback state in the per-flow error URL', async () => { + const callbackPath = getAuthErrorCallbackPath('/invite/invitation-1?token=workspace-token') + const ErrorPage = (await import('./error/[[...callback]]/page')).default + + const result = await ErrorPage({ + params: Promise.resolve({ + callback: callbackPath?.split('/').filter(Boolean).slice(1), + }), + searchParams: Promise.resolve({ error: 'UNABLE_TO_CREATE_SESSION' }), + }) + const markup = renderToStaticMarkup(result) + + expect(markup).toContain( + '/login?reauth=1&callbackUrl=%2Finvite%2Finvitation-1%3Ftoken%3Dworkspace-token' + ) + expect(markup).not.toContain('document.cookie') + }) +}) diff --git a/apps/tradinggoose/app/[locale]/(auth)/error/page.tsx b/apps/tradinggoose/app/[locale]/(auth)/error/[[...callback]]/page.tsx similarity index 81% rename from apps/tradinggoose/app/[locale]/(auth)/error/page.tsx rename to apps/tradinggoose/app/[locale]/(auth)/error/[[...callback]]/page.tsx index 816bf46d4..ff5476244 100644 --- a/apps/tradinggoose/app/[locale]/(auth)/error/page.tsx +++ b/apps/tradinggoose/app/[locale]/(auth)/error/[[...callback]]/page.tsx @@ -1,33 +1,34 @@ import { getLocale } from 'next-intl/server' import { Button } from '@/components/ui/button' -import { getAuthErrorContent } from '@/lib/auth/auth-error-copy' +import { getAuthErrorContent, normalizeAuthErrorCallbackSegments } from '@/lib/auth/auth-error-copy' import { getBrandConfig } from '@/lib/branding/branding' import { AuthPageHeader } from '@/app/(auth)/components/auth-page-header' import { inter } from '@/app/fonts/inter' import { Link } from '@/i18n/navigation' import { getPublicCopy } from '@/i18n/public-copy' -import { type LocaleCode } from '@/i18n/utils' - -export const dynamic = 'force-dynamic' +import type { LocaleCode } from '@/i18n/utils' function getSingleSearchParam(value: string | string[] | undefined) { return Array.isArray(value) ? value[0] : value } export default async function AuthErrorPage({ + params, searchParams, }: { + params?: Promise<{ callback?: string[] }> searchParams?: Promise<{ error?: string | string[] error_description?: string | string[] }> }) { - const resolvedSearchParams = (await searchParams) ?? {} - const error = getSingleSearchParam(resolvedSearchParams.error) - const errorDescription = getSingleSearchParam(resolvedSearchParams.error_description) + const [resolvedParams, resolvedSearchParams] = await Promise.all([params, searchParams]) + const error = getSingleSearchParam(resolvedSearchParams?.error) + const errorDescription = getSingleSearchParam(resolvedSearchParams?.error_description) + const callbackUrl = normalizeAuthErrorCallbackSegments(resolvedParams?.callback) const locale = (await getLocale()) as LocaleCode const copy = getPublicCopy(locale) - const { code, content } = getAuthErrorContent(copy, error, errorDescription) + const { code, content } = getAuthErrorContent(copy, error, errorDescription, callbackUrl) const brand = getBrandConfig() const supportEmail = brand.supportEmail const errorCopy = copy.auth.error @@ -47,9 +48,7 @@ export default async function AuthErrorPage({ > {errorCopy.codeLabel}

- - {code} - + {code}
) : null} diff --git a/apps/tradinggoose/app/[locale]/(auth)/login/page.tsx b/apps/tradinggoose/app/[locale]/(auth)/login/page.tsx index 216d56b48..448bc8ee7 100644 --- a/apps/tradinggoose/app/[locale]/(auth)/login/page.tsx +++ b/apps/tradinggoose/app/[locale]/(auth)/login/page.tsx @@ -1,11 +1,29 @@ +import { getLocale } from 'next-intl/server' import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker' import LoginForm from '@/app/(auth)/login/login-form' +import { getSession } from '@/lib/auth' import { getRegistrationModeForRender } from '@/lib/registration/service' +import { redirect } from '@/i18n/navigation' // Force dynamic rendering to avoid prerender errors with search params export const dynamic = 'force-dynamic' -export default async function LoginPage() { +export default async function LoginPage({ + searchParams, +}: { + searchParams?: Promise<{ reauth?: string }> +} = {}) { + const query = await searchParams + const isReauth = query?.reauth === '1' + const [locale, session] = await Promise.all([ + getLocale(), + isReauth ? Promise.resolve(null) : getSession(), + ]) + + if (session?.user?.id) { + redirect({ href: '/workspace', locale }) + } + const [{ githubAvailable, googleAvailable, isProduction }, registrationMode] = await Promise.all([ getOAuthProviderStatus(), getRegistrationModeForRender(), diff --git a/apps/tradinggoose/app/[locale]/(auth)/signup/page.tsx b/apps/tradinggoose/app/[locale]/(auth)/signup/page.tsx index 6e11f3db2..b0f8bdf78 100644 --- a/apps/tradinggoose/app/[locale]/(auth)/signup/page.tsx +++ b/apps/tradinggoose/app/[locale]/(auth)/signup/page.tsx @@ -1,11 +1,11 @@ import { getLocale } from 'next-intl/server' -import { Link } from '@/i18n/navigation' +import { Link, redirect } from '@/i18n/navigation' import { getPublicCopy } from '@/i18n/public-copy' -import { type LocaleCode } from '@/i18n/utils' import { AuthPageHeader } from '@/app/(auth)/components/auth-page-header' import { getOAuthProviderStatus } from '@/app/(auth)/components/oauth-provider-checker' import SignupForm from '@/app/(auth)/signup/signup-form' import { Button } from '@/components/ui/button' +import { getSession } from '@/lib/auth' import { getRegistrationModeForRender } from '@/lib/registration/service' export const dynamic = 'force-dynamic' @@ -15,12 +15,18 @@ export default async function SignupPage({ }: { searchParams?: Promise<{ invite_flow?: string }> }) { - const [providers, locale] = await Promise.all([ - Promise.all([getOAuthProviderStatus(), getRegistrationModeForRender()]), + const [locale, session] = await Promise.all([ getLocale(), + getSession(), ]) + + if (session?.user?.id) { + redirect({ href: '/workspace', locale }) + } + + const providers = await Promise.all([getOAuthProviderStatus(), getRegistrationModeForRender()]) const [{ githubAvailable, googleAvailable, isProduction }, registrationMode] = providers - const copy = getPublicCopy(locale as LocaleCode) + const copy = getPublicCopy(locale) const commonCopy = copy.auth.common const disabledCopy = copy.auth.disabled const resolvedSearchParams = (await searchParams) ?? {} diff --git a/apps/tradinggoose/app/[locale]/(landing)/blog/[slug]/page.tsx b/apps/tradinggoose/app/[locale]/(landing)/blog/[slug]/page.tsx index 895f0eeda..0d12eae0f 100644 --- a/apps/tradinggoose/app/[locale]/(landing)/blog/[slug]/page.tsx +++ b/apps/tradinggoose/app/[locale]/(landing)/blog/[slug]/page.tsx @@ -12,9 +12,9 @@ import { buildLocalizedAlternates, getOpenGraphLocale, localizeSiteUrl, - SITE_BASE_URL, type LocaleCode, } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' import { getPostBySlug } from '@/app/(landing)/blog/lib/posts' import { formatBlogDate } from '@/app/(landing)/blog/lib/heading-slugs' import BreadcrumbNav from '@/app/(landing)/blog/components/breadcrumb-nav' @@ -74,6 +74,7 @@ export default async function PostPage({ params }: PostPageProps) { if (!post) notFound() const locale = (await getLocale()) as LocaleCode const copy = getPublicCopy(locale) + const siteBaseUrl = getBaseUrl() const { title, date, image, authors, tags, toc, content, readingTime } = post const postPath = `/blog/${slug}` @@ -102,7 +103,7 @@ export default async function PostPage({ params }: PostPageProps) { ], })), }), - publisher: { '@id': `${SITE_BASE_URL}/#organization` }, + publisher: { '@id': `${siteBaseUrl}/#organization` }, mainEntityOfPage: { '@type': 'WebPage', '@id': localizeSiteUrl(locale, postPath) }, ...(tags?.length && { keywords: tags.join(', '), articleSection: tags[0] }), inLanguage: locale, diff --git a/apps/tradinggoose/app/[locale]/(landing)/layout.tsx b/apps/tradinggoose/app/[locale]/(landing)/layout.tsx index 6af3cd044..f58e9bf0b 100644 --- a/apps/tradinggoose/app/[locale]/(landing)/layout.tsx +++ b/apps/tradinggoose/app/[locale]/(landing)/layout.tsx @@ -1,18 +1,20 @@ import type { Metadata } from 'next' import Background from '@/app/(landing)/components/background/background' -import { SITE_BASE_URL } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' -export const metadata: Metadata = { - metadataBase: new URL(SITE_BASE_URL), - manifest: '/manifest.webmanifest', - icons: { - icon: '/favicon.ico', - apple: '/apple-icon.png', - }, - other: { - 'msapplication-TileColor': '#000000', - 'theme-color': '#000000', - }, +export function generateMetadata(): Metadata { + return { + metadataBase: new URL(getBaseUrl()), + manifest: '/manifest.webmanifest', + icons: { + icon: '/favicon.ico', + apple: '/apple-icon.png', + }, + other: { + 'msapplication-TileColor': '#000000', + 'theme-color': '#000000', + }, + } } export default function LandingLayout({ children }: { children: React.ReactNode }) { diff --git a/apps/tradinggoose/app/[locale]/[...notFound]/page.tsx b/apps/tradinggoose/app/[locale]/[...notFound]/page.tsx new file mode 100644 index 000000000..68abc79dd --- /dev/null +++ b/apps/tradinggoose/app/[locale]/[...notFound]/page.tsx @@ -0,0 +1,5 @@ +import { notFound } from 'next/navigation' + +export default function LocalizedNotFoundRoute() { + notFound() +} diff --git a/apps/tradinggoose/app/[locale]/admin/layout.test.tsx b/apps/tradinggoose/app/[locale]/admin/layout.test.tsx index bfab86a80..6e86f0004 100644 --- a/apps/tradinggoose/app/[locale]/admin/layout.test.tsx +++ b/apps/tradinggoose/app/[locale]/admin/layout.test.tsx @@ -1,10 +1,12 @@ import type React from 'react' import { renderToStaticMarkup } from 'react-dom/server' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CANONICAL_CALLBACK_PATH_HEADER } from '@/i18n/utils' let capturedGlobalNavbarProps: | { isSystemAdmin?: boolean + workspaceUser?: { id: string; email: string | null } | null navigationMode?: 'workspace' | 'admin' } | undefined @@ -12,12 +14,39 @@ let capturedGlobalNavbarProps: const mockNotFound = vi.fn(() => { throw new Error('notFound') }) +const mockRedirect = vi.fn((url: string) => { + throw new Error(`redirect:${url}`) +}) const mockGetSystemAdminAccess = vi.fn() +const mockHeaders = vi.fn() vi.mock('next/navigation', () => ({ notFound: () => mockNotFound(), })) +vi.mock('next/headers', () => ({ + headers: () => mockHeaders(), +})) + +vi.mock('@/i18n/navigation', () => ({ + redirect: ({ + href, + locale, + }: { + href: string | { pathname: string; query?: Record } + locale?: string + }) => { + const canonicalPath = + typeof href === 'string' + ? href + : `${href.pathname}${href.query ? `?${new URLSearchParams(href.query).toString()}` : ''}` + const localizedPath = + locale && canonicalPath.startsWith('/') ? `/${locale}${canonicalPath}` : canonicalPath + + return mockRedirect(localizedPath) + }, +})) + vi.mock('@/lib/admin/access', () => ({ getSystemAdminAccess: (...args: unknown[]) => mockGetSystemAdminAccess(...args), })) @@ -26,13 +55,15 @@ vi.mock('@/global-navbar', () => ({ GlobalNavbar: ({ children, isSystemAdmin, + workspaceUser, navigationMode, }: { children: React.ReactNode isSystemAdmin?: boolean + workspaceUser?: { id: string; email: string | null } | null navigationMode?: 'workspace' | 'admin' }) => { - capturedGlobalNavbarProps = { isSystemAdmin, navigationMode } + capturedGlobalNavbarProps = { isSystemAdmin, workspaceUser, navigationMode } return
{children}
}, })) @@ -42,11 +73,19 @@ describe('Admin layout', () => { vi.clearAllMocks() vi.resetModules() capturedGlobalNavbarProps = undefined + mockHeaders.mockResolvedValue(new Headers()) + + mockRedirect.mockImplementation((url: string) => { + throw new Error(`redirect:${url}`) + }) }) it('renders admin content inside the admin navbar', async () => { mockGetSystemAdminAccess.mockResolvedValue({ + isAuthenticated: true, isSystemAdmin: false, + userId: 'admin-user-1', + user: { email: 'admin@example.com' }, canBootstrapSystemAdmin: true, }) @@ -59,12 +98,71 @@ describe('Admin layout', () => { expect(renderToStaticMarkup(result)).toContain('admin content') expect(capturedGlobalNavbarProps).toEqual({ isSystemAdmin: false, + workspaceUser: { + id: 'admin-user-1', + email: 'admin@example.com', + }, navigationMode: 'admin', }) + expect(mockGetSystemAdminAccess).toHaveBeenCalledWith(expect.any(Headers)) + }) + + it('redirects signed-out admin entry to login with the current callback target', async () => { + mockHeaders.mockResolvedValue( + new Headers([[CANONICAL_CALLBACK_PATH_HEADER, '/admin/billing?from=nav']]) + ) + mockGetSystemAdminAccess.mockResolvedValue({ + isAuthenticated: false, + isSystemAdmin: false, + canBootstrapSystemAdmin: false, + }) + + const AdminLayout = (await import('./layout')).default + + await expect( + AdminLayout({ + children:
admin content
, + params: Promise.resolve({ locale: 'es' }), + }) + ).rejects.toThrow('redirect:/es/login?callbackUrl=%2Fadmin%2Fbilling%3Ffrom%3Dnav') + + expect(mockRedirect).toHaveBeenCalledWith( + '/es/login?callbackUrl=%2Fadmin%2Fbilling%3Ffrom%3Dnav' + ) + expect(mockNotFound).not.toHaveBeenCalled() + }) + + it('routes invalid admin session cookies through reauth cleanup', async () => { + mockHeaders.mockResolvedValue( + new Headers([ + [CANONICAL_CALLBACK_PATH_HEADER, '/admin/billing?from=nav'], + ['cookie', 'better-auth.session_token=stale'], + ]) + ) + mockGetSystemAdminAccess.mockResolvedValue({ + isAuthenticated: false, + isSystemAdmin: false, + canBootstrapSystemAdmin: false, + }) + + const AdminLayout = (await import('./layout')).default + + await expect( + AdminLayout({ + children:
admin content
, + params: Promise.resolve({ locale: 'es' }), + }) + ).rejects.toThrow('redirect:/es/login?reauth=1&callbackUrl=%2Fadmin%2Fbilling%3Ffrom%3Dnav') + + expect(mockRedirect).toHaveBeenCalledWith( + '/es/login?reauth=1&callbackUrl=%2Fadmin%2Fbilling%3Ffrom%3Dnav' + ) + expect(mockNotFound).not.toHaveBeenCalled() }) it('calls notFound when the user cannot access admin routes', async () => { mockGetSystemAdminAccess.mockResolvedValue({ + isAuthenticated: true, isSystemAdmin: false, canBootstrapSystemAdmin: false, }) diff --git a/apps/tradinggoose/app/[locale]/admin/layout.tsx b/apps/tradinggoose/app/[locale]/admin/layout.tsx index d546c5f25..c9a8781e5 100644 --- a/apps/tradinggoose/app/[locale]/admin/layout.tsx +++ b/apps/tradinggoose/app/[locale]/admin/layout.tsx @@ -1,10 +1,13 @@ import type React from 'react' +import { getSessionCookie } from 'better-auth/cookies' +import { headers } from 'next/headers' import { notFound } from 'next/navigation' import { NextIntlClientProvider } from 'next-intl' import { getSystemAdminAccess } from '@/lib/admin/access' import { GlobalNavbar } from '@/global-navbar' +import { redirect } from '@/i18n/navigation' import { getClientMessages } from '@/i18n/public-copy' -import type { LocaleCode } from '@/i18n/utils' +import { type LocaleCode, requireCanonicalCallbackPath } from '@/i18n/utils' export default async function AdminLayout({ children, @@ -13,8 +16,22 @@ export default async function AdminLayout({ children: React.ReactNode params: Promise<{ locale: string }> }) { - const [{ locale: routeLocale }, access] = await Promise.all([params, getSystemAdminAccess()]) + const [{ locale: routeLocale }, requestHeaders] = await Promise.all([params, headers()]) const locale = routeLocale as LocaleCode + const access = await getSystemAdminAccess(requestHeaders) + + if (!access.isAuthenticated) { + return redirect({ + href: { + pathname: '/login', + query: { + ...(getSessionCookie(requestHeaders) ? { reauth: '1' } : {}), + callbackUrl: requireCanonicalCallbackPath(requestHeaders, 'admin'), + }, + }, + locale, + }) + } if (!access.isSystemAdmin && !access.canBootstrapSystemAdmin) { notFound() @@ -22,7 +39,18 @@ export default async function AdminLayout({ return ( - + {children} diff --git a/apps/tradinggoose/app/[locale]/changelog/page.tsx b/apps/tradinggoose/app/[locale]/changelog/page.tsx index a30f3f79b..5c8de468f 100644 --- a/apps/tradinggoose/app/[locale]/changelog/page.tsx +++ b/apps/tradinggoose/app/[locale]/changelog/page.tsx @@ -7,8 +7,8 @@ import { getOpenGraphLocale, type LocaleCode, localizeSiteUrl, - SITE_BASE_URL, } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' export async function generateMetadata(): Promise { const locale = (await getLocale()) as LocaleCode @@ -36,6 +36,7 @@ export async function generateMetadata(): Promise { export default async function ChangelogPage() { const locale = (await getLocale()) as LocaleCode const copy = getPublicCopy(locale) + const siteBaseUrl = getBaseUrl() const changelogStructuredData = { '@context': 'https://schema.org', '@graph': [ @@ -46,10 +47,10 @@ export default async function ChangelogPage() { url: localizeSiteUrl(locale, '/changelog'), mainEntityOfPage: localizeSiteUrl(locale, '/changelog'), inLanguage: locale, - author: { '@id': `${SITE_BASE_URL}/#organization` }, - publisher: { '@id': `${SITE_BASE_URL}/#organization` }, - about: { '@id': `${SITE_BASE_URL}/#software` }, - isPartOf: { '@id': `${SITE_BASE_URL}/#website` }, + author: { '@id': `${siteBaseUrl}/#organization` }, + publisher: { '@id': `${siteBaseUrl}/#organization` }, + about: { '@id': `${siteBaseUrl}/#software` }, + isPartOf: { '@id': `${siteBaseUrl}/#website` }, }, { '@type': 'BreadcrumbList', diff --git a/apps/tradinggoose/app/[locale]/layout.tsx b/apps/tradinggoose/app/[locale]/layout.tsx index 76bd0c153..b35dfe463 100644 --- a/apps/tradinggoose/app/[locale]/layout.tsx +++ b/apps/tradinggoose/app/[locale]/layout.tsx @@ -10,9 +10,9 @@ import { type AppLocale, routing } from '@/i18n/routing' import 'monaco-editor/min/vs/editor/editor.main.css' import '@/app/globals.css' +import { AppBootstrap } from '@/app/app-bootstrap' import { TooltipProvider } from '@/components/ui/tooltip' import { SessionProvider } from '@/lib/session/session-context' -import { ProviderModelsBootstrap } from '@/app/provider-models-bootstrap' import { QueryProvider } from '@/app/query-provider' import { ThemeProvider } from '@/app/theme-provider' import { ZoomPrevention } from '@/app/zoom-prevention' @@ -26,6 +26,8 @@ export const viewport: Viewport = { ], } +export const dynamic = 'force-dynamic' + export async function generateMetadata({ params, }: { @@ -37,10 +39,6 @@ export async function generateMetadata({ ) } -export function generateStaticParams() { - return routing.locales.map((locale) => ({ locale })) -} - export default async function RootLayout({ children, params, @@ -74,7 +72,7 @@ export default async function RootLayout({ locale={locale} messages={getClientMessages(locale)} > - + {children} diff --git a/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.test.tsx b/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.test.tsx index 4ce4d0070..7fcfec243 100644 --- a/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.test.tsx +++ b/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.test.tsx @@ -23,9 +23,7 @@ vi.mock('@/i18n/navigation', () => ({ ? href : `${href.pathname}${href.query ? `?${new URLSearchParams(href.query).toString()}` : ''}` const localizedPath = - locale && locale !== 'en' && canonicalPath.startsWith('/') - ? `/${locale}${canonicalPath}` - : canonicalPath + locale && canonicalPath.startsWith('/') ? `/${locale}${canonicalPath}` : canonicalPath return mockRedirect(localizedPath) }, @@ -44,8 +42,18 @@ vi.mock('@/lib/permissions/utils', () => ({ })) vi.mock('@/app/workspace/[workspaceId]/providers/providers', () => ({ - default: ({ children, workspaceId }: { children: React.ReactNode; workspaceId?: string }) => ( -
{children}
+ default: ({ + children, + workspaceId, + userId, + }: { + children: React.ReactNode + workspaceId: string + userId: string + }) => ( +
+ {children} +
), })) @@ -68,6 +76,33 @@ describe('Workspace layout access guard', () => { const WorkspaceLayout = (await import('./layout')).default + await expect( + WorkspaceLayout({ + children:
workspace
, + params: Promise.resolve({ locale: 'es', workspaceId: 'ws-1' }), + }) + ).rejects.toThrow( + 'redirect:/es/login?callbackUrl=%2Fworkspace%2Fws-1%2Ffiles%3FlayoutId%3Dlayout-1' + ) + + expect(mockRedirect).toHaveBeenCalledWith( + '/es/login?callbackUrl=%2Fworkspace%2Fws-1%2Ffiles%3FlayoutId%3Dlayout-1' + ) + expect(mockGetSession).toHaveBeenCalledWith(expect.any(Headers)) + expect(mockCheckWorkspaceAccess).not.toHaveBeenCalled() + }) + + it('routes invalid session cookies through reauth cleanup', async () => { + mockHeaders.mockResolvedValue( + new Headers([ + [CANONICAL_CALLBACK_PATH_HEADER, '/workspace/ws-1/files?layoutId=layout-1'], + ['cookie', 'better-auth.session_token=stale'], + ]) + ) + mockGetSession.mockResolvedValue(null) + + const WorkspaceLayout = (await import('./layout')).default + await expect( WorkspaceLayout({ children:
workspace
, @@ -80,11 +115,10 @@ describe('Workspace layout access guard', () => { expect(mockRedirect).toHaveBeenCalledWith( '/es/login?reauth=1&callbackUrl=%2Fworkspace%2Fws-1%2Ffiles%3FlayoutId%3Dlayout-1' ) - expect(mockGetSession).toHaveBeenCalledWith(expect.any(Headers), { disableCookieCache: true }) expect(mockCheckWorkspaceAccess).not.toHaveBeenCalled() }) - it('redirects to /workspace when the user cannot access the workspace', async () => { + it('redirects to the localized workspace root when the user cannot access the workspace', async () => { mockGetSession.mockResolvedValue({ user: { id: 'user-1', @@ -104,10 +138,10 @@ describe('Workspace layout access guard', () => { children:
workspace
, params: Promise.resolve({ locale: 'en', workspaceId: 'ws-1' }), }) - ).rejects.toThrow('redirect:/workspace') + ).rejects.toThrow('redirect:/en/workspace') expect(mockCheckWorkspaceAccess).toHaveBeenCalledWith('ws-1', 'user-1') - expect(mockRedirect).toHaveBeenCalledWith('/workspace') + expect(mockRedirect).toHaveBeenCalledWith('/en/workspace') }) it('renders the workspace route when access is valid', async () => { @@ -131,8 +165,11 @@ describe('Workspace layout access guard', () => { params: Promise.resolve({ locale: 'en', workspaceId: 'ws-1' }), }) - expect(renderToStaticMarkup(result)).toContain('data-workspace-id="ws-1"') - expect(renderToStaticMarkup(result)).toContain('workspace') + const markup = renderToStaticMarkup(result) + + expect(markup).toContain('workspace') + expect(markup).toContain('data-workspace-id="ws-1"') + expect(markup).toContain('data-user-id="user-1"') expect(mockRedirect).not.toHaveBeenCalled() }) }) diff --git a/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.tsx b/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.tsx index e1e5e1475..f7832c191 100644 --- a/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.tsx +++ b/apps/tradinggoose/app/[locale]/workspace/[workspaceId]/layout.tsx @@ -1,9 +1,10 @@ +import { getSessionCookie } from 'better-auth/cookies' import { headers } from 'next/headers' import { getSession } from '@/lib/auth' import { checkWorkspaceAccess } from '@/lib/permissions/utils' import Providers from '@/app/workspace/[workspaceId]/providers/providers' import { redirect } from '@/i18n/navigation' -import { CANONICAL_CALLBACK_PATH_HEADER, type LocaleCode } from '@/i18n/utils' +import { type LocaleCode, requireCanonicalCallbackPath } from '@/i18n/utils' export default async function WorkspaceLayout({ children, @@ -15,21 +16,16 @@ export default async function WorkspaceLayout({ const { locale: routeLocale, workspaceId } = await params const locale = routeLocale as LocaleCode const requestHeaders = await headers() - const session = await getSession(requestHeaders, { disableCookieCache: true }) + const session = await getSession(requestHeaders) const userId = session?.user?.id if (!userId) { - const callbackUrl = requestHeaders.get(CANONICAL_CALLBACK_PATH_HEADER) - if (!callbackUrl) { - throw new Error('Missing canonical callback path for workspace reauth redirect') - } - return redirect({ href: { pathname: '/login', query: { - reauth: '1', - callbackUrl, + ...(getSessionCookie(requestHeaders) ? { reauth: '1' } : {}), + callbackUrl: requireCanonicalCallbackPath(requestHeaders, 'workspace'), }, }, locale, @@ -43,7 +39,7 @@ export default async function WorkspaceLayout({ } return ( - +
{children}
diff --git a/apps/tradinggoose/app/[locale]/workspace/layout.tsx b/apps/tradinggoose/app/[locale]/workspace/layout.tsx index 39a0817b7..589a72de0 100644 --- a/apps/tradinggoose/app/[locale]/workspace/layout.tsx +++ b/apps/tradinggoose/app/[locale]/workspace/layout.tsx @@ -19,7 +19,19 @@ export default async function WorkspaceRootLayout({ return ( - {children} + + {children} + ) diff --git a/apps/tradinggoose/app/[locale]/workspace/page.test.tsx b/apps/tradinggoose/app/[locale]/workspace/page.test.tsx new file mode 100644 index 000000000..d302326ec --- /dev/null +++ b/apps/tradinggoose/app/[locale]/workspace/page.test.tsx @@ -0,0 +1,165 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CANONICAL_CALLBACK_PATH_HEADER } from '@/i18n/utils' + +const mockRedirect = vi.fn((url: string) => { + throw new Error(`redirect:${url}`) +}) +const mockGetSession = vi.fn() +const mockHeaders = vi.fn() +const mockGetUserWorkspaces = vi.fn() +const mockReadWorkflowAccessContext = vi.fn() + +function mockLocalizedRedirect({ + href, + locale, +}: { + href: string | { pathname: string; query?: Record } + locale?: string +}) { + const canonicalPath = + typeof href === 'string' + ? href + : `${href.pathname}${href.query ? `?${new URLSearchParams(href.query).toString()}` : ''}` + const localizedPath = + locale && canonicalPath.startsWith('/') ? `/${locale}${canonicalPath}` : canonicalPath + return mockRedirect(localizedPath) +} + +vi.mock('@/i18n/navigation', () => ({ + redirect: mockLocalizedRedirect, +})) + +vi.mock('next/headers', () => ({ + headers: () => mockHeaders(), +})) + +vi.mock('@/lib/auth', () => ({ + getSession: (...args: unknown[]) => mockGetSession(...args), +})) + +vi.mock('@/lib/workspaces/service', () => ({ + getUserWorkspaces: (...args: unknown[]) => mockGetUserWorkspaces(...args), +})) + +vi.mock('@/lib/workflows/utils', () => ({ + readWorkflowAccessContext: (...args: unknown[]) => mockReadWorkflowAccessContext(...args), +})) + +async function renderWorkspacePage( + locale = 'en', + searchParams: { callbackUrl?: string; redirect_workflow?: string } = {} +) { + const WorkspacePage = (await import('./page')).default + return WorkspacePage({ + params: Promise.resolve({ locale }), + searchParams: Promise.resolve(searchParams), + }) +} + +describe('Workspace root page access guard', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + mockHeaders.mockResolvedValue(new Headers()) + + mockRedirect.mockImplementation((url: string) => { + throw new Error(`redirect:${url}`) + }) + mockGetSession.mockResolvedValue({ + user: { + id: 'user-1', + name: 'Ada Lovelace', + }, + }) + mockGetUserWorkspaces.mockResolvedValue([{ id: 'workspace-1' }]) + mockReadWorkflowAccessContext.mockResolvedValue(null) + }) + + it('redirects signed-out users to login with the current callback target', async () => { + mockHeaders.mockResolvedValue( + new Headers([[CANONICAL_CALLBACK_PATH_HEADER, '/workspace?redirect_workflow=workflow-1']]) + ) + mockGetSession.mockResolvedValue(null) + + await expect(renderWorkspacePage('zh')).rejects.toThrow( + 'redirect:/zh/login?callbackUrl=%2Fworkspace%3Fredirect_workflow%3Dworkflow-1' + ) + + expect(mockRedirect).toHaveBeenCalledWith( + '/zh/login?callbackUrl=%2Fworkspace%3Fredirect_workflow%3Dworkflow-1' + ) + expect(mockGetSession).toHaveBeenCalledWith(expect.any(Headers)) + }) + + it('routes invalid session cookies through reauth cleanup', async () => { + mockHeaders.mockResolvedValue( + new Headers([ + [CANONICAL_CALLBACK_PATH_HEADER, '/workspace?redirect_workflow=workflow-1'], + ['cookie', 'better-auth.session_token=stale'], + ]) + ) + mockGetSession.mockResolvedValue(null) + + await expect(renderWorkspacePage('zh')).rejects.toThrow( + 'redirect:/zh/login?reauth=1&callbackUrl=%2Fworkspace%3Fredirect_workflow%3Dworkflow-1' + ) + + expect(mockRedirect).toHaveBeenCalledWith( + '/zh/login?reauth=1&callbackUrl=%2Fworkspace%3Fredirect_workflow%3Dworkflow-1' + ) + expect(mockGetSession).toHaveBeenCalledWith(expect.any(Headers)) + }) + + it('redirects authenticated users to the requested workflow workspace', async () => { + mockReadWorkflowAccessContext.mockResolvedValue({ + workflow: { + workspaceId: 'workspace-from-workflow', + }, + isOwner: false, + isWorkspaceOwner: false, + workspacePermission: 'read', + }) + await expect(renderWorkspacePage('en', { redirect_workflow: 'workflow-1' })).rejects.toThrow( + 'redirect:/en/workspace/workspace-from-workflow/dashboard' + ) + + expect(mockReadWorkflowAccessContext).toHaveBeenCalledWith('workflow-1', 'user-1') + expect(mockGetUserWorkspaces).not.toHaveBeenCalled() + }) + + it('redirects authenticated users to same-origin absolute callback URLs', async () => { + mockHeaders.mockResolvedValue(new Headers([['host', 'preview.local:3000']])) + + await expect( + renderWorkspacePage('en', { + callbackUrl: 'http://preview.local:3000/workspace/workspace-2/dashboard?layoutId=layout-1', + }) + ).rejects.toThrow('redirect:/en/workspace/workspace-2/dashboard?layoutId=layout-1') + + expect(mockGetUserWorkspaces).not.toHaveBeenCalled() + }) + + it('redirects authenticated users to their first workspace dashboard', async () => { + await expect(renderWorkspacePage('es')).rejects.toThrow( + 'redirect:/es/workspace/workspace-1/dashboard' + ) + + expect(mockGetUserWorkspaces).toHaveBeenCalledWith({ + userId: 'user-1', + userName: 'Ada Lovelace', + }) + }) + + it('bootstraps a workspace on the server when the user has none and redirects to it', async () => { + mockGetUserWorkspaces.mockResolvedValue([{ id: 'workspace-bootstrapped' }]) + + await expect(renderWorkspacePage('en')).rejects.toThrow( + 'redirect:/en/workspace/workspace-bootstrapped/dashboard' + ) + + expect(mockGetUserWorkspaces).toHaveBeenCalledWith({ + userId: 'user-1', + userName: 'Ada Lovelace', + }) + }) +}) diff --git a/apps/tradinggoose/app/[locale]/workspace/page.tsx b/apps/tradinggoose/app/[locale]/workspace/page.tsx index ef3ada300..d4782dbbc 100644 --- a/apps/tradinggoose/app/[locale]/workspace/page.tsx +++ b/apps/tradinggoose/app/[locale]/workspace/page.tsx @@ -1,151 +1,100 @@ -'use client' - -import { useEffect } from 'react' -import { useTranslations } from 'next-intl' -import { LoadingAgent } from '@/components/ui/loading-agent' -import { useSession } from '@/lib/auth-client' -import { createLogger } from '@/lib/logs/console/logger' -import { usePathname, useRouter } from '@/i18n/navigation' -import { normalizeCallbackUrl } from '@/i18n/utils' - -const logger = createLogger('WorkspacePage') - -export default function WorkspacePage() { - const router = useRouter() - const pathname = usePathname() - const tWorkspace = useTranslations('workspace') - const { data: session, isPending, error: sessionError } = useSession() - - useEffect(() => { - const redirectToFirstWorkspace = async () => { - if (isPending) { - return - } - - if (sessionError || !session?.user) { - logger.info('User not authenticated, redirecting to home', { - hasSessionError: Boolean(sessionError), - }) - router.replace('/') - return - } - - try { - const urlParams = new URLSearchParams(window.location.search) - const callbackUrl = normalizeCallbackUrl(urlParams.get('callbackUrl'), window.location.origin) - const redirectWorkflowId = urlParams.get('redirect_workflow') - - if (callbackUrl) { - const callbackPath = new URL(callbackUrl, window.location.origin).pathname - - if (callbackPath !== pathname) { - logger.info('Redirecting to callback URL from workspace root', { callbackUrl }) - router.replace(callbackUrl) - return - } - } - - if (redirectWorkflowId) { - try { - const workflowResponse = await fetch(`/api/workflows/${redirectWorkflowId}`) - if (workflowResponse.ok) { - const workflowData = await workflowResponse.json() - const workspaceId = workflowData.data?.workspaceId - - if (workspaceId) { - logger.info( - `Redirecting workflow ${redirectWorkflowId} to workspace ${workspaceId} dashboard` - ) - router.replace(`/workspace/${workspaceId}/dashboard`) - return - } - } - } catch (error) { - logger.error('Error fetching workflow for redirect:', error) - } - } - - const response = await fetch('/api/workspaces', { - credentials: 'include', - }) - - if (response.status === 401 || response.status === 403) { - logger.info('Unauthorized to fetch workspaces, redirecting to home', { - status: response.status, - }) - router.replace('/') - return - } - - if (!response.ok) { - let errorBody = '' - try { - errorBody = await response.text() - } catch {} - - logger.error('Failed to fetch workspaces for redirect', { - status: response.status, - body: errorBody, - }) - router.replace('/') - return - } - - const data = await response.json() - const workspaces = data.workspaces || [] - - if (workspaces.length === 0) { - logger.warn('No workspaces found for user, creating default workspace') - - try { - const createResponse = await fetch('/api/workspaces', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ name: tWorkspace('defaults.newWorkspaceName') }), - }) +import { getSessionCookie } from 'better-auth/cookies' +import { headers } from 'next/headers' +import { getSession } from '@/lib/auth' +import { readWorkflowAccessContext } from '@/lib/workflows/utils' +import { getUserWorkspaces } from '@/lib/workspaces/service' +import { redirect } from '@/i18n/navigation' +import { type LocaleCode, normalizeCallbackUrl, requireCanonicalCallbackPath } from '@/i18n/utils' + +type WorkspaceSearchParams = Promise<{ + callbackUrl?: string | string[] + redirect_workflow?: string | string[] +}> + +function getSearchParam( + searchParams: Awaited, + key: keyof Awaited +) { + const value = searchParams[key] + return Array.isArray(value) ? value[0] : value +} - if (createResponse.ok) { - const createData = await createResponse.json() - const newWorkspace = createData.workspace +function normalizeRequestCallbackUrl(href: string | null | undefined, headers: Headers) { + const internalCallback = normalizeCallbackUrl(href) + if (internalCallback) { + return internalCallback + } - if (newWorkspace?.id) { - logger.info( - `Created default workspace ${newWorkspace.id}, redirecting to dashboard` - ) - router.replace(`/workspace/${newWorkspace.id}/dashboard`) - return - } - } + const host = (headers.get('x-forwarded-host') ?? headers.get('host'))?.split(',', 1)[0]?.trim() + const forwardedProtocol = headers.get('x-forwarded-proto')?.split(',', 1)[0]?.trim() + const protocols = forwardedProtocol ? [forwardedProtocol] : ['http', 'https'] - logger.error('Failed to create default workspace') - } catch (createError) { - logger.error('Error creating default workspace:', createError) - } + for (const protocol of host ? protocols : []) { + const callback = normalizeCallbackUrl(href, `${protocol}://${host}`) + if (callback) { + return callback + } + } - router.replace('/') - return - } + return null +} - const firstWorkspace = workspaces[0] - logger.info(`Redirecting to workspace ${firstWorkspace.id} dashboard`) - router.replace(`/workspace/${firstWorkspace.id}/dashboard`) - } catch (error) { - logger.error('Error fetching workspaces for redirect:', error) - router.replace('/') - } +export default async function WorkspacePage({ + params, + searchParams, +}: { + params: Promise<{ locale: string }> + searchParams: WorkspaceSearchParams +}) { + const [{ locale: routeLocale }, query, requestHeaders] = await Promise.all([ + params, + searchParams, + headers(), + ]) + const locale = routeLocale as LocaleCode + const session = await getSession(requestHeaders) + const userId = session?.user?.id + + if (!userId) { + return redirect({ + href: { + pathname: '/login', + query: { + ...(getSessionCookie(requestHeaders) ? { reauth: '1' } : {}), + callbackUrl: requireCanonicalCallbackPath(requestHeaders, 'workspace'), + }, + }, + locale, + }) + } + + const callbackUrl = normalizeRequestCallbackUrl( + getSearchParam(query, 'callbackUrl'), + requestHeaders + ) + if (callbackUrl && callbackUrl.split(/[?#]/, 1)[0] !== '/workspace') { + return redirect({ href: callbackUrl, locale }) + } + + const redirectWorkflowId = getSearchParam(query, 'redirect_workflow') + if (redirectWorkflowId) { + const access = await readWorkflowAccessContext(redirectWorkflowId, userId) + if ( + access?.workflow.workspaceId && + (access.isOwner || access.isWorkspaceOwner || access.workspacePermission) + ) { + return redirect({ href: `/workspace/${access.workflow.workspaceId}/dashboard`, locale }) } + } - void redirectToFirstWorkspace() - }, [isPending, pathname, router, session, sessionError, tWorkspace]) + const [workspace] = await getUserWorkspaces({ + userId, + userName: session.user.name, + }) - return ( -
-
- - {tWorkspace('entry.loading')} -
-
- ) + if (!workspace) { + throw new Error('Expected workspace bootstrap to return a workspace') + } + + return redirect({ href: `/workspace/${workspace.id}/dashboard`, locale }) } diff --git a/apps/tradinggoose/app/api/auth/socket-token/route.test.ts b/apps/tradinggoose/app/api/auth/socket-token/route.test.ts new file mode 100644 index 000000000..845ce1443 --- /dev/null +++ b/apps/tradinggoose/app/api/auth/socket-token/route.test.ts @@ -0,0 +1,74 @@ +/** + * @vitest-environment node + */ + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockGenerateOneTimeToken, mockGetSession, mockHeaders } = vi.hoisted(() => ({ + mockGenerateOneTimeToken: vi.fn(), + mockGetSession: vi.fn(), + mockHeaders: vi.fn(), +})) + +vi.mock('next/headers', () => ({ + headers: () => mockHeaders(), +})) + +vi.mock('@/lib/auth', () => ({ + auth: { + api: { + generateOneTimeToken: (...args: unknown[]) => mockGenerateOneTimeToken(...args), + }, + }, + getSession: (...args: unknown[]) => mockGetSession(...args), +})) + +describe('socket token route', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + mockHeaders.mockResolvedValue(new Headers([['cookie', 'better-auth.session_token=token']])) + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGenerateOneTimeToken.mockResolvedValue({ token: 'socket-token' }) + }) + + it('uses the canonical app session before issuing a socket token', async () => { + const { POST } = await import('./route') + + const response = await POST() + + expect(response.status).toBe(200) + expect(await response.json()).toEqual({ token: 'socket-token' }) + expect(mockGetSession).toHaveBeenCalledWith(expect.any(Headers)) + expect(mockGenerateOneTimeToken).toHaveBeenCalledWith({ headers: expect.any(Headers) }) + }) + + it('rejects socket token requests without an app session', async () => { + mockGetSession.mockResolvedValue(null) + + const { POST } = await import('./route') + const response = await POST() + + expect(response.status).toBe(401) + expect(await response.json()).toEqual({ error: 'Authentication required' }) + expect(mockGenerateOneTimeToken).not.toHaveBeenCalled() + }) + + it('does not issue a token when canonical session lookup rejects stale cookie data', async () => { + const staleCookieHeaders = new Headers([ + [ + 'cookie', + 'better-auth.session_token=revoked; better-auth.session_data=stale-session-payload', + ], + ]) + mockHeaders.mockResolvedValue(staleCookieHeaders) + mockGetSession.mockResolvedValue(null) + + const { POST } = await import('./route') + const response = await POST() + + expect(response.status).toBe(401) + expect(mockGetSession).toHaveBeenCalledWith(staleCookieHeaders) + expect(mockGenerateOneTimeToken).not.toHaveBeenCalled() + }) +}) diff --git a/apps/tradinggoose/app/api/auth/socket-token/route.ts b/apps/tradinggoose/app/api/auth/socket-token/route.ts index 5d8b4f146..4af204928 100644 --- a/apps/tradinggoose/app/api/auth/socket-token/route.ts +++ b/apps/tradinggoose/app/api/auth/socket-token/route.ts @@ -5,7 +5,7 @@ import { auth, getSession } from '@/lib/auth' export async function POST() { try { const hdrs = await headers() - const session = await getSession(hdrs, { disableCookieCache: true }) + const session = await getSession(hdrs) if (!session?.user?.id) { return NextResponse.json({ error: 'Authentication required' }, { status: 401 }) diff --git a/apps/tradinggoose/app/api/auth/sso/register/route.test.ts b/apps/tradinggoose/app/api/auth/sso/register/route.test.ts index 044bf3805..9484d2042 100644 --- a/apps/tradinggoose/app/api/auth/sso/register/route.test.ts +++ b/apps/tradinggoose/app/api/auth/sso/register/route.test.ts @@ -47,7 +47,9 @@ vi.mock('@/lib/env', () => ({ env: { SSO_ENABLED: true, }, - getEnv: vi.fn(() => undefined), + getEnv: vi.fn((key: string) => + key === 'NEXT_PUBLIC_APP_URL' ? 'http://localhost:3000' : undefined + ), isTruthy: (value: string | boolean | number | undefined) => typeof value === 'string' ? value.toLowerCase() === 'true' || value === '1' : Boolean(value), })) diff --git a/apps/tradinggoose/app/api/auth/trello/authorize/route.test.ts b/apps/tradinggoose/app/api/auth/trello/authorize/route.test.ts index 7c9063985..459361d4d 100644 --- a/apps/tradinggoose/app/api/auth/trello/authorize/route.test.ts +++ b/apps/tradinggoose/app/api/auth/trello/authorize/route.test.ts @@ -72,9 +72,7 @@ describe('Trello authorize route', () => { const returnURL = new URL(authorizeURL.searchParams.get('return_url')!) expect(returnURL.pathname).toBe('/api/auth/trello/callback') - expect(returnURL.searchParams.get('callbackURL')).toBe( - 'http://localhost:3000/workspace/ws-1/integrations' - ) + expect(returnURL.searchParams.get('callbackURL')).toBe('/workspace/ws-1/integrations') expect(returnURL.searchParams.get('state')).toBe('trello-state') const setCookie = response.headers.get('set-cookie') diff --git a/apps/tradinggoose/app/api/auth/trello/authorize/route.ts b/apps/tradinggoose/app/api/auth/trello/authorize/route.ts index 93b7b7112..a3472ccc0 100644 --- a/apps/tradinggoose/app/api/auth/trello/authorize/route.ts +++ b/apps/tradinggoose/app/api/auth/trello/authorize/route.ts @@ -8,50 +8,43 @@ import { TRELLO_OAUTH_STATE_COOKIE, } from '@/lib/trello/auth' import { getBaseUrl } from '@/lib/urls/utils' +import { normalizeCallbackUrl } from '@/i18n/utils' export const dynamic = 'force-dynamic' const logger = createLogger('TrelloAuthorizeAPI') -function getSafeCallbackURL(request: NextRequest) { +function getCallbackPath(request: NextRequest) { const appUrl = new URL(getBaseUrl()) - const rawCallbackURL = request.nextUrl.searchParams.get('callbackURL') || '/' - - try { - const callbackURL = new URL(rawCallbackURL, appUrl.origin) - if (callbackURL.origin !== appUrl.origin) { - return appUrl.origin - } - - return callbackURL.toString() - } catch { - return appUrl.origin - } + return normalizeCallbackUrl(request.nextUrl.searchParams.get('callbackURL'), appUrl.origin) } -function redirectWithError(callbackURL: string, error: string) { - const redirectURL = new URL(callbackURL) +function redirectWithError(callbackPath: string, error: string) { + const redirectURL = new URL(callbackPath, getBaseUrl()) redirectURL.searchParams.set('error', error) return NextResponse.redirect(redirectURL) } export async function GET(request: NextRequest) { - const callbackURL = getSafeCallbackURL(request) + const callbackPath = getCallbackPath(request) + if (!callbackPath) { + return NextResponse.json({ error: 'invalid_callback_url' }, { status: 400 }) + } try { const session = await getSession(request.headers) if (!session?.user?.id) { - return redirectWithError(callbackURL, 'user_not_authenticated') + return redirectWithError(callbackPath, 'user_not_authenticated') } const apiKey = await getTrelloApiKey() if (!apiKey) { - return redirectWithError(callbackURL, 'trello_not_configured') + return redirectWithError(callbackPath, 'trello_not_configured') } const state = createTrelloOAuthState() const returnURL = new URL('/api/auth/trello/callback', getBaseUrl()) - returnURL.searchParams.set('callbackURL', callbackURL) + returnURL.searchParams.set('callbackURL', callbackPath) returnURL.searchParams.set('state', state) const authorizeURL = new URL('https://trello.com/1/authorize') @@ -68,6 +61,6 @@ export async function GET(request: NextRequest) { return response } catch (error) { logger.error('Failed to start Trello authorization', { error }) - return redirectWithError(callbackURL, 'trello_authorization_failed') + return redirectWithError(callbackPath, 'trello_authorization_failed') } } diff --git a/apps/tradinggoose/app/api/auth/trello/callback/route.ts b/apps/tradinggoose/app/api/auth/trello/callback/route.ts index 28dd2f0f0..0bbd638bb 100644 --- a/apps/tradinggoose/app/api/auth/trello/callback/route.ts +++ b/apps/tradinggoose/app/api/auth/trello/callback/route.ts @@ -1,5 +1,6 @@ import { type NextRequest, NextResponse } from 'next/server' import { getBaseUrl } from '@/lib/urls/utils' +import { normalizeCallbackUrl } from '@/i18n/utils' export const dynamic = 'force-dynamic' @@ -13,20 +14,11 @@ const INLINE_SCRIPT_ESCAPES: Record = { '\u2029': '\\u2029', } -function getSafeCallbackURL(request: NextRequest) { - const fallback = new URL('/', getBaseUrl()) +function getCallbackURL(request: NextRequest) { + const appUrl = new URL(getBaseUrl()) const rawCallbackURL = request.nextUrl.searchParams.get('callbackURL') - - if (!rawCallbackURL) { - return fallback - } - - try { - const callbackURL = new URL(rawCallbackURL, fallback.origin) - return callbackURL.origin === fallback.origin ? callbackURL : fallback - } catch { - return fallback - } + const callbackPath = normalizeCallbackUrl(rawCallbackURL, appUrl.origin) + return callbackPath ? new URL(callbackPath, appUrl.origin) : null } function serializeForInlineScript(value: string) { @@ -116,7 +108,11 @@ function renderTrelloCallbackPage({ callbackURL, state }: { callbackURL: URL; st } export async function GET(request: NextRequest) { - const callbackURL = getSafeCallbackURL(request) + const callbackURL = getCallbackURL(request) + if (!callbackURL) { + return NextResponse.json({ error: 'invalid_callback_url' }, { status: 400 }) + } + const state = request.nextUrl.searchParams.get('state')?.trim() || '' return new NextResponse(renderTrelloCallbackPage({ callbackURL, state }), { diff --git a/apps/tradinggoose/app/api/chat/utils.ts b/apps/tradinggoose/app/api/chat/utils.ts index 9eda57bed..9769785f5 100644 --- a/apps/tradinggoose/app/api/chat/utils.ts +++ b/apps/tradinggoose/app/api/chat/utils.ts @@ -4,7 +4,7 @@ import { and, eq, gte, isNull, or } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { isDev } from '@/lib/environment' import { createLogger } from '@/lib/logs/console/logger' -import { hasAdminPermission } from '@/lib/permissions/utils' +import { hasWorkspaceAdminAccess } from '@/lib/permissions/utils' import { decryptSecret } from '@/lib/utils-server' import { CHAT_ERROR_CODES } from '@/app/chat/constants' @@ -31,7 +31,7 @@ export async function checkWorkflowAccessForChatCreation( } if (workflowRecord.workspaceId) { - const hasAdmin = await hasAdminPermission(userId, workflowRecord.workspaceId) + const hasAdmin = await hasWorkspaceAdminAccess(userId, workflowRecord.workspaceId) if (hasAdmin) { return { hasAccess: true, workflow: workflowRecord } } @@ -69,7 +69,7 @@ export async function checkChatAccess( } if (workflowWorkspaceId) { - const hasAdmin = await hasAdminPermission(userId, workflowWorkspaceId) + const hasAdmin = await hasWorkspaceAdminAccess(userId, workflowWorkspaceId) if (hasAdmin) { return { hasAccess: true, chat: chatRecord, workspaceId: workflowWorkspaceId } } diff --git a/apps/tradinggoose/app/api/copilot/chat/route.ts b/apps/tradinggoose/app/api/copilot/chat/route.ts index 266c1726d..ae9472678 100644 --- a/apps/tradinggoose/app/api/copilot/chat/route.ts +++ b/apps/tradinggoose/app/api/copilot/chat/route.ts @@ -17,6 +17,7 @@ import { createRequestTracker, createUnauthorizedResponse, } from '@/lib/copilot/auth' +import { mirrorLocalCopilotCompletionUsageReports } from '@/lib/copilot/completion-usage-billing' import { normalizeFunctionCallArguments } from '@/lib/copilot/function-call-args' import { mapSessionToApiResponse, @@ -264,6 +265,7 @@ async function persistChatMessages( function generateAndPersistTitle(params: { reviewSessionId: string message: string + userId: string model: string provider?: ProviderId requestId: string @@ -271,6 +273,7 @@ function generateAndPersistTitle(params: { }): void { requestCopilotTitle({ message: params.message, + userId: params.userId, model: params.model, provider: params.provider, }) @@ -942,6 +945,10 @@ export async function POST(req: NextRequest) { enqueueTurnState('in_progress', 'streaming') const forwardClientEvent = (event: Record) => { + if (event.type === 'billing.completion_usage') { + return + } + if (event.type === 'awaiting_tools') { latestTurnStatus = 'in_progress' enqueueTurnState('in_progress', 'waiting_for_tools') @@ -990,6 +997,12 @@ export async function POST(req: NextRequest) { } const event = JSON.parse(jsonStr) + if (event.type === 'billing.completion_usage') { + await mirrorLocalCopilotCompletionUsageReports({ + userId: authenticatedUserId, + reports: [event.report], + }) + } switch (event.type) { case 'tool_result': @@ -1023,6 +1036,7 @@ export async function POST(req: NextRequest) { generateAndPersistTitle({ reviewSessionId: actualReviewSessionId!, message, + userId: authenticatedUserId, model, provider: runtimeProvider, requestId: tracker.requestId, @@ -1127,6 +1141,12 @@ export async function POST(req: NextRequest) { try { const jsonStr = buffer.slice(6) const event = JSON.parse(jsonStr) + if (event.type === 'billing.completion_usage') { + await mirrorLocalCopilotCompletionUsageReports({ + userId: authenticatedUserId, + reports: [event.report], + }) + } if (event.type === 'tool_result') { streamCapture.captureToolResult(event as Record) } @@ -1304,6 +1324,13 @@ export async function POST(req: NextRequest) { } }) : undefined + await mirrorLocalCopilotCompletionUsageReports({ + userId: authenticatedUserId, + reports: Array.isArray(responseData.completionUsageReports) + ? responseData.completionUsageReports + : [], + }) + responseData.completionUsageReports = undefined if (currentSession && (responseData.content || contentBlocks?.length)) { await persistChatMessages({ @@ -1324,6 +1351,7 @@ export async function POST(req: NextRequest) { generateAndPersistTitle({ reviewSessionId: actualReviewSessionId, message, + userId: authenticatedUserId, model: providerConfig?.model ?? model, provider: providerConfig?.provider, requestId: tracker.requestId, diff --git a/apps/tradinggoose/app/api/copilot/tools/mark-complete/route.ts b/apps/tradinggoose/app/api/copilot/tools/mark-complete/route.ts index 3ff6e0fe7..cb7a7741a 100644 --- a/apps/tradinggoose/app/api/copilot/tools/mark-complete/route.ts +++ b/apps/tradinggoose/app/api/copilot/tools/mark-complete/route.ts @@ -7,6 +7,7 @@ import { createRequestTracker, createUnauthorizedResponse, } from '@/lib/copilot/auth' +import { mirrorLocalCopilotCompletionUsageReports } from '@/lib/copilot/completion-usage-billing' import { createLogger } from '@/lib/logs/console/logger' import { encodeSSE, SSE_HEADERS } from '@/lib/utils' import { getCopilotApiUrl, proxyCopilotRequest } from '@/app/api/copilot/proxy' @@ -22,7 +23,11 @@ const MarkCompleteSchema = z.object({ data: z.any().optional(), }) -function createTurnStateStream(body: ReadableStream, abortUpstream: () => void) { +function createTurnStateStream( + body: ReadableStream, + abortUpstream: () => void, + userId: string +) { let reader: ReadableStreamDefaultReader | null = null return new ReadableStream({ @@ -44,7 +49,15 @@ function createTurnStateStream(body: ReadableStream, abortUpstream: ) } - const forwardEvent = (event: Record) => { + const forwardEvent = async (event: Record) => { + if (event.type === 'billing.completion_usage') { + await mirrorLocalCopilotCompletionUsageReports({ + userId, + reports: [event.report], + }) + return + } + if (event.type === 'awaiting_tools') { enqueueTurnState('in_progress', 'waiting_for_tools') } else if (event.type === 'response.completed') { @@ -80,7 +93,7 @@ function createTurnStateStream(body: ReadableStream, abortUpstream: } const event = JSON.parse(payload) as Record - forwardEvent(event) + await forwardEvent(event) } } @@ -92,7 +105,7 @@ function createTurnStateStream(body: ReadableStream, abortUpstream: } const event = JSON.parse(payload) as Record - forwardEvent(event) + await forwardEvent(event) } } catch (error) { controller.error(error) @@ -182,7 +195,7 @@ export async function POST(req: NextRequest) { toolCallId: parsed.id, toolName: parsed.name, }) - return new NextResponse(createTurnStateStream(agentRes.body, abortUpstream), { + return new NextResponse(createTurnStateStream(agentRes.body, abortUpstream, userId), { status: agentRes.status, headers: { ...SSE_HEADERS, @@ -211,6 +224,12 @@ export async function POST(req: NextRequest) { }) if (agentRes.ok) { + await mirrorLocalCopilotCompletionUsageReports({ + userId, + reports: Array.isArray(agentJson?.completionUsageReports) + ? agentJson.completionUsageReports + : [], + }) return NextResponse.json({ success: true }) } diff --git a/apps/tradinggoose/app/api/copilot/usage/route.test.ts b/apps/tradinggoose/app/api/copilot/usage/route.test.ts index 4026201b4..81b2bd188 100644 --- a/apps/tradinggoose/app/api/copilot/usage/route.test.ts +++ b/apps/tradinggoose/app/api/copilot/usage/route.test.ts @@ -7,7 +7,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' describe('Copilot Usage API - Context', () => { const mockCheckInternalApiKey = vi.fn() - const mockIsHosted = vi.fn() const mockProxyCopilotRequest = vi.fn() const mockIsBillingEnabledForRuntime = vi.fn() const mockGetPersonalEffectiveSubscription = vi.fn() @@ -18,8 +17,9 @@ describe('Copilot Usage API - Context', () => { const mockMarkMessageAsProcessed = vi.fn() const mockCalculateCost = vi.fn() const mockReserveCopilotUsage = vi.fn() - const mockAdjustCopilotUsageReservation = vi.fn() + const mockCommitCopilotUsageReservation = vi.fn() const mockReleaseCopilotUsageReservation = vi.fn() + const mockIsHosted = vi.fn() const createTier = (copilotCostMultiplier: number) => ({ id: `tier-${copilotCostMultiplier}`, @@ -57,7 +57,6 @@ describe('Copilot Usage API - Context', () => { vi.resetModules() mockProxyCopilotRequest.mockReset() mockCheckInternalApiKey.mockReset() - mockIsHosted.mockReset() mockIsBillingEnabledForRuntime.mockReset() mockGetPersonalEffectiveSubscription.mockReset() mockGetTierCopilotCostMultiplier.mockReset() @@ -67,13 +66,16 @@ describe('Copilot Usage API - Context', () => { mockMarkMessageAsProcessed.mockReset() mockCalculateCost.mockReset() mockReserveCopilotUsage.mockReset() - mockAdjustCopilotUsageReservation.mockReset() + mockCommitCopilotUsageReservation.mockReset() mockReleaseCopilotUsageReservation.mockReset() + mockIsHosted.mockReset() mockIsBillingEnabledForRuntime.mockResolvedValue(false) + mockIsHosted.mockReturnValue(true) mockGetPersonalEffectiveSubscription.mockResolvedValue(null) mockGetTierCopilotCostMultiplier.mockImplementation( - (tier: { copilotCostMultiplier?: number } | null | undefined) => tier?.copilotCostMultiplier ?? 1 + (tier: { copilotCostMultiplier?: number } | null | undefined) => + tier?.copilotCostMultiplier ?? 1 ) mockAccrueUserUsageCost.mockResolvedValue(true) mockResolveWorkflowBillingContext.mockResolvedValue({ @@ -98,18 +100,7 @@ describe('Copilot Usage API - Context', () => { scopeType: 'user', scopeId: 'user-1', }) - mockAdjustCopilotUsageReservation.mockResolvedValue({ - allowed: true, - status: 200, - reservationId: 'reservation-1', - reservedUsd: 3, - currentUsage: 8, - limit: 10, - remaining: 0, - activeReservedUsd: 3, - scopeType: 'user', - scopeId: 'user-1', - }) + mockCommitCopilotUsageReservation.mockImplementation(async ({ operation }) => operation()) mockReleaseCopilotUsageReservation.mockResolvedValue({ released: true, reservationId: 'reservation-1', @@ -119,7 +110,6 @@ describe('Copilot Usage API - Context', () => { }) mockCheckInternalApiKey.mockReturnValue({ success: false }) - mockIsHosted.mockReturnValue(true) vi.doMock('@tradinggoose/db', () => ({ db: {}, @@ -144,10 +134,6 @@ describe('Copilot Usage API - Context', () => { checkInternalApiKey: (...args: any[]) => mockCheckInternalApiKey(...args), })) - vi.doMock('@/lib/environment', () => ({ - isHosted: mockIsHosted(), - })) - vi.doMock('@/app/api/copilot/proxy', () => ({ proxyCopilotRequest: (...args: any[]) => mockProxyCopilotRequest(...args), getCopilotApiUrl: vi.fn(() => 'https://copilot.example.test/api/get-context-usage'), @@ -198,6 +184,10 @@ describe('Copilot Usage API - Context', () => { calculateCost: (...args: any[]) => mockCalculateCost(...args), })) + vi.doMock('@/lib/environment', () => ({ + isHosted: mockIsHosted(), + })) + vi.doMock('@/lib/billing/usage-accrual', () => ({ accrueUserUsageCost: (...args: any[]) => mockAccrueUserUsageCost(...args), })) @@ -208,8 +198,7 @@ describe('Copilot Usage API - Context', () => { vi.doMock('@/lib/copilot/usage-reservations', () => ({ reserveCopilotUsage: (...args: any[]) => mockReserveCopilotUsage(...args), - adjustCopilotUsageReservation: (...args: any[]) => - mockAdjustCopilotUsageReservation(...args), + commitCopilotUsageReservation: (...args: any[]) => mockCommitCopilotUsageReservation(...args), releaseCopilotUsageReservation: (...args: any[]) => mockReleaseCopilotUsageReservation(...args), })) @@ -237,6 +226,7 @@ describe('Copilot Usage API - Context', () => { kind: 'context', conversationId: 'conversation-1', model: 'gpt-5.4', + workspaceId: 'workspace-1', }), }) @@ -263,6 +253,7 @@ describe('Copilot Usage API - Context', () => { apiKey: 'test-copilot-key', }, userId: 'user-1', + workspaceId: 'workspace-1', }, }) expect(mockGetPersonalEffectiveSubscription).not.toHaveBeenCalled() @@ -270,283 +261,85 @@ describe('Copilot Usage API - Context', () => { expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() }) - it('does not bill context usage for hosted browser-session requests even when bill is requested', async () => { - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockIsHosted.mockReturnValue(true) - mockProxyCopilotRequest.mockResolvedValue( - new Response( - JSON.stringify({ - tokensUsed: 100, - model: 'gpt-5.4', - }), - { - status: 200, - headers: { 'Content-Type': 'application/json' }, - } + it.each([true, false])( + 'returns display-only context usage for hosted=%s browser sessions', + async (hosted) => { + mockIsHosted.mockReturnValue(hosted) + mockIsBillingEnabledForRuntime.mockResolvedValue(true) + mockProxyCopilotRequest.mockResolvedValue( + new Response( + JSON.stringify({ + tokensUsed: 100, + percentage: 0.1, + model: 'gpt-5.4', + contextWindow: 128000, + }), + { + status: 200, + headers: { 'Content-Type': 'application/json' }, + } + ) ) - ) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - kind: 'context', - conversationId: 'conversation-browser-bill', - model: 'gpt-5.4', - bill: true, - assistantMessageId: 'assistant-message-browser', - }), - }) - - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) - - expect(response.status).toBe(200) - await expect(response.json()).resolves.toEqual({ - tokensUsed: 100, - model: 'gpt-5.4', - }) - expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() - expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() - }) - - it('records local context billing for self-hosted browser-session requests', async () => { - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockIsHosted.mockReturnValue(false) - mockGetPersonalEffectiveSubscription.mockResolvedValue({ - id: 'subscription-personal', - tier: createTier(2), - }) - mockProxyCopilotRequest.mockResolvedValue( - new Response( - JSON.stringify({ - tokensUsed: 100, - model: 'gpt-5.4', - }), - { - status: 200, - headers: { 'Content-Type': 'application/json' }, - } - ) - ) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - kind: 'context', - conversationId: 'conversation-self-host-bill', - model: 'gpt-5.4', - bill: true, - assistantMessageId: 'assistant-message-self-host', - }), - }) - - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) - - expect(response.status).toBe(200) - await expect(response.json()).resolves.toEqual({ - tokensUsed: 100, - model: 'gpt-5.4', - billing: { - billed: true, - duplicate: false, - tokens: 100, - model: 'gpt-5.4', - cost: 3, - }, - }) - expect(mockAccrueUserUsageCost).toHaveBeenCalledWith({ - userId: 'user-1', - workflowId: undefined, - cost: 3, - extraUpdates: expect.any(Object), - reason: 'copilot_context_usage', - }) - expect(mockMarkMessageAsProcessed).toHaveBeenCalledWith( - 'copilot-billing:assistant-message-self-host', - 60 * 60 * 24 * 30 - ) - }) - - it('returns exact personal billing metadata for committed context usage', async () => { - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockGetPersonalEffectiveSubscription.mockResolvedValue({ - id: 'subscription-personal', - tier: createTier(2), - }) - mockProxyCopilotRequest.mockResolvedValue( - new Response( - JSON.stringify({ - tokensUsed: 100, + const request = new NextRequest('http://localhost:3000/api/copilot/usage', { + method: 'POST', + body: JSON.stringify({ + kind: 'context', + conversationId: `conversation-${hosted ? 'hosted' : 'self-hosted'}`, model: 'gpt-5.4', }), - { - status: 200, - headers: { 'Content-Type': 'application/json' }, - } - ) - ) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - action: 'commit', - kind: 'context', - conversationId: 'conversation-2', - model: 'gpt-5.4', - userId: 'user-1', - assistantMessageId: 'assistant-message-1', - reservationId: 'reservation-1', - }), - }) + }) - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) + const { POST } = await import('@/app/api/copilot/usage/route') + const response = await POST(request) - expect(response.status).toBe(200) - await expect(response.json()).resolves.toEqual({ - tokensUsed: 100, - model: 'gpt-5.4', - billing: { - billed: true, - duplicate: false, - tokens: 100, + expect(response.status).toBe(200) + await expect(response.json()).resolves.toEqual({ + tokensUsed: 100, + percentage: 0.1, model: 'gpt-5.4', - cost: 3, - }, - }) - expect(mockGetPersonalEffectiveSubscription).toHaveBeenCalledWith('user-1') - expect(mockResolveWorkflowBillingContext).not.toHaveBeenCalled() - expect(mockAccrueUserUsageCost).toHaveBeenCalledWith({ - userId: 'user-1', - workflowId: undefined, - cost: 3, - extraUpdates: expect.any(Object), - reason: 'copilot_context_usage', - }) - expect(mockMarkMessageAsProcessed).toHaveBeenCalledWith( - 'copilot-billing:assistant-message-1', - 60 * 60 * 24 * 30 - ) - expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ - reservationId: 'reservation-1', - }) - }) - - it('commits workflow context usage with the workflow subscription tier', async () => { - mockIsBillingEnabledForRuntime.mockResolvedValue(true) + contextWindow: 128000, + }) + expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() + expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() + } + ) + + it('rejects context usage inspection without a browser session even with internal auth', async () => { mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockProxyCopilotRequest.mockResolvedValue( - new Response( - JSON.stringify({ - tokensUsed: 100, - model: 'gpt-5.4', - }), - { - status: 200, - headers: { 'Content-Type': 'application/json' }, - } - ) - ) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - action: 'commit', - kind: 'context', - conversationId: 'conversation-3', - model: 'gpt-5.4', - userId: 'user-1', - workflowId: 'workflow-1', - assistantMessageId: 'assistant-message-2', - reservationId: 'reservation-1', - }), - }) - - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) - - expect(response.status).toBe(200) - await expect(response.json()).resolves.toMatchObject({ - billing: { - billed: true, - cost: 4.5, - }, - }) - expect(mockResolveWorkflowBillingContext).toHaveBeenCalledWith({ - workflowId: 'workflow-1', - actorUserId: 'user-1', - }) - expect(mockGetPersonalEffectiveSubscription).not.toHaveBeenCalled() - expect(mockAccrueUserUsageCost).toHaveBeenCalledWith({ - userId: 'user-1', - workflowId: 'workflow-1', - cost: 4.5, - extraUpdates: expect.any(Object), - reason: 'copilot_context_usage', - }) - expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ - reservationId: 'reservation-1', - }) - }) - - it('returns 500 for committed context billing when Studio cannot resolve a tier', async () => { - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockGetPersonalEffectiveSubscription.mockResolvedValue(null) - mockProxyCopilotRequest.mockResolvedValue( - new Response( - JSON.stringify({ - tokensUsed: 100, - model: 'gpt-5.4', - }), - { - status: 200, - headers: { 'Content-Type': 'application/json' }, - } - ) - ) + vi.doMock('@/lib/auth', () => ({ + getSession: vi.fn().mockResolvedValue(null), + })) const request = new NextRequest('http://localhost:3000/api/copilot/usage', { method: 'POST', body: JSON.stringify({ - action: 'commit', kind: 'context', - conversationId: 'conversation-4', + conversationId: 'conversation-1', model: 'gpt-5.4', userId: 'user-1', - assistantMessageId: 'assistant-message-3', - reservationId: 'reservation-1', }), }) const { POST } = await import('@/app/api/copilot/usage/route') const response = await POST(request) - expect(response.status).toBe(500) - expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() - expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() - expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ - reservationId: 'reservation-1', - }) + expect(response.status).toBe(401) + expect(mockProxyCopilotRequest).not.toHaveBeenCalled() }) - it('releases the reservation when committed context usage throws before billing completes', async () => { + it('rejects context usage commit requests because context usage is inspection-only', async () => { mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockProxyCopilotRequest.mockRejectedValue(new Error('copilot unavailable')) const request = new NextRequest('http://localhost:3000/api/copilot/usage', { method: 'POST', body: JSON.stringify({ action: 'commit', kind: 'context', - conversationId: 'conversation-5', + conversationId: 'conversation-2', model: 'gpt-5.4', userId: 'user-1', - assistantMessageId: 'assistant-message-4', + assistantMessageId: 'assistant-message-1', reservationId: 'reservation-1', }), }) @@ -554,12 +347,10 @@ describe('Copilot Usage API - Context', () => { const { POST } = await import('@/app/api/copilot/usage/route') const response = await POST(request) - expect(response.status).toBe(500) + expect(response.status).toBe(400) + expect(mockProxyCopilotRequest).not.toHaveBeenCalled() expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() - expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() - expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ - reservationId: 'reservation-1', - }) + expect(mockReleaseCopilotUsageReservation).not.toHaveBeenCalled() }) it('reserves shared usage budget through the internal reserve action', async () => { @@ -693,89 +484,6 @@ describe('Copilot Usage API - Context', () => { expect(mockGetPersonalEffectiveSubscription).not.toHaveBeenCalled() }) - it('adjusts shared usage budget through the internal adjust action using Studio pricing', async () => { - mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockIsBillingEnabledForRuntime.mockResolvedValue(true) - mockGetPersonalEffectiveSubscription.mockResolvedValue({ - id: 'subscription-personal', - tier: createTier(2), - }) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - action: 'adjust', - reservationId: 'reservation-1', - userId: 'user-1', - model: 'openai/gpt-5.4', - estimatedPromptTokens: 100, - reservedCompletionTokens: 25, - reason: 'copilot_turn_model_call', - }), - }) - - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) - - expect(response.status).toBe(200) - await expect(response.json()).resolves.toEqual({ - allowed: true, - status: 200, - reservationId: 'reservation-1', - reservedUsd: 3, - currentUsage: 8, - limit: 10, - remaining: 0, - activeReservedUsd: 3, - scopeType: 'user', - scopeId: 'user-1', - }) - expect(mockAdjustCopilotUsageReservation).toHaveBeenCalledWith({ - reservationId: 'reservation-1', - userId: 'user-1', - workflowId: undefined, - requestedUsd: 3, - reason: 'copilot_turn_model_call', - }) - }) - - it('no-ops adjust requests when billing is disabled', async () => { - mockCheckInternalApiKey.mockReturnValue({ success: true }) - mockIsBillingEnabledForRuntime.mockResolvedValue(false) - - const request = new NextRequest('http://localhost:3000/api/copilot/usage', { - method: 'POST', - body: JSON.stringify({ - action: 'adjust', - reservationId: 'reservation-1', - userId: 'user-1', - model: 'openai/gpt-5.4', - estimatedPromptTokens: 100, - reservedCompletionTokens: 25, - reason: 'copilot_turn_model_call', - }), - }) - - const { POST } = await import('@/app/api/copilot/usage/route') - const response = await POST(request) - - expect(response.status).toBe(200) - await expect(response.json()).resolves.toEqual({ - allowed: true, - status: 200, - reservationId: 'reservation-1', - reservedUsd: 0, - currentUsage: 0, - limit: Number.MAX_SAFE_INTEGER, - remaining: Number.MAX_SAFE_INTEGER, - activeReservedUsd: 0, - scopeType: 'user', - scopeId: 'user-1', - }) - expect(mockAdjustCopilotUsageReservation).not.toHaveBeenCalled() - expect(mockGetPersonalEffectiveSubscription).not.toHaveBeenCalled() - }) - it('releases reservations through the internal release action', async () => { mockCheckInternalApiKey.mockReturnValue({ success: true }) mockIsBillingEnabledForRuntime.mockResolvedValue(true) @@ -866,8 +574,9 @@ describe('Copilot Usage API - Completion', () => { const mockHasProcessedMessage = vi.fn() const mockMarkMessageAsProcessed = vi.fn() const mockCalculateCost = vi.fn() - const mockAdjustCopilotUsageReservation = vi.fn() + const mockCommitCopilotUsageReservation = vi.fn() const mockReleaseCopilotUsageReservation = vi.fn() + const mockIsHosted = vi.fn() const createTier = (copilotCostMultiplier: number) => ({ id: `tier-${copilotCostMultiplier}`, @@ -912,17 +621,20 @@ describe('Copilot Usage API - Completion', () => { mockHasProcessedMessage.mockReset() mockMarkMessageAsProcessed.mockReset() mockCalculateCost.mockReset() - mockAdjustCopilotUsageReservation.mockReset() + mockCommitCopilotUsageReservation.mockReset() mockReleaseCopilotUsageReservation.mockReset() + mockIsHosted.mockReset() mockCheckInternalApiKey.mockReturnValue({ success: true }) mockIsBillingEnabledForRuntime.mockResolvedValue(true) + mockIsHosted.mockReturnValue(true) mockGetPersonalEffectiveSubscription.mockResolvedValue({ id: 'subscription-personal', tier: createTier(2), }) mockGetTierCopilotCostMultiplier.mockImplementation( - (tier: { copilotCostMultiplier?: number } | null | undefined) => tier?.copilotCostMultiplier ?? 1 + (tier: { copilotCostMultiplier?: number } | null | undefined) => + tier?.copilotCostMultiplier ?? 1 ) mockAccrueUserUsageCost.mockResolvedValue(true) mockResolveWorkflowBillingContext.mockResolvedValue({ @@ -935,6 +647,15 @@ describe('Copilot Usage API - Completion', () => { mockHasProcessedMessage.mockResolvedValue(false) mockMarkMessageAsProcessed.mockResolvedValue(undefined) mockCalculateCost.mockReturnValue({ total: 1.5 }) + mockCommitCopilotUsageReservation.mockImplementation(async ({ reservationId, operation }) => { + try { + return await operation() + } finally { + if (reservationId) { + await mockReleaseCopilotUsageReservation({ reservationId }) + } + } + }) mockReleaseCopilotUsageReservation.mockResolvedValue({ released: true, reservationId: 'reservation-1', @@ -954,6 +675,10 @@ describe('Copilot Usage API - Completion', () => { checkInternalApiKey: (...args: any[]) => mockCheckInternalApiKey(...args), })) + vi.doMock('@/lib/environment', () => ({ + isHosted: mockIsHosted(), + })) + vi.doMock('@/lib/billing/settings', () => ({ isBillingEnabledForRuntime: (...args: any[]) => mockIsBillingEnabledForRuntime(...args), })) @@ -977,8 +702,7 @@ describe('Copilot Usage API - Completion', () => { vi.doMock('@/lib/copilot/usage-reservations', () => ({ reserveCopilotUsage: vi.fn(), - adjustCopilotUsageReservation: (...args: any[]) => - mockAdjustCopilotUsageReservation(...args), + commitCopilotUsageReservation: (...args: any[]) => mockCommitCopilotUsageReservation(...args), releaseCopilotUsageReservation: (...args: any[]) => mockReleaseCopilotUsageReservation(...args), })) @@ -1001,15 +725,15 @@ describe('Copilot Usage API - Completion', () => { })) }) - it('records internal completion billing with canonical dotted Claude model ids', async () => { + it('records internal completion billing with canonical provider model ids', async () => { const request = new NextRequest('http://localhost:3000/api/copilot/usage', { method: 'POST', body: JSON.stringify({ action: 'commit', kind: 'completion', userId: 'user-1', - model: 'claude-sonnet-4.6', - remoteModel: 'anthropic/claude-sonnet-4.6', + model: 'anthropic/claude-sonnet-4.6', + remoteModel: 'claude-4.6-sonnet-20260217', completionId: 'completion-1', reservationId: 'reservation-1', usage: { @@ -1031,13 +755,11 @@ describe('Copilot Usage API - Completion', () => { billed: true, duplicate: false, tokens: 125, - model: 'claude-sonnet-4.6', + model: 'anthropic/claude-sonnet-4.6', cost: 3, }, }) - expect(mockHasProcessedMessage).toHaveBeenCalledWith( - 'copilot-completion-billing:completion-1' - ) + expect(mockHasProcessedMessage).toHaveBeenCalledWith('copilot-completion-billing:completion-1') expect(mockAccrueUserUsageCost).toHaveBeenCalledWith({ userId: 'user-1', workflowId: undefined, @@ -1049,12 +771,152 @@ describe('Copilot Usage API - Completion', () => { 'copilot-completion-billing:completion-1', 60 * 60 * 24 * 30 ) - expect(mockCalculateCost).toHaveBeenCalledWith('claude-sonnet-4.6', 100, 25, false) + expect(mockCalculateCost).toHaveBeenCalledWith('anthropic/claude-sonnet-4.6', 100, 25, false) + expect(mockCommitCopilotUsageReservation).toHaveBeenCalledWith({ + userId: 'user-1', + workflowId: undefined, + reservationId: 'reservation-1', + operation: expect.any(Function), + }) expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ reservationId: 'reservation-1', }) }) + it('mirrors hosted Copilot completion reports into self-hosted Studio usage', async () => { + mockIsHosted.mockReturnValue(false) + mockIsBillingEnabledForRuntime.mockResolvedValue(true) + mockGetPersonalEffectiveSubscription.mockResolvedValue({ + id: 'subscription-personal', + tier: createTier(2), + }) + + const { mirrorLocalCopilotCompletionUsageReports } = await import( + '@/lib/copilot/completion-usage-billing' + ) + await mirrorLocalCopilotCompletionUsageReports({ + userId: 'user-1', + reports: [ + { + kind: 'completion', + model: 'gpt-5.4', + remoteModel: 'openai/gpt-5.4', + completionId: 'local-completion-1', + usage: { + prompt_tokens: 100, + completion_tokens: 25, + total_tokens: 125, + }, + }, + ], + }) + + expect(mockAccrueUserUsageCost).toHaveBeenCalledWith({ + userId: 'user-1', + workflowId: undefined, + cost: 3, + extraUpdates: expect.any(Object), + reason: 'copilot_completion_usage', + }) + expect(mockMarkMessageAsProcessed).toHaveBeenCalledWith( + 'copilot-completion-billing:local-completion-1', + 60 * 60 * 24 * 30 + ) + expect(mockCommitCopilotUsageReservation).toHaveBeenCalledWith({ + userId: 'user-1', + workflowId: undefined, + operation: expect.any(Function), + }) + }) + + it('ignores invalid self-hosted Copilot completion mirror reports', async () => { + mockIsHosted.mockReturnValue(false) + mockIsBillingEnabledForRuntime.mockResolvedValue(true) + + const { mirrorLocalCopilotCompletionUsageReports } = await import( + '@/lib/copilot/completion-usage-billing' + ) + await mirrorLocalCopilotCompletionUsageReports({ + userId: 'user-1', + reports: [ + { + kind: 'completion', + model: 'gpt-5.4', + usage: { + prompt_tokens: 100, + completion_tokens: 25, + total_tokens: 125, + }, + }, + ], + }) + + expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() + expect(mockHasProcessedMessage).not.toHaveBeenCalled() + expect(mockCommitCopilotUsageReservation).not.toHaveBeenCalled() + }) + + it('isolates self-hosted Copilot completion mirror billing failures', async () => { + mockIsHosted.mockReturnValue(false) + mockIsBillingEnabledForRuntime.mockResolvedValue(true) + mockGetPersonalEffectiveSubscription.mockResolvedValue({ + id: 'subscription-personal', + tier: createTier(2), + }) + mockCalculateCost.mockImplementation(() => { + throw new Error('pricing unavailable') + }) + + const { mirrorLocalCopilotCompletionUsageReports } = await import( + '@/lib/copilot/completion-usage-billing' + ) + await mirrorLocalCopilotCompletionUsageReports({ + userId: 'user-1', + reports: [ + { + kind: 'completion', + model: 'gpt-5.4', + completionId: 'local-completion-2', + usage: { + prompt_tokens: 100, + completion_tokens: 25, + total_tokens: 125, + }, + }, + ], + }) + + expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() + expect(mockCommitCopilotUsageReservation).toHaveBeenCalledWith({ + userId: 'user-1', + workflowId: undefined, + operation: expect.any(Function), + }) + expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() + }) + + it('does not mirror hosted Copilot completion reports on hosted Studio', async () => { + mockIsHosted.mockReturnValue(true) + + const { mirrorLocalCopilotCompletionUsageReports } = await import( + '@/lib/copilot/completion-usage-billing' + ) + await mirrorLocalCopilotCompletionUsageReports({ + userId: 'user-1', + reports: [ + { + kind: 'completion', + model: 'gpt-5.4', + completionId: 'hosted-completion-1', + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }, + ], + }) + + expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() + expect(mockCommitCopilotUsageReservation).not.toHaveBeenCalled() + }) + it('does not double-bill duplicate completion ids', async () => { mockHasProcessedMessage.mockResolvedValue(true) @@ -1089,6 +951,12 @@ describe('Copilot Usage API - Completion', () => { }) expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() expect(mockMarkMessageAsProcessed).not.toHaveBeenCalled() + expect(mockCommitCopilotUsageReservation).toHaveBeenCalledWith({ + userId: 'user-1', + workflowId: undefined, + reservationId: 'reservation-1', + operation: expect.any(Function), + }) expect(mockReleaseCopilotUsageReservation).toHaveBeenCalledWith({ reservationId: 'reservation-1', }) @@ -1125,6 +993,31 @@ describe('Copilot Usage API - Completion', () => { }) }) + it('does not release reservations for malformed completion commits', async () => { + const request = new NextRequest('http://localhost:3000/api/copilot/usage', { + method: 'POST', + body: JSON.stringify({ + action: 'commit', + kind: 'completion', + userId: 'user-1', + reservationId: 'reservation-1', + usage: { + prompt_tokens: 100, + completion_tokens: 25, + total_tokens: 125, + }, + }), + headers: { 'Content-Type': 'application/json' }, + }) + + const { POST } = await import('@/app/api/copilot/usage/route') + const response = await POST(request) + + expect(response.status).toBe(400) + expect(mockAccrueUserUsageCost).not.toHaveBeenCalled() + expect(mockReleaseCopilotUsageReservation).not.toHaveBeenCalled() + }) + it('releases the reservation when completion billing is disabled', async () => { mockIsBillingEnabledForRuntime.mockResolvedValue(false) diff --git a/apps/tradinggoose/app/api/copilot/usage/route.ts b/apps/tradinggoose/app/api/copilot/usage/route.ts index c808a23fa..4ec4441c0 100644 --- a/apps/tradinggoose/app/api/copilot/usage/route.ts +++ b/apps/tradinggoose/app/api/copilot/usage/route.ts @@ -1,29 +1,23 @@ -import { sql } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' -import { getPersonalEffectiveSubscription } from '@/lib/billing/core/subscription' import { isBillingEnabledForRuntime } from '@/lib/billing/settings' -import { getTierCopilotCostMultiplier } from '@/lib/billing/tiers' -import { accrueUserUsageCost } from '@/lib/billing/usage-accrual' -import { resolveWorkflowBillingContext } from '@/lib/billing/workspace-billing' import { - adjustCopilotUsageReservation, - releaseCopilotUsageReservation, - reserveCopilotUsage, -} from '@/lib/copilot/usage-reservations' + calculateCopilotReservationUsdFromEstimate, + recordCopilotCompletionUsage, +} from '@/lib/copilot/completion-usage-billing' import { COPILOT_RUNTIME_MODELS } from '@/lib/copilot/runtime-models' import { COPILOT_RUNTIME_PROVIDER_IDS } from '@/lib/copilot/runtime-provider' import { buildCopilotRuntimeProviderConfig } from '@/lib/copilot/runtime-provider.server' +import { + commitCopilotUsageReservation, + releaseCopilotUsageReservation, + reserveCopilotUsage, +} from '@/lib/copilot/usage-reservations' import { checkInternalApiKey } from '@/lib/copilot/utils' -import { isHosted } from '@/lib/environment' import { createLogger } from '@/lib/logs/console/logger' -import { hasProcessedMessage, markMessageAsProcessed } from '@/lib/redis' import { getCopilotApiUrl, proxyCopilotRequest } from '@/app/api/copilot/proxy' -import { calculateCost } from '@/providers/ai/utils' -const BILLING_EVENT_TTL_SECONDS = 60 * 60 * 24 * 30 // 30 days -const DEFAULT_ESTIMATED_RESERVATION_USD = 1 const BILLING_DISABLED_RESERVATION_ID = 'billing-disabled' const logger = createLogger('CopilotUsageAPI') @@ -32,11 +26,8 @@ const ContextUsageRequestSchema = z.object({ conversationId: z.string(), model: z.enum(COPILOT_RUNTIME_MODELS), workflowId: z.string().optional(), + workspaceId: z.string().optional(), provider: z.enum(COPILOT_RUNTIME_PROVIDER_IDS).optional(), - bill: z.boolean().optional(), - assistantMessageId: z.string().optional(), - billingModel: z.string().optional(), - userId: z.string().optional(), }) const UsageEstimateSchema = z.object({ @@ -53,39 +44,20 @@ const ReserveUsageUsdRequestSchema = z.object({ reason: z.string().min(1).optional(), }) -const ReserveUsageEstimatedRequestSchema = z.object({ - action: z.literal('reserve'), - userId: z.string().min(1, 'userId is required'), - workflowId: z.string().min(1).optional(), - reason: z.string().min(1).optional(), -}).merge(UsageEstimateSchema) +const ReserveUsageEstimatedRequestSchema = z + .object({ + action: z.literal('reserve'), + userId: z.string().min(1, 'userId is required'), + workflowId: z.string().min(1).optional(), + reason: z.string().min(1).optional(), + }) + .merge(UsageEstimateSchema) const ReserveUsageRequestSchema = z.union([ ReserveUsageUsdRequestSchema, ReserveUsageEstimatedRequestSchema, ]) -const AdjustUsageRequestSchema = z.object({ - action: z.literal('adjust'), - reservationId: z.string().min(1, 'reservationId is required'), - userId: z.string().min(1, 'userId is required'), - workflowId: z.string().min(1).optional(), - reason: z.string().min(1).optional(), -}).merge(UsageEstimateSchema) - -const ContextCommitRequestSchema = z.object({ - action: z.literal('commit'), - kind: z.literal('context'), - conversationId: z.string(), - model: z.enum(COPILOT_RUNTIME_MODELS), - workflowId: z.string().optional(), - provider: z.enum(COPILOT_RUNTIME_PROVIDER_IDS).optional(), - assistantMessageId: z.string().min(1, 'assistantMessageId is required'), - billingModel: z.string().optional(), - userId: z.string().min(1, 'userId is required'), - reservationId: z.string().min(1).optional(), -}) - const CompletionCommitRequestSchema = z.object({ action: z.literal('commit'), kind: z.literal('completion'), @@ -103,323 +75,15 @@ const ReleaseUsageRequestSchema = z.object({ reservationId: z.string().min(1, 'reservationId is required'), }) -interface TokenMetrics { - promptTokens: number - completionTokens: number - totalTokens: number -} - -type UsageBillingResult = - | { - billed: true - duplicate: false - cost: number - tokens: number - model: string - } - | { - billed: false - duplicate: true - } - | { - billed: false - duplicate?: false - reason: 'billing_disabled' | 'no_token_metrics' | 'zero_cost' | 'ledger_not_found' - } - -function readNumber(value: unknown): number | undefined { - if (typeof value === 'number' && Number.isFinite(value)) { - return value - } - if (typeof value === 'string') { - const parsed = Number.parseFloat(value) - return Number.isFinite(parsed) ? parsed : undefined - } - return undefined -} - -function pickNumber(source: any, keys: string[]): number | undefined { - if (!source || typeof source !== 'object') return undefined - for (const key of keys) { - const candidate = readNumber(source[key]) - if (candidate !== undefined) { - return candidate - } - } - return undefined -} - -function extractTokenMetrics(usage: any): TokenMetrics | null { - const sources = [usage, usage?.tokenUsage, usage?.tokens, usage?.usageDetails] - - let promptTokens: number | undefined - let completionTokens: number | undefined - let totalTokens: number | undefined - - for (const src of sources) { - if (promptTokens === undefined) { - promptTokens = pickNumber(src, [ - 'prompt_tokens', - 'promptTokens', - 'input_tokens', - 'inputTokens', - 'prompt', - ]) - } - if (completionTokens === undefined) { - completionTokens = pickNumber(src, [ - 'completion_tokens', - 'completionTokens', - 'output_tokens', - 'outputTokens', - 'completion', - ]) - } - if (totalTokens === undefined) { - totalTokens = pickNumber(src, [ - 'total_tokens', - 'totalTokens', - 'tokens', - 'token_count', - 'total', - ]) - } - } - - if (totalTokens === undefined) { - totalTokens = readNumber(usage?.tokensUsed) ?? readNumber(usage?.usage) - } - - if (completionTokens === undefined) { - completionTokens = 0 - } - - if (totalTokens !== undefined && promptTokens === undefined) { - promptTokens = totalTokens - completionTokens - } - - if (promptTokens === undefined || totalTokens === undefined) { - return null - } - - const normalizedPrompt = Math.max(0, Math.round(promptTokens)) - const normalizedCompletion = Math.max(0, Math.round(completionTokens ?? 0)) - const normalizedTotal = Math.max( - 0, - Math.round(totalTokens ?? normalizedPrompt + normalizedCompletion) - ) - - if (normalizedTotal <= 0 || (normalizedPrompt === 0 && normalizedCompletion === 0)) { - return null - } - - return { - promptTokens: normalizedPrompt, - completionTokens: normalizedCompletion, - totalTokens: normalizedTotal, - } -} - -function normalizeModelForBilling(model: string): string { - const base = model.includes('/') ? model.split('/').pop() || model : model - return base.toLowerCase() -} - -async function recordBilledUsage(params: { - userId: string - workflowId?: string - usage: any - billingModel: string - remoteModel?: string | null - billingKeyPrefix: 'copilot-billing' | 'copilot-completion-billing' - billingKeyId?: string | null - reason: 'copilot_context_usage' | 'copilot_completion_usage' -}): Promise { - const { userId, workflowId, usage, billingModel, remoteModel, billingKeyPrefix, billingKeyId, reason } = - params - - const metrics = extractTokenMetrics(usage) - if (!metrics) { - logger.info('Skipping copilot billing - no token metrics available', { - billingKeyPrefix, - billingKeyId, - reason, - }) - return { billed: false, reason: 'no_token_metrics' } - } - - const billingKey = billingKeyId ? `${billingKeyPrefix}:${billingKeyId}` : null - if (billingKey && (await hasProcessedMessage(billingKey))) { - logger.info('Copilot billing already processed', { billingKey, reason }) - return { billed: false, duplicate: true } - } - - const { costUsd: costToAdd, normalizedModel, billingContext } = await calculateCopilotCostUsd({ - userId, - workflowId, - billingModel, - remoteModel, - promptTokens: metrics.promptTokens, - completionTokens: metrics.completionTokens, - }) - if (costToAdd <= 0) { - logger.info('Skipping copilot billing - calculated cost is zero', { - userId, - workflowId, - billingKeyId, - model: normalizedModel, - reason, - }) - return { billed: false, reason: 'zero_cost' } - } - - const extraUpdates: Record = { - totalCopilotCost: sql`total_copilot_cost + ${costToAdd}`, - currentPeriodCopilotCost: sql`current_period_copilot_cost + ${costToAdd}`, - totalCopilotCalls: sql`total_copilot_calls + 1`, - } - - if (metrics.totalTokens > 0) { - extraUpdates.totalCopilotTokens = sql`total_copilot_tokens + ${metrics.totalTokens}` - } - - const didAccrue = await accrueUserUsageCost({ - userId, - workflowId, - cost: costToAdd, - extraUpdates, - reason, - }) - - if (!didAccrue) { - logger.warn('Copilot billing skipped - ledger record not found', { - userId, - workflowId, - billingKeyId, - reason, - }) - return { billed: false, reason: 'ledger_not_found' } - } - - if (billingKey) { - await markMessageAsProcessed(billingKey, BILLING_EVENT_TTL_SECONDS) - } - - logger.info('Copilot billing recorded', { - userId, - billingUserId: billingContext?.billingUserId ?? userId, - workflowId, - billingKeyId, - cost: costToAdd, - tokens: metrics.totalTokens, - model: normalizedModel, - reason, - }) - - return { - billed: true, - duplicate: false, - cost: costToAdd, - tokens: metrics.totalTokens, - model: normalizedModel, - } -} - -async function resolveEffectiveCopilotTier(params: { - userId: string - workflowId?: string -}): Promise<{ - effectiveTier: any - billingContext: Awaited> | null -}> { - const billingContext = params.workflowId - ? await resolveWorkflowBillingContext({ - workflowId: params.workflowId, - actorUserId: params.userId, - }) - : null - const effectiveTier = params.workflowId - ? billingContext?.subscription?.tier ?? null - : (await getPersonalEffectiveSubscription(params.userId))?.tier ?? null - - if (!effectiveTier) { - throw new Error( - params.workflowId - ? `No active workflow subscription tier found for billed copilot usage on workflow ${params.workflowId}` - : `No active personal subscription tier found for billed copilot usage for user ${params.userId}` - ) - } - - return { - effectiveTier, - billingContext, - } -} - -async function calculateCopilotCostUsd(params: { - userId: string - workflowId?: string - billingModel: string - remoteModel?: string | null - promptTokens: number - completionTokens: number - fallbackUsd?: number -}): Promise<{ - costUsd: number - normalizedModel: string - billingContext: Awaited> | null -}> { - const modelToUse = - typeof params.remoteModel === 'string' && params.remoteModel.length > 0 - ? params.remoteModel - : params.billingModel - const normalizedModel = normalizeModelForBilling(modelToUse) - const costResult = calculateCost( - normalizedModel, - params.promptTokens, - params.completionTokens, - false - ) - const { effectiveTier, billingContext } = await resolveEffectiveCopilotTier({ - userId: params.userId, - workflowId: params.workflowId, - }) - const rawCostUsd = Number(costResult.total || 0) * getTierCopilotCostMultiplier(effectiveTier) - - return { - costUsd: rawCostUsd > 0 ? rawCostUsd : params.fallbackUsd ?? 0, - normalizedModel, - billingContext, - } -} - -async function calculateReservationUsdFromEstimate(params: { - userId: string - workflowId?: string - model: string - estimatedPromptTokens: number - reservedCompletionTokens: number -}): Promise { - const { costUsd } = await calculateCopilotCostUsd({ - userId: params.userId, - workflowId: params.workflowId, - billingModel: params.model, - promptTokens: params.estimatedPromptTokens, - completionTokens: params.reservedCompletionTokens, - fallbackUsd: DEFAULT_ESTIMATED_RESERVATION_USD, - }) - - return costUsd -} - async function fetchContextUsageFromCopilot(params: { conversationId: string model: z.infer['model'] workflowId?: string + workspaceId?: string provider?: z.infer['provider'] userId: string }) { - const { conversationId, model, workflowId, provider, userId } = params + const { conversationId, model, workflowId, workspaceId, provider, userId } = params const { providerConfig } = await buildCopilotRuntimeProviderConfig({ model, provider, @@ -430,6 +94,7 @@ async function fetchContextUsageFromCopilot(params: { model, userId, ...(workflowId ? { workflowId } : {}), + ...(workspaceId ? { workspaceId } : {}), provider: providerConfig, } @@ -445,14 +110,11 @@ async function fetchContextUsageFromCopilot(params: { } async function handleContextUsage( - req: NextRequest, payload: z.infer ): Promise { - const { conversationId, model, workflowId, provider, bill, assistantMessageId, billingModel } = - payload - const internalAuth = checkInternalApiKey(req) - const session = !internalAuth.success ? await getSession() : null - const userId = internalAuth.success ? payload.userId : session?.user?.id + const { conversationId, model, workflowId, workspaceId, provider } = payload + const session = await getSession() + const userId = session?.user?.id if (!userId) { logger.warn('[Usage API] No session/user ID for context usage') @@ -463,6 +125,7 @@ async function handleContextUsage( conversationId, model, workflowId, + workspaceId, provider, userId, }) @@ -480,76 +143,10 @@ async function handleContextUsage( } const data = await simAgentResponse.json() - - const shouldBill = Boolean(bill && assistantMessageId && !internalAuth.success && !isHosted) - if (!shouldBill) { - return NextResponse.json(data) - } - - if (!(await isBillingEnabledForRuntime())) { - return NextResponse.json({ - ...data, - billing: { billed: false, reason: 'billing_disabled' }, - }) - } - - try { - const billing = await recordBilledUsage({ - userId, - workflowId, - usage: data, - billingModel: billingModel || model, - remoteModel: data?.model, - billingKeyPrefix: 'copilot-billing', - billingKeyId: assistantMessageId, - reason: 'copilot_context_usage', - }) - return NextResponse.json({ - ...data, - billing, - }) - } catch (billingError) { - logger.error('Failed to bill copilot context usage', { - error: billingError, - conversationId, - assistantMessageId, - }) - return NextResponse.json({ - ...data, - billing: { billed: false, reason: 'ledger_not_found' }, - }) - } + return NextResponse.json(data) } -async function releaseCommittedReservation(reservationId?: string): Promise { - if (!reservationId) return - if (reservationId === BILLING_DISABLED_RESERVATION_ID) { - return - } - - await releaseCopilotUsageReservation({ reservationId }).catch((error) => { - logger.warn('Failed to release copilot usage reservation after commit', { - reservationId, - error: error instanceof Error ? error.message : String(error), - }) - }) -} - -async function withCommittedReservationRelease( - reservationId: string | undefined, - operation: () => Promise -): Promise { - try { - return await operation() - } finally { - await releaseCommittedReservation(reservationId) - } -} - -function buildBillingDisabledReservation(params: { - userId: string - reservationId?: string -}) { +function buildBillingDisabledReservation(params: { userId: string; reservationId?: string }) { return { allowed: true, status: 200, @@ -580,7 +177,7 @@ async function handleReserveUsage( const requestedUsd = 'requestedUsd' in payload ? payload.requestedUsd - : await calculateReservationUsdFromEstimate({ + : await calculateCopilotReservationUsdFromEstimate({ userId: payload.userId, workflowId: payload.workflowId, model: payload.model, @@ -598,133 +195,35 @@ async function handleReserveUsage( return NextResponse.json(result, { status: result.status }) } -async function handleAdjustUsage( - req: NextRequest, - payload: z.infer +async function handleCompletionCommit( + payload: z.infer ): Promise { - const auth = checkInternalApiKey(req) - if (!auth.success) { - return new NextResponse(null, { status: 401 }) - } - - if (!(await isBillingEnabledForRuntime())) { - return NextResponse.json( - buildBillingDisabledReservation({ - userId: payload.userId, - reservationId: payload.reservationId, - }) - ) - } - - const requestedUsd = await calculateReservationUsdFromEstimate({ - userId: payload.userId, - workflowId: payload.workflowId, - model: payload.model, - estimatedPromptTokens: payload.estimatedPromptTokens, - reservedCompletionTokens: payload.reservedCompletionTokens, - }) - - const result = await adjustCopilotUsageReservation({ - reservationId: payload.reservationId, + return await commitCopilotUsageReservation({ userId: payload.userId, workflowId: payload.workflowId, - requestedUsd, - reason: payload.reason, - }) - - return NextResponse.json(result, { status: result.status }) -} - -async function handleContextCommit( - req: NextRequest, - payload: z.infer -): Promise { - const auth = checkInternalApiKey(req) - if (!auth.success) { - return new NextResponse(null, { status: 401 }) - } - - return withCommittedReservationRelease(payload.reservationId, async () => { - const simAgentResponse = await fetchContextUsageFromCopilot({ - conversationId: payload.conversationId, - model: payload.model, - workflowId: payload.workflowId, - provider: payload.provider, - userId: payload.userId, - }) - - if (!simAgentResponse.ok) { - const errorText = await simAgentResponse.text().catch(() => '') - logger.warn('[Usage API] TradingGoose agent request failed during commit', { - status: simAgentResponse.status, - error: errorText, - reservationId: payload.reservationId, - }) - return NextResponse.json( - { error: 'Failed to fetch context usage from copilot' }, - { status: simAgentResponse.status } - ) - } - - const data = await simAgentResponse.json() + reservationId: + payload.reservationId === BILLING_DISABLED_RESERVATION_ID ? undefined : payload.reservationId, + operation: async () => { + if (!(await isBillingEnabledForRuntime())) { + return NextResponse.json({ + success: true, + billing: { billed: false, reason: 'billing_disabled' }, + }) + } - if (!(await isBillingEnabledForRuntime())) { - return NextResponse.json({ - ...data, - billing: { billed: false, reason: 'billing_disabled' }, + const billing = await recordCopilotCompletionUsage({ + userId: payload.userId, + workflowId: payload.workflowId, + usage: payload.usage, + billingModel: payload.model, + billingKeyId: payload.completionId, }) - } - - const billing = await recordBilledUsage({ - userId: payload.userId, - workflowId: payload.workflowId, - usage: data, - billingModel: payload.billingModel || payload.model, - remoteModel: data?.model, - billingKeyPrefix: 'copilot-billing', - billingKeyId: payload.assistantMessageId, - reason: 'copilot_context_usage', - }) - - return NextResponse.json({ - ...data, - billing, - }) - }) -} - -async function handleCompletionCommit( - req: NextRequest, - payload: z.infer -): Promise { - const auth = checkInternalApiKey(req) - if (!auth.success) { - return new NextResponse(null, { status: 401 }) - } - return withCommittedReservationRelease(payload.reservationId, async () => { - if (!(await isBillingEnabledForRuntime())) { return NextResponse.json({ success: true, - billing: { billed: false, reason: 'billing_disabled' }, + billing, }) - } - - const billing = await recordBilledUsage({ - userId: payload.userId, - workflowId: payload.workflowId, - usage: payload.usage, - billingModel: payload.model, - remoteModel: payload.remoteModel, - billingKeyPrefix: 'copilot-completion-billing', - billingKeyId: payload.completionId, - reason: 'copilot_completion_usage', - }) - - return NextResponse.json({ - success: true, - billing, - }) + }, }) } @@ -753,7 +252,7 @@ async function handleReleaseUsage( /** * POST /api/copilot/usage - * Unified copilot usage endpoint for context inspection/billing and raw completion billing. + * Unified copilot usage endpoint for context inspection, reservation control, and completion billing. */ export async function POST(req: NextRequest) { try { @@ -769,7 +268,8 @@ export async function POST(req: NextRequest) { return NextResponse.json({ error: 'Invalid JSON body' }, { status: 400 }) } - const action = body && typeof body === 'object' ? (body as Record).action : null + const action = + body && typeof body === 'object' ? (body as Record).action : null if (action === 'reserve') { const parsed = ReserveUsageRequestSchema.safeParse(body) if (!parsed.success) { @@ -785,48 +285,25 @@ export async function POST(req: NextRequest) { return await handleReserveUsage(req, parsed.data) } - if (action === 'adjust') { - const parsed = AdjustUsageRequestSchema.safeParse(body) - if (!parsed.success) { - logger.warn('Invalid copilot usage adjust request', { errors: parsed.error.errors }) - return NextResponse.json( - { - error: 'Invalid request body', - details: parsed.error.errors, - }, - { status: 400 } - ) + if (action === 'commit') { + const auth = checkInternalApiKey(req) + if (!auth.success) { + return new NextResponse(null, { status: 401 }) } - return await handleAdjustUsage(req, parsed.data) - } - if (action === 'commit') { - const kind = body && typeof body === 'object' ? (body as Record).kind : null - const parsed = - kind === 'context' - ? ContextCommitRequestSchema.safeParse(body) - : kind === 'completion' - ? CompletionCommitRequestSchema.safeParse(body) - : null - - if (!parsed || !parsed.success) { - logger.warn('Invalid copilot usage commit request', { - errors: parsed && !parsed.success ? parsed.error.errors : [{ message: 'Invalid commit kind' }], - }) + const parsed = CompletionCommitRequestSchema.safeParse(body) + if (!parsed.success) { + logger.warn('Invalid copilot usage commit request', { errors: parsed.error.errors }) return NextResponse.json( { error: 'Invalid request body', - details: parsed && !parsed.success ? parsed.error.errors : [{ message: 'Invalid commit kind' }], + details: parsed.error.errors, }, { status: 400 } ) } - if (parsed.data.kind === 'context') { - return await handleContextCommit(req, parsed.data) - } - - return await handleCompletionCommit(req, parsed.data) + return await handleCompletionCommit(parsed.data) } if (action === 'release') { @@ -857,7 +334,7 @@ export async function POST(req: NextRequest) { ) } - return await handleContextUsage(req, parsed.data) + return await handleContextUsage(parsed.data) } catch (error) { logger.error('Failed to process copilot usage request', { error }) return NextResponse.json({ error: 'Internal server error' }, { status: 500 }) diff --git a/apps/tradinggoose/app/api/files/serve/[...path]/route.test.ts b/apps/tradinggoose/app/api/files/serve/[...path]/route.test.ts index 42718c52a..8e29854ed 100644 --- a/apps/tradinggoose/app/api/files/serve/[...path]/route.test.ts +++ b/apps/tradinggoose/app/api/files/serve/[...path]/route.test.ts @@ -60,7 +60,6 @@ describe('File Serve API Route', () => { }, getEnv: vi.fn((key: string) => { if (key === 'NEXT_PUBLIC_APP_URL') return 'https://app.tradinggoose.ai' - if (key === 'NEXT_PUBLIC_IS_PREVIEW_DEVELOPMENT') return 'false' return undefined }), })) diff --git a/apps/tradinggoose/app/api/indicators/custom/import/route.test.ts b/apps/tradinggoose/app/api/indicators/custom/import/route.test.ts index e9de87c8b..fdfe1d765 100644 --- a/apps/tradinggoose/app/api/indicators/custom/import/route.test.ts +++ b/apps/tradinggoose/app/api/indicators/custom/import/route.test.ts @@ -72,7 +72,6 @@ describe('Indicators import route', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: {}, }, @@ -93,7 +92,6 @@ describe('Indicators import route', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: {}, }, diff --git a/apps/tradinggoose/app/api/indicators/custom/route.ts b/apps/tradinggoose/app/api/indicators/custom/route.ts index e904862bc..43d4acd7d 100644 --- a/apps/tradinggoose/app/api/indicators/custom/route.ts +++ b/apps/tradinggoose/app/api/indicators/custom/route.ts @@ -36,7 +36,6 @@ const IndicatorSchema = z.object({ z.object({ id: z.string().optional(), name: z.string().min(1, 'Indicator name is required'), - color: z.string().optional(), pineCode: z.string().default(''), inputMeta: z.record(z.any()).optional(), }) diff --git a/apps/tradinggoose/app/api/schedules/execute/route.test.ts b/apps/tradinggoose/app/api/schedules/execute/route.test.ts index ce360a70c..bb942537b 100644 --- a/apps/tradinggoose/app/api/schedules/execute/route.test.ts +++ b/apps/tradinggoose/app/api/schedules/execute/route.test.ts @@ -89,7 +89,7 @@ describe('Scheduled Workflow Execution API Route', () => { { id: 'schedule-1', workflowId: 'workflow-1', - blockId: null, + blockId: 'schedule-trigger-1', cronExpression: null, lastRanAt: null, failedCount: 0, @@ -151,7 +151,7 @@ describe('Scheduled Workflow Execution API Route', () => { expect(data.error).toContain('Trigger.dev is required for scheduled executions') }) - it('should queue schedules through pending execution when enabled', async () => { + it('should queue configured schedules and remove orphan schedule rows', async () => { vi.doMock('@/lib/auth/internal', () => ({ verifyCronAuth: vi.fn().mockReturnValue(null), })) @@ -189,18 +189,29 @@ describe('Scheduled Workflow Execution API Route', () => { isPendingExecutionLimitError: vi.fn(() => false), })) + let deletedScheduleWhere: Record | undefined vi.doMock('@tradinggoose/db', () => { const scheduleRows = [ { id: 'schedule-1', workflowId: 'workflow-1', - blockId: null, + blockId: 'schedule-trigger-1', cronExpression: null, lastRanAt: null, failedCount: 0, timezone: 'UTC', nextRunAt: new Date('2024-01-01T00:00:00.000Z'), }, + { + id: 'schedule-missing-trigger', + workflowId: 'workflow-2', + blockId: null, + cronExpression: null, + lastRanAt: null, + failedCount: 1, + timezone: 'UTC', + nextRunAt: new Date('2024-01-01T00:00:00.000Z'), + }, ] const workflowRows = [ @@ -231,6 +242,12 @@ describe('Scheduled Workflow Execution API Route', () => { }), } }), + delete: vi.fn().mockImplementation(() => ({ + where: vi.fn().mockImplementation((condition) => { + deletedScheduleWhere = condition + return Promise.resolve([]) + }), + })), } return { @@ -247,6 +264,12 @@ describe('Scheduled Workflow Execution API Route', () => { expect(response.status).toBe(200) const data = await response.json() expect(data).toHaveProperty('executedCount', 1) + expect(deletedScheduleWhere).toEqual( + expect.objectContaining({ + type: 'eq', + value: 'schedule-missing-trigger', + }) + ) expect(enqueuePendingExecutionMock).toHaveBeenCalledWith( expect.objectContaining({ executionType: 'schedule', @@ -349,7 +372,7 @@ describe('Scheduled Workflow Execution API Route', () => { { id: 'schedule-1', workflowId: 'workflow-1', - blockId: null, + blockId: 'schedule-trigger-1', cronExpression: null, lastRanAt: null, failedCount: 0, @@ -359,7 +382,7 @@ describe('Scheduled Workflow Execution API Route', () => { { id: 'schedule-2', workflowId: 'workflow-2', - blockId: null, + blockId: 'schedule-trigger-2', cronExpression: null, lastRanAt: null, failedCount: 0, diff --git a/apps/tradinggoose/app/api/schedules/execute/route.ts b/apps/tradinggoose/app/api/schedules/execute/route.ts index c99e5fd0b..de99b4c1e 100644 --- a/apps/tradinggoose/app/api/schedules/execute/route.ts +++ b/apps/tradinggoose/app/api/schedules/execute/route.ts @@ -1,8 +1,8 @@ import { db, workflow, workflowSchedule } from '@tradinggoose/db' import { and, eq, lte, not } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' -import { verifyCronAuth } from '@/lib/auth/internal' import { getApiKeyOwnerUserId } from '@/lib/api-key/service' +import { verifyCronAuth } from '@/lib/auth/internal' import { enqueuePendingExecution, isPendingExecutionLimitError, @@ -40,94 +40,95 @@ export async function GET(request: NextRequest) { const queuedSchedules = await Promise.all( dueSchedules.map(async (schedule) => { try { + if (typeof schedule.blockId !== 'string' || schedule.blockId.length === 0) { + logger.warn( + `[${requestId}] Removing schedule ${schedule.id}: missing schedule trigger block.` + ) + await db.delete(workflowSchedule).where(eq(workflowSchedule.id, schedule.id)) + return null + } + const [workflowRecord] = await db .select({ - workspaceId: workflow.workspaceId, - pinnedApiKeyId: workflow.pinnedApiKeyId, + workspaceId: workflow.workspaceId, + pinnedApiKeyId: workflow.pinnedApiKeyId, + }) + .from(workflow) + .where(eq(workflow.id, schedule.workflowId)) + .limit(1) + + if (!workflowRecord) { + logger.warn( + `[${requestId}] Workflow ${schedule.workflowId} not found for schedule ${schedule.id}` + ) + return null + } + + const actorUserId = await getApiKeyOwnerUserId(workflowRecord.pinnedApiKeyId) + + if (!actorUserId) { + logger.warn( + `[${requestId}] Skipping schedule ${schedule.id}: pinned API key required to attribute usage.` + ) + return null + } + + const pendingExecutionId = `schedule_execution:${schedule.id}:${schedule.nextRunAt?.toISOString() ?? now.toISOString()}` + const payload = { + executionId: pendingExecutionId, + scheduleId: schedule.id, + workflowId: schedule.workflowId, + blockId: schedule.blockId, + cronExpression: schedule.cronExpression || undefined, + lastRanAt: schedule.lastRanAt?.toISOString(), + failedCount: schedule.failedCount || 0, + timezone: schedule.timezone, + now: now.toISOString(), + } + + const handle = await enqueuePendingExecution({ + executionType: 'schedule', + pendingExecutionId, + workflowId: schedule.workflowId, + workspaceId: workflowRecord.workspaceId, + userId: actorUserId, + source: 'schedule', + orderingKey: `schedule:${schedule.id}`, + requestId, + payload, }) - .from(workflow) - .where(eq(workflow.id, schedule.workflowId)) - .limit(1) - if (!workflowRecord) { - logger.warn( - `[${requestId}] Workflow ${schedule.workflowId} not found for schedule ${schedule.id}`, - ) - return null - } - - const actorUserId = await getApiKeyOwnerUserId( - workflowRecord.pinnedApiKeyId, - ) + if (!handle.inserted) return null - if (!actorUserId) { - logger.warn( - `[${requestId}] Skipping schedule ${schedule.id}: pinned API key required to attribute usage.`, + logger.info( + `[${requestId}] Queued schedule execution ${handle.pendingExecutionId} for workflow ${schedule.workflowId}` ) - return null - } - - const pendingExecutionId = `schedule_execution:${schedule.id}:${schedule.nextRunAt?.toISOString() ?? now.toISOString()}` - const payload = { - executionId: pendingExecutionId, - scheduleId: schedule.id, - workflowId: schedule.workflowId, - blockId: schedule.blockId || undefined, - cronExpression: schedule.cronExpression || undefined, - lastRanAt: schedule.lastRanAt?.toISOString(), - failedCount: schedule.failedCount || 0, - timezone: schedule.timezone, - now: now.toISOString(), - } - - const handle = await enqueuePendingExecution({ - executionType: 'schedule', - pendingExecutionId, - workflowId: schedule.workflowId, - workspaceId: workflowRecord.workspaceId, - userId: actorUserId, - source: 'schedule', - orderingKey: `schedule:${schedule.id}`, - requestId, - payload, - }) - - if (!handle.inserted) return null - - logger.info( - `[${requestId}] Queued schedule execution ${handle.pendingExecutionId} for workflow ${schedule.workflowId}`, - ) - return handle - } catch (error) { - if (isPendingExecutionLimitError(error)) { - logger.warn( - `[${requestId}] Pending backlog full for schedule ${schedule.id}`, - { + return handle + } catch (error) { + if (isPendingExecutionLimitError(error)) { + logger.warn(`[${requestId}] Pending backlog full for schedule ${schedule.id}`, { workflowId: schedule.workflowId, pendingCount: error.details.pendingCount, maxPendingCount: error.details.maxPendingCount, - }, + }) + return null + } + + if (error instanceof TriggerExecutionUnavailableError) { + throw error + } + + logger.error( + `[${requestId}] Failed to trigger schedule execution for workflow ${schedule.workflowId}`, + error ) return null } - - if (error instanceof TriggerExecutionUnavailableError) { - throw error - } - - logger.error( - `[${requestId}] Failed to trigger schedule execution for workflow ${schedule.workflowId}`, - error - ) - return null - } }) ) const queuedCount = queuedSchedules.filter((result) => result !== null).length - logger.info( - `[${requestId}] Queued ${queuedCount} schedule executions to pending execution`, - ) + logger.info(`[${requestId}] Queued ${queuedCount} schedule executions to pending execution`) return NextResponse.json({ message: 'Scheduled workflow executions processed', diff --git a/apps/tradinggoose/app/api/schedules/route.ts b/apps/tradinggoose/app/api/schedules/route.ts index f12fde77e..92da584ce 100644 --- a/apps/tradinggoose/app/api/schedules/route.ts +++ b/apps/tradinggoose/app/api/schedules/route.ts @@ -5,7 +5,6 @@ import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' -import { resolveTimezoneState } from '@/lib/timezone/timezone-resolver' import { getUserEntityPermissions } from '@/lib/permissions/utils' import { type BlockState, @@ -15,13 +14,14 @@ import { getSubBlockValue, validateCronExpression, } from '@/lib/schedules/utils' +import { resolveTimezoneState } from '@/lib/timezone/timezone-resolver' import { generateRequestId } from '@/lib/utils' const logger = createLogger('ScheduledAPI') const ScheduleRequestSchema = z.object({ workflowId: z.string(), - blockId: z.string().optional(), + blockId: z.string().min(1), state: z.object({ blocks: z.record(z.any()), edges: z.array(z.any()), @@ -212,68 +212,37 @@ export async function POST(req: NextRequest) { return NextResponse.json({ error: 'Not authorized to modify this workflow' }, { status: 403 }) } - // Find the target block - prioritize the specific blockId if provided - let targetBlock: BlockState | undefined - if (blockId) { - targetBlock = Object.values(state.blocks).find((block: any) => block.id === blockId) as - | BlockState - | undefined - } else { - targetBlock = Object.values(state.blocks).find( - (block: any) => block.type === 'schedule' - ) as BlockState | undefined - } + const targetBlock = Object.values(state.blocks).find((block: any) => block.id === blockId) as + | BlockState + | undefined if (!targetBlock) { - logger.warn(`[${requestId}] No schedule block found in workflow ${workflowId}`) - return NextResponse.json( - { error: 'No schedule block found in workflow' }, - { status: 400 } - ) + logger.warn(`[${requestId}] Schedule block ${blockId} not found in workflow ${workflowId}`) + return NextResponse.json({ error: 'Schedule block not found in workflow' }, { status: 400 }) } const scheduleType = getSubBlockValue(targetBlock, 'scheduleType') - const isScheduleBlock = targetBlock.type === 'schedule' + if (targetBlock.type !== 'schedule') { + return NextResponse.json({ error: 'Schedule block is required' }, { status: 400 }) + } const scheduleValues = getScheduleTimeValues(targetBlock) const hasValidConfig = hasValidScheduleConfig(scheduleType, scheduleValues, targetBlock) - // Debug logging to understand why validation fails - logger.info(`[${requestId}] Schedule validation debug:`, { - workflowId, - blockId, - blockType: targetBlock.type, - scheduleType, - hasValidConfig, - scheduleValues: { - minutesInterval: scheduleValues.minutesInterval, - dailyTime: scheduleValues.dailyTime, - cronExpression: scheduleValues.cronExpression, - }, - }) - if (!hasValidConfig) { logger.info( `[${requestId}] Removing schedule for workflow ${workflowId} - no valid configuration found` ) - // Build delete conditions - const deleteConditions = [eq(workflowSchedule.workflowId, workflowId)] - if (blockId) { - deleteConditions.push(eq(workflowSchedule.blockId, blockId)) - } - await db .delete(workflowSchedule) - .where(deleteConditions.length > 1 ? and(...deleteConditions) : deleteConditions[0]) + .where( + and(eq(workflowSchedule.workflowId, workflowId), eq(workflowSchedule.blockId, blockId)) + ) return NextResponse.json({ message: 'Schedule removed' }) } - if (isScheduleBlock) { - logger.info(`[${requestId}] Processing schedule trigger block for workflow ${workflowId}`) - } - logger.debug(`[${requestId}] Schedule type for workflow ${workflowId}: ${scheduleType}`) let cronExpression: string | null = null @@ -313,7 +282,12 @@ export async function POST(req: NextRequest) { } } - nextRunAt = calculateNextRunTime(defaultScheduleType, scheduleValues, undefined, utcOffsetMinutes) + nextRunAt = calculateNextRunTime( + defaultScheduleType, + scheduleValues, + undefined, + utcOffsetMinutes + ) logger.debug( `[${requestId}] Generated cron: ${cronExpression}, next run at: ${nextRunAt.toISOString()}` diff --git a/apps/tradinggoose/app/api/templates/[id]/route.ts b/apps/tradinggoose/app/api/templates/[id]/route.ts index 9d73df795..1b880f10f 100644 --- a/apps/tradinggoose/app/api/templates/[id]/route.ts +++ b/apps/tradinggoose/app/api/templates/[id]/route.ts @@ -5,7 +5,7 @@ import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' -import { hasAdminPermission } from '@/lib/permissions/utils' +import { hasWorkspaceAdminAccess } from '@/lib/permissions/utils' import { generateRequestId } from '@/lib/utils' const logger = createLogger('TemplateByIdAPI') @@ -121,7 +121,7 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{ const workspaceId = wfRows[0]?.workspaceId as string | null | undefined if (workspaceId) { - const hasAdmin = await hasAdminPermission(session.user.id, workspaceId) + const hasAdmin = await hasWorkspaceAdminAccess(session.user.id, workspaceId) if (hasAdmin) canUpdate = true } } @@ -196,7 +196,7 @@ export async function DELETE( const workspaceId = wfRows[0]?.workspaceId as string | null | undefined if (workspaceId) { - const hasAdmin = await hasAdminPermission(session.user.id, workspaceId) + const hasAdmin = await hasWorkspaceAdminAccess(session.user.id, workspaceId) if (hasAdmin) canDelete = true } } diff --git a/apps/tradinggoose/app/api/users/me/settings/route.ts b/apps/tradinggoose/app/api/users/me/settings/route.ts index 15babfe24..c28da522a 100644 --- a/apps/tradinggoose/app/api/users/me/settings/route.ts +++ b/apps/tradinggoose/app/api/users/me/settings/route.ts @@ -8,10 +8,15 @@ import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { generateRequestId } from '@/lib/utils' -import { defaultLocale, isLocaleCode, locales } from '@/i18n/utils' +import { + defaultLocale, + isLocaleCode, + LOCALE_COOKIE, + LOCALE_COOKIE_MAX_AGE, + locales, +} from '@/i18n/utils' const logger = createLogger('UserSettingsAPI') -const LOCALE_COOKIE = 'NEXT_LOCALE' const SettingsSchema = z.object({ theme: z.enum(['system', 'light', 'dark']).optional(), @@ -40,7 +45,7 @@ function withPreferredLocaleCookie(response: NextResponse, locale: string | null if (locale && isLocaleCode(locale)) { response.cookies.set(LOCALE_COOKIE, locale, { path: '/', - maxAge: 60 * 60 * 24 * 365, + maxAge: LOCALE_COOKIE_MAX_AGE, sameSite: 'lax', }) } @@ -60,8 +65,9 @@ export async function GET() { const session = await getSession() if (!session?.user?.id) { - logger.info(`[${requestId}] Returning default settings for unauthenticated user`) - return NextResponse.json({ data: defaultSettings }, { status: 200 }) + const preferredLocale = await getRuntimeLocale() + logger.info(`[${requestId}] Returning runtime settings for unauthenticated user`) + return NextResponse.json({ data: { ...defaultSettings, preferredLocale } }, { status: 200 }) } const userId = session.user.id diff --git a/apps/tradinggoose/app/api/webhooks/trigger/[path]/route.test.ts b/apps/tradinggoose/app/api/webhooks/trigger/[path]/route.test.ts index 9d87de34b..ce590f345 100644 --- a/apps/tradinggoose/app/api/webhooks/trigger/[path]/route.test.ts +++ b/apps/tradinggoose/app/api/webhooks/trigger/[path]/route.test.ts @@ -248,6 +248,7 @@ describe('Webhook Trigger API Route', () => { isActive: true, providerConfig: { requireAuth: false }, workflowId: 'test-workflow-id', + blockId: 'generic-trigger-id', rateLimitCount: 100, rateLimitPeriod: 60, }) @@ -282,6 +283,7 @@ describe('Webhook Trigger API Route', () => { isActive: true, providerConfig: { requireAuth: true, token: 'test-token-123' }, workflowId: 'test-workflow-id', + blockId: 'generic-trigger-id', }) globalMockData.workflows.push({ id: 'test-workflow-id', @@ -317,6 +319,7 @@ describe('Webhook Trigger API Route', () => { secretHeaderName: 'X-Custom-Auth', }, workflowId: 'test-workflow-id', + blockId: 'generic-trigger-id', }) globalMockData.workflows.push({ id: 'test-workflow-id', @@ -348,6 +351,7 @@ describe('Webhook Trigger API Route', () => { isActive: true, providerConfig: { requireAuth: true, token: 'case-test-token' }, workflowId: 'test-workflow-id', + blockId: 'generic-trigger-id', }) globalMockData.workflows.push({ id: 'test-workflow-id', @@ -392,6 +396,7 @@ describe('Webhook Trigger API Route', () => { secretHeaderName: 'X-Secret-Key', }, workflowId: 'test-workflow-id', + blockId: 'generic-trigger-id', }) globalMockData.workflows.push({ id: 'test-workflow-id', diff --git a/apps/tradinggoose/app/api/workflows/[id]/autolayout/route.ts b/apps/tradinggoose/app/api/workflows/[id]/autolayout/route.ts index d4d972763..45217e348 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/autolayout/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/autolayout/route.ts @@ -1,11 +1,10 @@ import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' -import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { generateRequestId } from '@/lib/utils' import { applyAutoLayout } from '@/lib/workflows/autolayout' import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/db-helpers' -import { readWorkflowAccessContext } from '@/lib/workflows/utils' +import { validateWorkflowPermissions } from '@/lib/workflows/utils' export const dynamic = 'force-dynamic' @@ -35,10 +34,12 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ const { id: workflowId } = await params try { - const session = await getSession() - if (!session?.user?.id) { - logger.warn(`[${requestId}] Unauthorized autolayout attempt for workflow ${workflowId}`) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + const { error, session } = await validateWorkflowPermissions(workflowId, requestId, 'write') + if (error || !session?.user?.id) { + return NextResponse.json( + { error: error?.message ?? 'Unauthorized' }, + { status: error?.status ?? 401 } + ) } const userId = session.user.id @@ -50,28 +51,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ userId, }) - const accessContext = await readWorkflowAccessContext(workflowId, userId) - const workflowData = accessContext?.workflow - - if (!workflowData) { - logger.warn(`[${requestId}] Workflow ${workflowId} not found for autolayout`) - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) - } - - const canUpdate = - accessContext?.isOwner || - (workflowData.workspaceId - ? accessContext?.workspacePermission === 'write' || - accessContext?.workspacePermission === 'admin' - : false) - - if (!canUpdate) { - logger.warn( - `[${requestId}] User ${userId} denied permission to autolayout workflow ${workflowId}` - ) - return NextResponse.json({ error: 'Access denied' }, { status: 403 }) - } - let currentWorkflowData: { blocks: Record; edges: any[] } | null if (layoutOptions.blocks && layoutOptions.edges) { diff --git a/apps/tradinggoose/app/api/workflows/[id]/duplicate/route.ts b/apps/tradinggoose/app/api/workflows/[id]/duplicate/route.ts index 26f84daa6..b61846eb3 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/duplicate/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/duplicate/route.ts @@ -25,7 +25,6 @@ const logger = createLogger('WorkflowDuplicateAPI') const DuplicateRequestSchema = z.object({ name: z.string().min(1, 'Name is required'), description: z.string().optional(), - color: z.string().optional(), workspaceId: z.string().min(1, 'Workspace ID is required'), folderId: z.string().nullable().optional(), }) @@ -79,7 +78,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: try { const body = await req.json() - const { name, description, color, workspaceId, folderId } = DuplicateRequestSchema.parse(body) + const { name, description, workspaceId, folderId } = DuplicateRequestSchema.parse(body) logger.info( `[${requestId}] Duplicating workflow ${sourceWorkflowId} for user ${session.user.id}` @@ -122,10 +121,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: const newWorkflowId = crypto.randomUUID() const now = new Date() - const resolvedColor = - typeof color === 'string' && color.trim().length > 0 - ? color.trim() - : getStableVibrantColor(newWorkflowId) + const resolvedColor = getStableVibrantColor(newWorkflowId) const duplicatedWorkflowState = regenerateWorkflowStateIds(sourceArtifacts.workflowState) const duplicatedVariables = remapVariableIds(sourceArtifacts.variables, newWorkflowId) diff --git a/apps/tradinggoose/app/api/workflows/[id]/execute/route.test.ts b/apps/tradinggoose/app/api/workflows/[id]/execute/route.test.ts index 50513b392..d885774e6 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/execute/route.test.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/execute/route.test.ts @@ -470,7 +470,13 @@ describe('/api/workflows/[id]/execute', () => { expect(createHttpResponseFromBlockMock).toHaveBeenCalledWith(responseResult) }) - it('rejects non-API execution control fields on the deployed execute adapter', async () => { + it.each([ + 'workflowTriggerType', + 'triggerType', + 'executionTarget', + 'startBlockId', + 'triggerBlockId', + ])('rejects %s on the deployed execute adapter', async (field) => { const { POST } = await import('./route') const response = await POST( new NextRequest('https://example.com/api/workflows/workflow-1/execute', { @@ -479,16 +485,14 @@ describe('/api/workflows/[id]/execute', () => { 'Content-Type': 'application/json', 'X-API-Key': 'key-1', }, - body: JSON.stringify({ - workflowTriggerType: 'chat', - }), + body: JSON.stringify({ [field]: 'chat' }), }), { params: Promise.resolve({ id: 'workflow-1' }) } ) expect(response.status).toBe(400) await expect(response.json()).resolves.toMatchObject({ - error: 'Field "workflowTriggerType" is not supported by the deployed API execute endpoint', + error: `Field "${field}" is not supported by the deployed API execute endpoint`, }) expect(enqueuePendingExecutionMock).not.toHaveBeenCalled() }) diff --git a/apps/tradinggoose/app/api/workflows/[id]/execute/route.ts b/apps/tradinggoose/app/api/workflows/[id]/execute/route.ts index d904867e0..e3930f4de 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/execute/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/execute/route.ts @@ -31,13 +31,16 @@ const API_EXECUTION_POLL_INTERVAL_MS = 1_000 const API_EXECUTION_WAIT_TIMEOUT_MS = 25_000 const UNSUPPORTED_API_EXECUTE_FIELDS = [ 'workflowTriggerType', + 'triggerType', + 'executionTarget', + 'startBlockId', + 'triggerBlockId', 'isSecureMode', 'useDraftState', 'isClientSession', 'workflowData', 'workflowStateOverride', 'workflowVariables', - 'startBlockId', 'executionId', ] as const diff --git a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/[webhookId]/route.ts b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/[webhookId]/route.ts index 096117770..20987c9c1 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/[webhookId]/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/[webhookId]/route.ts @@ -1,11 +1,11 @@ import { db } from '@tradinggoose/db' -import { permissions, workflow, workflowLogWebhook } from '@tradinggoose/db/schema' +import { workflowLogWebhook } from '@tradinggoose/db/schema' import { and, eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' -import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { encryptSecret } from '@/lib/utils-server' +import { validateWorkflowPermissions } from '@/lib/workflows/utils' import { cancelActiveWebhookDeliveries } from '../delivery-cancellation' const logger = createLogger('WorkflowLogWebhookUpdate') @@ -39,32 +39,15 @@ export async function PUT( { params }: { params: Promise<{ id: string; webhookId: string }> } ) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId, webhookId } = await params - const userId = session.user.id - - // Check if user has access to the workflow - const hasAccess = await db - .select({ id: workflow.id }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) + const { error, session } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error || !session?.user?.id) { + return NextResponse.json( + { error: error?.message ?? 'Unauthorized' }, + { status: error?.status ?? 401 } ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) } + const userId = session.user.id // Check if webhook exists and belongs to this workflow const existingWebhook = await db @@ -169,32 +152,15 @@ export async function DELETE( { params }: { params: Promise<{ id: string; webhookId: string }> } ) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId, webhookId } = await params - const userId = session.user.id - - // Check if user has access to the workflow - const hasAccess = await db - .select({ id: workflow.id }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) + const { error, session } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error || !session?.user?.id) { + return NextResponse.json( + { error: error?.message ?? 'Unauthorized' }, + { status: error?.status ?? 401 } ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) } + const userId = session.user.id await cancelActiveWebhookDeliveries(workflowId, webhookId) diff --git a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/route.ts b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/route.ts index cf2fad76d..7dd2c119c 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/route.ts @@ -1,12 +1,12 @@ import { db } from '@tradinggoose/db' -import { permissions, workflow, workflowLogWebhook } from '@tradinggoose/db/schema' +import { workflowLogWebhook } from '@tradinggoose/db/schema' import { and, eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { v4 as uuidv4 } from 'uuid' import { z } from 'zod' -import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { encryptSecret } from '@/lib/utils-server' +import { validateWorkflowPermissions } from '@/lib/workflows/utils' import { cancelActiveWebhookDeliveries } from './delivery-cancellation' const logger = createLogger('WorkflowLogWebhookAPI') @@ -31,30 +31,10 @@ const CreateWebhookSchema = z.object({ export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId } = await params - const userId = session.user.id - - const hasAccess = await db - .select({ id: workflow.id, userId: workflow.userId, workspaceId: workflow.workspaceId }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) - ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) + const { error } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error) { + return NextResponse.json({ error: error.message }, { status: error.status }) } const webhooks = await db @@ -83,34 +63,12 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId } = await params - const userId = session.user.id - - const hasAccess = await db - .select({ id: workflow.id, userId: workflow.userId, workspaceId: workflow.workspaceId }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) - ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) + const { error } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error) { + return NextResponse.json({ error: error.message }, { status: error.status }) } - const workflowRecord = hasAccess[0] - const body = await request.json() const validationResult = CreateWebhookSchema.safeParse(body) @@ -194,13 +152,7 @@ export async function DELETE( { params }: { params: Promise<{ id: string }> } ) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId } = await params - const userId = session.user.id const { searchParams } = new URL(request.url) const webhookId = searchParams.get('webhookId') @@ -208,22 +160,9 @@ export async function DELETE( return NextResponse.json({ error: 'webhookId is required' }, { status: 400 }) } - const hasAccess = await db - .select({ id: workflow.id }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) - ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) + const { error } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error) { + return NextResponse.json({ error: error.message }, { status: error.status }) } await cancelActiveWebhookDeliveries(workflowId, webhookId) diff --git a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/test/route.ts b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/test/route.ts index 6d195813a..66a1f5049 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/log-webhook/test/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/log-webhook/test/route.ts @@ -1,12 +1,12 @@ import { createHmac } from 'crypto' import { db } from '@tradinggoose/db' -import { permissions, workflow, workflowLogWebhook } from '@tradinggoose/db/schema' +import { workflowLogWebhook } from '@tradinggoose/db/schema' import { and, eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { v4 as uuidv4 } from 'uuid' -import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { decryptSecret } from '@/lib/utils-server' +import { validateWorkflowPermissions } from '@/lib/workflows/utils' const logger = createLogger('WorkflowLogWebhookTestAPI') @@ -19,13 +19,7 @@ function generateSignature(secret: string, timestamp: number, body: string): str export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { try { - const session = await getSession() - if (!session?.user?.id) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - const { id: workflowId } = await params - const userId = session.user.id const { searchParams } = new URL(request.url) const webhookId = searchParams.get('webhookId') @@ -33,22 +27,9 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ return NextResponse.json({ error: 'webhookId is required' }, { status: 400 }) } - const hasAccess = await db - .select({ id: workflow.id }) - .from(workflow) - .innerJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workflow.workspaceId), - eq(permissions.userId, userId) - ) - ) - .where(eq(workflow.id, workflowId)) - .limit(1) - - if (hasAccess.length === 0) { - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) + const { error } = await validateWorkflowPermissions(workflowId, workflowId, 'read') + if (error) { + return NextResponse.json({ error: error.message }, { status: error.status }) } const [webhook] = await db diff --git a/apps/tradinggoose/app/api/workflows/[id]/queue/route.test.ts b/apps/tradinggoose/app/api/workflows/[id]/queue/route.test.ts index a6e147bce..69107e314 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/queue/route.test.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/queue/route.test.ts @@ -51,6 +51,8 @@ vi.mock('@/lib/utils', () => ({ SSE_HEADERS: { 'Content-Type': 'text/event-stream' }, })) +vi.unmock('@/blocks/registry') + import { POST } from './route' describe('POST /api/workflows/[id]/queue', () => { @@ -142,7 +144,7 @@ describe('POST /api/workflows/[id]/queue', () => { it.each([ { name: 'unsupported trigger type', - body: JSON.stringify({ triggerType: 'webhook' }), + body: JSON.stringify({ triggerType: 'api-endpoint' }), error: 'Unsupported queued workflow trigger type', }, { @@ -150,6 +152,11 @@ describe('POST /api/workflows/[id]/queue', () => { body: JSON.stringify({ executionTarget: 'draft' }), error: 'Unsupported queued workflow execution target', }, + { + name: 'webhook without live trigger block', + body: JSON.stringify({ executionTarget: 'live', triggerType: 'webhook' }), + error: 'Webhook and schedule queued workflow executions require a live trigger block', + }, { name: 'malformed JSON', body: '{', @@ -284,10 +291,10 @@ describe('POST /api/workflows/[id]/queue', () => { expect(enqueuePendingExecutionMock).not.toHaveBeenCalled() }) - it('queues editor live executions with the canonical workflow payload', async () => { + it('queues editor live executions as manual runs with trigger source metadata', async () => { const workflowData = { blocks: { - 'trigger-1': { id: 'trigger-1', type: 'manual_trigger' }, + 'trigger-1': { id: 'trigger-1', type: 'schedule' }, }, edges: [], loops: {}, @@ -304,7 +311,7 @@ describe('POST /api/workflows/[id]/queue', () => { triggerType: 'manual', workflowData, workflowVariables: { risk: { value: 1 } }, - startBlockId: 'trigger-1', + triggerBlockId: 'trigger-1', }), headers: { 'Content-Type': 'application/json', @@ -323,10 +330,12 @@ describe('POST /api/workflows/[id]/queue', () => { payload: expect.objectContaining({ executionId: 'execution-1', input: { symbol: 'AAPL' }, + triggerType: 'manual', executionTarget: 'live', workflowData, workflowVariables: { risk: { value: 1 } }, - startBlockId: 'trigger-1', + triggerBlockId: 'trigger-1', + triggerData: { source: 'schedule' }, }), }) ) diff --git a/apps/tradinggoose/app/api/workflows/[id]/queue/route.ts b/apps/tradinggoose/app/api/workflows/[id]/queue/route.ts index ba987a865..4a3a9237d 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/queue/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/queue/route.ts @@ -11,10 +11,11 @@ import { TriggerExecutionUnavailableError } from '@/lib/trigger/settings' import { generateRequestId, SSE_HEADERS } from '@/lib/utils' import type { WorkflowExecutionBlueprint } from '@/lib/workflows/execution-runner' import { readWorkflowAccessContext } from '@/lib/workflows/utils' +import type { QueuedWorkflowTriggerType } from '@/services/queue' +import { resolveTriggerExecutionIdentity } from '@/triggers/resolution' const logger = createLogger('WorkflowQueueAPI') -type QueuedWorkflowTriggerType = 'api' | 'manual' | 'chat' type QueuedWorkflowExecutionTarget = 'deployed' | 'live' type QueueRequestBody = { @@ -24,7 +25,7 @@ type QueueRequestBody = { triggerType?: unknown workflowData?: WorkflowExecutionBlueprint['workflowData'] workflowVariables?: Record - startBlockId?: string + triggerBlockId?: string selectedOutputs?: string[] stream?: boolean workflowDepth?: number @@ -32,7 +33,9 @@ type QueueRequestBody = { function readQueuedWorkflowTriggerType(value: unknown): QueuedWorkflowTriggerType | null { if (value === undefined) return 'manual' - if (value === 'api' || value === 'manual' || value === 'chat') return value + if (['api', 'manual', 'chat', 'webhook', 'schedule'].includes(value as string)) { + return value as QueuedWorkflowTriggerType + } return null } @@ -58,10 +61,36 @@ function hasLiveWorkflowState(body: QueueRequestBody) { return ( body.workflowData !== undefined || body.workflowVariables !== undefined || - (typeof body.startBlockId === 'string' && body.startBlockId.length > 0) + (typeof body.triggerBlockId === 'string' && body.triggerBlockId.length > 0) ) } +function resolveQueuedTriggerData( + body: QueueRequestBody, + executionTarget: QueuedWorkflowExecutionTarget, + triggerType: QueuedWorkflowTriggerType +) { + if (executionTarget !== 'live' || typeof body.triggerBlockId !== 'string') { + return undefined + } + + const block = body.workflowData?.blocks?.[body.triggerBlockId] + if (!block) { + throw new Error('Queued workflow trigger block was not found in live workflow state') + } + + const identity = resolveTriggerExecutionIdentity(block) + const isManualEditorRun = triggerType === 'manual' + const triggerTypeMatchesBlock = isManualEditorRun + ? identity.triggerType !== 'chat' + : identity.triggerType === triggerType + if (!triggerTypeMatchesBlock) { + throw new Error('Queued workflow trigger type does not match the trigger block') + } + + return { source: identity.triggerSource } +} + export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { const requestId = generateRequestId() const { id: workflowId } = await params @@ -118,7 +147,17 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ { status: 400 } ) } - + if ( + (triggerType === 'webhook' || triggerType === 'schedule') && + (executionTarget !== 'live' || + typeof body.triggerBlockId !== 'string' || + body.triggerBlockId.length === 0) + ) { + return NextResponse.json( + { error: 'Webhook and schedule queued workflow executions require a live trigger block' }, + { status: 400 } + ) + } if ( !accessContext.isOwner && !accessContext.isWorkspaceOwner && @@ -133,6 +172,14 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ typeof body.executionId === 'string' && body.executionId.length > 0 ? body.executionId : `workflow_execution_${randomUUID()}` + let triggerData: { source: string } | undefined + try { + triggerData = resolveQueuedTriggerData(body, executionTarget, triggerType) + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : 'Queued workflow trigger block is not runnable' + return NextResponse.json({ error: errorMessage }, { status: 400 }) + } const handle = await enqueuePendingExecution({ executionType: 'workflow', pendingExecutionId, @@ -153,12 +200,13 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{ workflowVariables: executionTarget === 'live' ? body.workflowVariables : undefined, selectedOutputs: body.selectedOutputs, stream: body.stream === true, - startBlockId: + triggerBlockId: executionTarget === 'live' && - typeof body.startBlockId === 'string' && - body.startBlockId.length > 0 - ? body.startBlockId + typeof body.triggerBlockId === 'string' && + body.triggerBlockId.length > 0 + ? body.triggerBlockId : undefined, + ...(triggerData ? { triggerData } : {}), workflowDepth: typeof body.workflowDepth === 'number' ? body.workflowDepth : 0, metadata: { source, diff --git a/apps/tradinggoose/app/api/workflows/[id]/route.test.ts b/apps/tradinggoose/app/api/workflows/[id]/route.test.ts index 3fa71eb2d..bb5f160f4 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/route.test.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/route.test.ts @@ -775,6 +775,28 @@ describe('Workflow By ID API Route', () => { const data = await response.json() expect(data.error).toBe('Invalid request data') }) + + it('should reject generated workflow color updates', async () => { + vi.doMock('@/lib/auth', () => ({ + getSession: vi.fn().mockResolvedValue({ + user: { id: 'user-123' }, + }), + })) + + const req = new NextRequest('http://localhost:3000/api/workflows/workflow-123', { + method: 'PUT', + body: JSON.stringify({ color: '#3972F6' }), + }) + const params = Promise.resolve({ id: 'workflow-123' }) + + const { PUT } = await import('@/app/api/workflows/[id]/route') + const response = await PUT(req, { params }) + + expect(response.status).toBe(400) + const data = await response.json() + expect(data.error).toBe('Invalid request data') + expect(JSON.stringify(data.details)).toContain('color') + }) }) describe('Error handling', () => { diff --git a/apps/tradinggoose/app/api/workflows/[id]/route.ts b/apps/tradinggoose/app/api/workflows/[id]/route.ts index e7823b3e8..cc0088b93 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/route.ts @@ -16,12 +16,13 @@ import { createWorkflowSnapshot } from '@/lib/yjs/workflow-session' const logger = createLogger('WorkflowByIdAPI') -const UpdateWorkflowSchema = z.object({ - name: z.string().min(1, 'Name is required').optional(), - description: z.string().optional(), - color: z.string().optional(), - folderId: z.string().nullable().optional(), -}) +const UpdateWorkflowSchema = z + .object({ + name: z.string().min(1, 'Name is required').optional(), + description: z.string().optional(), + folderId: z.string().nullable().optional(), + }) + .strict() /** * GET /api/workflows/[id] @@ -310,7 +311,7 @@ export async function DELETE( /** * PUT /api/workflows/[id] - * Update workflow metadata (name, description, color, folderId) + * Update workflow metadata (name, description, folderId) */ export async function PUT(request: NextRequest, { params }: { params: Promise<{ id: string }> }) { const requestId = generateRequestId() @@ -371,7 +372,6 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{ const updateData: any = { updatedAt: new Date() } if (updates.name !== undefined) updateData.name = updates.name if (updates.description !== undefined) updateData.description = updates.description - if (updates.color !== undefined) updateData.color = updates.color if (updates.folderId !== undefined) updateData.folderId = updates.folderId // Update the workflow diff --git a/apps/tradinggoose/app/api/workflows/[id]/state/route.test.ts b/apps/tradinggoose/app/api/workflows/[id]/state/route.test.ts index 3f16d1a48..5de1bb481 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/state/route.test.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/state/route.test.ts @@ -85,8 +85,9 @@ describe('Workflow State API Route', () => { })) vi.doMock('@/lib/workflows/utils', () => ({ - readWorkflowAccessContext: vi.fn().mockResolvedValue({ - isOwner: true, + validateWorkflowPermissions: vi.fn().mockResolvedValue({ + error: null, + session: { user: { id: 'user-id' } }, workflow: { id: 'workflow-id', workspaceId: 'workspace-id', diff --git a/apps/tradinggoose/app/api/workflows/[id]/state/route.ts b/apps/tradinggoose/app/api/workflows/[id]/state/route.ts index 4e9b4510b..1219f2d28 100644 --- a/apps/tradinggoose/app/api/workflows/[id]/state/route.ts +++ b/apps/tradinggoose/app/api/workflows/[id]/state/route.ts @@ -3,7 +3,6 @@ import { workflow } from '@tradinggoose/db/schema' import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' -import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { generateRequestId } from '@/lib/utils' import { extractAndPersistCustomTools } from '@/lib/workflows/custom-tools-persistence' @@ -12,7 +11,7 @@ import { saveWorkflowToNormalizedTables, toISOStringOrUndefined, } from '@/lib/workflows/db-helpers' -import { readWorkflowAccessContext } from '@/lib/workflows/utils' +import { validateWorkflowPermissions } from '@/lib/workflows/utils' import { sanitizeAgentToolsInBlocks } from '@/lib/workflows/validation' import { tryApplyWorkflowState } from '@/lib/yjs/server/apply-workflow-state' import type { WorkflowSnapshot } from '@/lib/yjs/workflow-session' @@ -131,43 +130,23 @@ export async function PUT(request: NextRequest, { params }: { params: Promise<{ const { id: workflowId } = await params try { - // Get the session - const session = await getSession() - if (!session?.user?.id) { - logger.warn(`[${requestId}] Unauthorized state update attempt for workflow ${workflowId}`) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + const { + error, + session, + workflow: workflowData, + } = await validateWorkflowPermissions(workflowId, requestId, 'write') + if (error || !session?.user?.id || !workflowData) { + return NextResponse.json( + { error: error?.message ?? 'Unauthorized' }, + { status: error?.status ?? 401 } + ) } - const userId = session.user.id // Parse and validate request body const body = await request.json() const state = WorkflowStateSchema.parse(body) - // Fetch the workflow to check ownership/access - const accessContext = await readWorkflowAccessContext(workflowId, userId) - const workflowData = accessContext?.workflow - - if (!workflowData) { - logger.warn(`[${requestId}] Workflow ${workflowId} not found for state update`) - return NextResponse.json({ error: 'Workflow not found' }, { status: 404 }) - } - - // Check if user has permission to update this workflow - const canUpdate = - accessContext?.isOwner || - (workflowData.workspaceId - ? accessContext?.workspacePermission === 'write' || - accessContext?.workspacePermission === 'admin' - : false) - - if (!canUpdate) { - logger.warn( - `[${requestId}] User ${userId} denied permission to update workflow state ${workflowId}` - ) - return NextResponse.json({ error: 'Access denied' }, { status: 403 }) - } - // Sanitize custom tools in agent blocks before saving const { blocks: sanitizedBlocks, warnings } = sanitizeAgentToolsInBlocks(state.blocks as any) diff --git a/apps/tradinggoose/app/api/workflows/route.ts b/apps/tradinggoose/app/api/workflows/route.ts index 7e1c84601..ca3b2ca30 100644 --- a/apps/tradinggoose/app/api/workflows/route.ts +++ b/apps/tradinggoose/app/api/workflows/route.ts @@ -19,7 +19,6 @@ const logger = createLogger('WorkflowAPI') const CreateWorkflowSchema = z.object({ name: z.string().min(1, 'Name is required'), description: z.string().optional().default(''), - color: z.string().optional(), workspaceId: z.string().min(1, 'Workspace ID is required'), folderId: z.string().nullable().optional(), initialWorkflowState: z.any().optional(), @@ -129,7 +128,7 @@ export async function POST(req: NextRequest) { try { const body = await req.json() - const { name, description, color, workspaceId, folderId, initialWorkflowState } = + const { name, description, workspaceId, folderId, initialWorkflowState } = CreateWorkflowSchema.parse(body) const workspaceAccess = await checkWorkspaceAccess(workspaceId, session.user.id) @@ -161,10 +160,7 @@ export async function POST(req: NextRequest) { normalizeVariables(initialState?.variables), workflowId ) - const resolvedColor = - typeof color === 'string' && color.trim().length > 0 - ? color.trim() - : getStableVibrantColor(workflowId) + const resolvedColor = getStableVibrantColor(workflowId) logger.info(`[${requestId}] Creating workflow ${workflowId} for user ${session.user.id}`) diff --git a/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.test.ts b/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.test.ts index 6255d0548..534fa4247 100644 --- a/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.test.ts +++ b/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.test.ts @@ -14,6 +14,8 @@ describe('Workspace permissions PATCH route', () => { })), })), })) + const mockAssertActiveWorkspaceAccess = vi.fn() + const mockGetUserEntityPermissions = vi.fn() const mockHasWorkspaceAdminAccess = vi.fn() const mockGetUsersWithPermissions = vi.fn() const mockAssertWorkspaceBillingOwnerRetainsAdminAccess = vi.fn() @@ -40,6 +42,7 @@ describe('Workspace permissions PATCH route', () => { }, workspace: { id: 'workspace.id', + ownerId: 'workspace.ownerId', billingOwnerType: 'workspace.billingOwnerType', billingOwnerUserId: 'workspace.billingOwnerUserId', }, @@ -65,6 +68,8 @@ describe('Workspace permissions PATCH route', () => { })) vi.doMock('@/lib/permissions/utils', () => ({ + assertActiveWorkspaceAccess: mockAssertActiveWorkspaceAccess, + getUserEntityPermissions: mockGetUserEntityPermissions, getUsersWithPermissions: mockGetUsersWithPermissions, hasWorkspaceAdminAccess: mockHasWorkspaceAdminAccess, })) @@ -73,6 +78,9 @@ describe('Workspace permissions PATCH route', () => { assertWorkspaceBillingOwnerRetainsAdminAccess: mockAssertWorkspaceBillingOwnerRetainsAdminAccess, })) + + mockAssertActiveWorkspaceAccess.mockResolvedValue({}) + mockGetUserEntityPermissions.mockResolvedValue('admin') }) afterEach(() => { @@ -89,6 +97,7 @@ describe('Workspace permissions PATCH route', () => { [{ id: 'permission-1' }], [ { + ownerId: 'owner-1', billingOwnerType: 'user', billingOwnerUserId: 'user-2', }, @@ -119,7 +128,7 @@ describe('Workspace permissions PATCH route', () => { mockGetUsersWithPermissions.mockResolvedValue([]) selectResults.push( [{ id: 'permission-1' }], - [{ billingOwnerType: 'user', billingOwnerUserId: 'user-2' }] + [{ ownerId: 'owner-1', billingOwnerType: 'user', billingOwnerUserId: 'user-2' }] ) const { PATCH } = await import('./route') @@ -138,4 +147,53 @@ describe('Workspace permissions PATCH route', () => { expect(transactionMock).not.toHaveBeenCalled() expect(mockAssertWorkspaceBillingOwnerRetainsAdminAccess).not.toHaveBeenCalled() }) + + it('rejects updates to the canonical workspace owner permission', async () => { + mockHasWorkspaceAdminAccess.mockResolvedValue(true) + selectResults.push([ + { + ownerId: 'owner-1', + billingOwnerType: 'organization', + billingOwnerUserId: null, + }, + ]) + + const { PATCH } = await import('./route') + const response = await PATCH( + new NextRequest('http://localhost/api/workspaces/workspace-1/permissions', { + method: 'PATCH', + body: JSON.stringify({ + updates: [{ userId: 'owner-1', permissions: 'write' }], + }), + }), + { params: Promise.resolve({ id: 'workspace-1' }) } + ) + + expect(response.status).toBe(400) + expect(await response.json()).toEqual({ + error: 'Workspace owner permissions are managed by workspace ownership', + }) + expect(transactionMock).not.toHaveBeenCalled() + expect(mockAssertWorkspaceBillingOwnerRetainsAdminAccess).not.toHaveBeenCalled() + }) + + it('resolves the current user permission independently from the member list', async () => { + mockGetUsersWithPermissions.mockResolvedValue([]) + mockGetUserEntityPermissions.mockResolvedValue('admin') + + const { GET } = await import('./route') + const response = await GET( + new NextRequest('http://localhost/api/workspaces/workspace-1/permissions'), + { params: Promise.resolve({ id: 'workspace-1' }) } + ) + + expect(response.status).toBe(200) + expect(await response.json()).toEqual({ + users: [], + total: 0, + currentUserPermission: 'admin', + }) + expect(mockAssertActiveWorkspaceAccess).toHaveBeenCalledWith('workspace-1', 'user-1') + expect(mockGetUserEntityPermissions).toHaveBeenCalledWith('user-1', 'workspace', 'workspace-1') + }) }) diff --git a/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.ts b/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.ts index 107e15cbd..f2a28fedc 100644 --- a/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.ts +++ b/apps/tradinggoose/app/api/workspaces/[id]/permissions/route.ts @@ -7,6 +7,7 @@ import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' import { assertActiveWorkspaceAccess, + getUserEntityPermissions, getUsersWithPermissions, hasWorkspaceAdminAccess, } from '@/lib/permissions/utils' @@ -62,10 +63,20 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{ } const result = await getUsersWithPermissions(workspaceId) + const currentUserPermission = await getUserEntityPermissions( + session.user.id, + 'workspace', + workspaceId + ) + + if (!currentUserPermission) { + return NextResponse.json({ error: 'Workspace permission state unavailable' }, { status: 403 }) + } return NextResponse.json({ users: result, total: result.length, + currentUserPermission, }) } catch (error) { logger.error('Error fetching workspace permissions:', error) @@ -118,6 +129,7 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise< const workspaceRow = await db .select({ + ownerId: workspace.ownerId, billingOwnerType: workspace.billingOwnerType, billingOwnerUserId: workspace.billingOwnerUserId, }) @@ -129,6 +141,13 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise< return NextResponse.json({ error: 'Workspace not found or access denied' }, { status: 404 }) } + if (body.updates.some((update) => update.userId === workspaceRow[0].ownerId)) { + return NextResponse.json( + { error: 'Workspace owner permissions are managed by workspace ownership' }, + { status: 400 } + ) + } + try { assertWorkspaceBillingOwnerRetainsAdminAccess({ billingOwnerType: workspaceRow[0].billingOwnerType, @@ -167,11 +186,21 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise< }) const updatedUsers = await getUsersWithPermissions(workspaceId) + const currentUserPermission = await getUserEntityPermissions( + session.user.id, + 'workspace', + workspaceId + ) + + if (!currentUserPermission) { + return NextResponse.json({ error: 'Workspace permission state unavailable' }, { status: 403 }) + } return NextResponse.json({ message: 'Permissions updated successfully', users: updatedUsers, total: updatedUsers.length, + currentUserPermission, }) } catch (error) { logger.error('Error updating workspace permissions:', error) diff --git a/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.test.ts b/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.test.ts index d194bf5d7..33f9fb5f4 100644 --- a/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.test.ts +++ b/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.test.ts @@ -37,6 +37,7 @@ describe('Workspace Invitation [invitationId] API Route', () => { let mockDbResults: any[] = [] let mockGetSession: any let mockHasWorkspaceAdminAccess: any + let mockCheckWorkspaceAccess: any let mockTransaction: any beforeEach(async () => { @@ -57,7 +58,9 @@ describe('Workspace Invitation [invitationId] API Route', () => { })) mockHasWorkspaceAdminAccess = vi.fn() + mockCheckWorkspaceAccess = vi.fn().mockResolvedValue({ hasAccess: false }) vi.doMock('@/lib/permissions/utils', () => ({ + checkWorkspaceAccess: mockCheckWorkspaceAccess, hasWorkspaceAdminAccess: mockHasWorkspaceAdminAccess, })) @@ -185,7 +188,6 @@ describe('Workspace Invitation [invitationId] API Route', () => { mockDbResults.push([mockInvitation]) mockDbResults.push([mockWorkspace]) mockDbResults.push([{ ...mockUser, email: 'invited@example.com' }]) - mockDbResults.push([]) mockTransaction.mockImplementation(async (callback: any) => { await callback({ @@ -392,6 +394,7 @@ describe('Workspace Invitation [invitationId] API Route', () => { getSession: vi.fn().mockResolvedValue({ user: mockUser }), })) vi.doMock('@/lib/permissions/utils', () => ({ + checkWorkspaceAccess: vi.fn(), hasWorkspaceAdminAccess: vi.fn(), })) vi.doMock('@/lib/env', () => { diff --git a/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.ts b/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.ts index fabec2ae6..eacc47e21 100644 --- a/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.ts +++ b/apps/tradinggoose/app/api/workspaces/invitations/[invitationId]/route.ts @@ -7,12 +7,12 @@ import { workspace, workspaceInvitation, } from '@tradinggoose/db/schema' -import { and, eq } from 'drizzle-orm' +import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { getEmailSubject, renderWorkspaceInvitationEmail } from '@/components/emails/render-email' import { getSession } from '@/lib/auth' import { resolveEmailLocale } from '@/lib/email/locale' -import { hasWorkspaceAdminAccess } from '@/lib/permissions/utils' +import { checkWorkspaceAccess, hasWorkspaceAdminAccess } from '@/lib/permissions/utils' import { getBaseUrl } from '@/lib/urls/utils' import { defaultLocale, localizeUrl, stripLocaleFromPathname } from '@/i18n/utils' @@ -115,19 +115,8 @@ export async function GET( return NextResponse.redirect(redirectUrl(`/invite/${invitation.id}?error=email-mismatch`)) } - const existingPermission = await db - .select() - .from(permissions) - .where( - and( - eq(permissions.entityId, invitation.workspaceId), - eq(permissions.entityType, 'workspace'), - eq(permissions.userId, session.user.id) - ) - ) - .then((rows) => rows[0]) - - if (existingPermission) { + const existingAccess = await checkWorkspaceAccess(invitation.workspaceId, session.user.id) + if (existingAccess.hasAccess) { await db .update(workspaceInvitation) .set({ diff --git a/apps/tradinggoose/app/api/workspaces/invitations/route.test.ts b/apps/tradinggoose/app/api/workspaces/invitations/route.test.ts index ea72b7e23..033a390b5 100644 --- a/apps/tradinggoose/app/api/workspaces/invitations/route.test.ts +++ b/apps/tradinggoose/app/api/workspaces/invitations/route.test.ts @@ -10,6 +10,8 @@ describe('Workspace Invitations API Route', () => { let mockGetSession: any let mockInsertValues: any let mockSendEmail: any + let mockHasWorkspaceAdminAccess: any + let mockCheckWorkspaceAccess: any beforeEach(() => { vi.resetModules() @@ -29,11 +31,14 @@ describe('Workspace Invitations API Route', () => { })) mockInsertValues = vi.fn().mockResolvedValue(undefined) + mockHasWorkspaceAdminAccess = vi.fn().mockResolvedValue(true) + mockCheckWorkspaceAccess = vi.fn().mockResolvedValue({ hasAccess: false }) const mockDbChain = { select: vi.fn().mockReturnThis(), from: vi.fn().mockReturnThis(), where: vi.fn().mockReturnThis(), innerJoin: vi.fn().mockReturnThis(), + leftJoin: vi.fn().mockReturnThis(), limit: vi.fn().mockReturnThis(), then: vi.fn().mockImplementation((callback: any) => { const result = mockDbResults.shift() || [] @@ -93,6 +98,15 @@ describe('Workspace Invitations API Route', () => { getBaseUrl: vi.fn().mockReturnValue('https://test.tradinggoose.ai'), })) + vi.doMock('@/lib/permissions/utils', () => ({ + buildWorkspaceAccessScope: vi.fn(() => ({ + permissionJoin: 'permission-join', + accessFilter: 'access-filter', + })), + checkWorkspaceAccess: mockCheckWorkspaceAccess, + hasWorkspaceAdminAccess: mockHasWorkspaceAdminAccess, + })) + vi.doMock('drizzle-orm', () => ({ and: vi.fn().mockImplementation((...args) => ({ type: 'and', conditions: args })), eq: vi.fn().mockImplementation((field, value) => ({ type: 'eq', field, value })), @@ -205,7 +219,7 @@ describe('Workspace Invitations API Route', () => { it('should return 403 when user does not have admin permissions', async () => { mockGetSession.mockResolvedValue({ user: { id: 'user-123' } }) - mockDbResults = [[]] // No admin permissions found + mockHasWorkspaceAdminAccess.mockResolvedValue(false) const { POST } = await import('@/app/api/workspaces/invitations/route') const req = createMockRequest('POST', { @@ -222,7 +236,6 @@ describe('Workspace Invitations API Route', () => { it('should return 404 when workspace is not found', async () => { mockGetSession.mockResolvedValue({ user: { id: 'user-123' } }) mockDbResults = [ - [{ permissionType: 'admin' }], // User has admin permissions [], // Workspace not found ] @@ -240,11 +253,10 @@ describe('Workspace Invitations API Route', () => { it('should return 400 when user already has workspace access', async () => { mockGetSession.mockResolvedValue({ user: { id: 'user-123' } }) + mockCheckWorkspaceAccess.mockResolvedValue({ hasAccess: true }) mockDbResults = [ - [{ permissionType: 'admin' }], // User has admin permissions [mockWorkspace], // Workspace exists [mockUser], // User exists - [{ permissionType: 'read' }], // User already has access ] const { POST } = await import('@/app/api/workspaces/invitations/route') @@ -265,7 +277,6 @@ describe('Workspace Invitations API Route', () => { it('should return 400 when invitation already exists', async () => { mockGetSession.mockResolvedValue({ user: { id: 'user-123' } }) mockDbResults = [ - [{ permissionType: 'admin' }], // User has admin permissions [mockWorkspace], // Workspace exists [], // User doesn't exist [mockInvitation], // Invitation exists @@ -291,7 +302,6 @@ describe('Workspace Invitations API Route', () => { user: { id: 'user-123', name: 'Test User', email: 'sender@example.com' }, }) mockDbResults = [ - [{ permissionType: 'admin' }], // User has admin permissions [mockWorkspace], // Workspace exists [], // User doesn't exist [], // No existing invitation diff --git a/apps/tradinggoose/app/api/workspaces/invitations/route.ts b/apps/tradinggoose/app/api/workspaces/invitations/route.ts index ff8a3f749..84fd0c0c7 100644 --- a/apps/tradinggoose/app/api/workspaces/invitations/route.ts +++ b/apps/tradinggoose/app/api/workspaces/invitations/route.ts @@ -15,6 +15,11 @@ import { getSession } from '@/lib/auth' import { resolveEmailLocale } from '@/lib/email/locale' import { sendEmail } from '@/lib/email/mailer' import { createLogger } from '@/lib/logs/console/logger' +import { + buildWorkspaceAccessScope, + checkWorkspaceAccess, + hasWorkspaceAdminAccess, +} from '@/lib/permissions/utils' import { getBaseUrl } from '@/lib/urls/utils' import { localizeUrl } from '@/i18n/utils' @@ -33,18 +38,12 @@ export async function GET(req: NextRequest) { } try { - // Get all workspaces where the user has permissions + const workspaceAccess = buildWorkspaceAccessScope(session.user.id, workspace.id) const userWorkspaces = await db .select({ id: workspace.id }) .from(workspace) - .innerJoin( - permissions, - and( - eq(permissions.entityId, workspace.id), - eq(permissions.entityType, 'workspace'), - eq(permissions.userId, session.user.id) - ) - ) + .leftJoin(permissions, workspaceAccess.permissionJoin) + .where(workspaceAccess.accessFilter) if (userWorkspaces.length === 0) { return NextResponse.json({ invitations: [] }) @@ -90,21 +89,8 @@ export async function POST(req: NextRequest) { ) } - // Check if user has admin permissions for this workspace - const userPermission = await db - .select() - .from(permissions) - .where( - and( - eq(permissions.entityId, workspaceId), - eq(permissions.entityType, 'workspace'), - eq(permissions.userId, session.user.id), - eq(permissions.permissionType, 'admin') - ) - ) - .then((rows) => rows[0]) - - if (!userPermission) { + const hasAdminAccess = await hasWorkspaceAdminAccess(session.user.id, workspaceId) + if (!hasAdminAccess) { return NextResponse.json( { error: 'You need admin permissions to invite users' }, { status: 403 } @@ -131,20 +117,8 @@ export async function POST(req: NextRequest) { .then((rows) => rows[0]) if (existingUser) { - // Check if the user already has permissions for this workspace - const existingPermission = await db - .select() - .from(permissions) - .where( - and( - eq(permissions.entityId, workspaceId), - eq(permissions.entityType, 'workspace'), - eq(permissions.userId, existingUser.id) - ) - ) - .then((rows) => rows[0]) - - if (existingPermission) { + const existingAccess = await checkWorkspaceAccess(workspaceId, existingUser.id) + if (existingAccess.hasAccess) { return NextResponse.json( { error: `${email} already has access to this workspace`, diff --git a/apps/tradinggoose/app/api/workspaces/members/[id]/route.test.ts b/apps/tradinggoose/app/api/workspaces/members/[id]/route.test.ts index 5496227ca..1eb8094fa 100644 --- a/apps/tradinggoose/app/api/workspaces/members/[id]/route.test.ts +++ b/apps/tradinggoose/app/api/workspaces/members/[id]/route.test.ts @@ -4,14 +4,36 @@ import { NextRequest } from 'next/server' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +async function deleteMember(userId: string) { + const { DELETE } = await import('./route') + return DELETE( + new NextRequest(`http://localhost/api/workspaces/members/${userId}`, { + method: 'DELETE', + body: JSON.stringify({ workspaceId: 'workspace-1' }), + }), + { params: Promise.resolve({ id: userId }) } + ) +} + describe('Workspace member DELETE route', () => { const selectResults: any[][] = [] - const deleteMock = vi.fn() + const deleteWhereMock = vi.fn() + const deleteMock = vi.fn(() => ({ + where: deleteWhereMock, + })) const selectMock = vi.fn(() => ({ from: vi.fn(() => ({ - where: vi.fn(() => ({ - limit: vi.fn(() => selectResults.shift() ?? []), - })), + where: vi.fn(() => { + const rows = selectResults.shift() ?? [] + + return { + limit: vi.fn(() => rows), + then: ( + onFulfilled: (value: any[]) => unknown, + onRejected?: (reason: unknown) => unknown + ) => Promise.resolve(rows).then(onFulfilled, onRejected), + } + }), })), })) const mockHasWorkspaceAdminAccess = vi.fn() @@ -34,6 +56,7 @@ describe('Workspace member DELETE route', () => { }, workspace: { id: 'workspace.id', + ownerId: 'workspace.ownerId', billingOwnerType: 'workspace.billingOwnerType', billingOwnerUserId: 'workspace.billingOwnerUserId', }, @@ -70,6 +93,8 @@ describe('Workspace member DELETE route', () => { } ), })) + + mockHasWorkspaceAdminAccess.mockResolvedValue(true) }) afterEach(() => { @@ -79,25 +104,71 @@ describe('Workspace member DELETE route', () => { it('blocks removing the workspace billing owner until billing is reassigned', async () => { selectResults.push([ { + ownerId: 'user-1', billingOwnerType: 'user', billingOwnerUserId: 'user-2', }, ]) - - const { DELETE } = await import('./route') - const response = await DELETE( - new NextRequest('http://localhost/api/workspaces/members/user-2', { - method: 'DELETE', - body: JSON.stringify({ workspaceId: 'workspace-1' }), - }), - { params: Promise.resolve({ id: 'user-2' }) } - ) + const response = await deleteMember('user-2') expect(response.status).toBe(400) expect(await response.json()).toEqual({ error: 'Cannot remove the workspace billing owner. Please reassign billing first.', }) expect(deleteMock).not.toHaveBeenCalled() - expect(mockHasWorkspaceAdminAccess).not.toHaveBeenCalled() + expect(mockHasWorkspaceAdminAccess).toHaveBeenCalledWith('user-1', 'workspace-1') + }) + + it('blocks removing the canonical workspace owner', async () => { + selectResults.push([ + { + ownerId: 'user-2', + billingOwnerType: 'user', + billingOwnerUserId: 'user-1', + }, + ]) + + const response = await deleteMember('user-2') + + expect(response.status).toBe(400) + expect(await response.json()).toEqual({ error: 'Cannot remove the workspace owner' }) + expect(deleteMock).not.toHaveBeenCalled() + expect(mockHasWorkspaceAdminAccess).toHaveBeenCalledWith('user-1', 'workspace-1') + }) + + it('does not disclose canonical owner state to callers without admin access', async () => { + mockHasWorkspaceAdminAccess.mockResolvedValue(false) + selectResults.push([ + { + ownerId: 'user-2', + billingOwnerType: 'user', + billingOwnerUserId: 'user-1', + }, + ]) + + const response = await deleteMember('user-2') + + expect(response.status).toBe(403) + expect(await response.json()).toEqual({ error: 'Insufficient permissions' }) + expect(deleteMock).not.toHaveBeenCalled() + expect(mockHasWorkspaceAdminAccess).toHaveBeenCalledWith('user-1', 'workspace-1') + }) + + it('allows a non-owner admin to leave when the canonical owner remains admin', async () => { + selectResults.push([ + { + ownerId: 'owner-1', + billingOwnerType: 'user', + billingOwnerUserId: 'owner-1', + }, + ]) + selectResults.push([{ userId: 'user-1', permissionType: 'admin' }]) + + const response = await deleteMember('user-1') + + expect(response.status).toBe(200) + expect(await response.json()).toEqual({ success: true }) + expect(deleteMock).toHaveBeenCalledWith(expect.anything()) + expect(deleteWhereMock).toHaveBeenCalled() }) }) diff --git a/apps/tradinggoose/app/api/workspaces/members/[id]/route.ts b/apps/tradinggoose/app/api/workspaces/members/[id]/route.ts index 5d95c950a..46dab9619 100644 --- a/apps/tradinggoose/app/api/workspaces/members/[id]/route.ts +++ b/apps/tradinggoose/app/api/workspaces/members/[id]/route.ts @@ -31,6 +31,7 @@ export async function DELETE(req: NextRequest, { params }: { params: Promise<{ i const workspaceRow = await db .select({ + ownerId: workspace.ownerId, billingOwnerType: workspace.billingOwnerType, billingOwnerUserId: workspace.billingOwnerUserId, }) @@ -42,6 +43,17 @@ export async function DELETE(req: NextRequest, { params }: { params: Promise<{ i return NextResponse.json({ error: 'Workspace not found' }, { status: 404 }) } + const hasAdminAccess = await hasWorkspaceAdminAccess(session.user.id, workspaceId) + const isSelf = userId === session.user.id + + if (!hasAdminAccess && !isSelf) { + return NextResponse.json({ error: 'Insufficient permissions' }, { status: 403 }) + } + + if (workspaceRow[0].ownerId === userId) { + return NextResponse.json({ error: 'Cannot remove the workspace owner' }, { status: 400 }) + } + try { assertWorkspaceBillingOwnerCanBeRemoved({ billingOwnerType: workspaceRow[0].billingOwnerType, @@ -72,36 +84,6 @@ export async function DELETE(req: NextRequest, { params }: { params: Promise<{ i return NextResponse.json({ error: 'User not found in workspace' }, { status: 404 }) } - // Check if current user has admin access to this workspace - const hasAdminAccess = await hasWorkspaceAdminAccess(session.user.id, workspaceId) - const isSelf = userId === session.user.id - - if (!hasAdminAccess && !isSelf) { - return NextResponse.json({ error: 'Insufficient permissions' }, { status: 403 }) - } - - // Prevent removing yourself if you're the last admin - if (isSelf && userPermission.permissionType === 'admin') { - const otherAdmins = await db - .select() - .from(permissions) - .where( - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workspaceId), - eq(permissions.permissionType, 'admin') - ) - ) - .then((rows) => rows.filter((row) => row.userId !== session.user.id)) - - if (otherAdmins.length === 0) { - return NextResponse.json( - { error: 'Cannot remove the last admin from a workspace' }, - { status: 400 } - ) - } - } - // Delete the user's permissions for this workspace await db .delete(permissions) diff --git a/apps/tradinggoose/app/api/workspaces/members/route.ts b/apps/tradinggoose/app/api/workspaces/members/route.ts index 365dce01d..54fafe18f 100644 --- a/apps/tradinggoose/app/api/workspaces/members/route.ts +++ b/apps/tradinggoose/app/api/workspaces/members/route.ts @@ -1,9 +1,9 @@ import { db } from '@tradinggoose/db' import { permissions, type permissionTypeEnum, user } from '@tradinggoose/db/schema' -import { and, eq } from 'drizzle-orm' +import { eq } from 'drizzle-orm' import { NextResponse } from 'next/server' import { getSession } from '@/lib/auth' -import { hasAdminPermission } from '@/lib/permissions/utils' +import { checkWorkspaceAccess, hasWorkspaceAdminAccess } from '@/lib/permissions/utils' type PermissionType = (typeof permissionTypeEnum.enumValues)[number] @@ -35,7 +35,7 @@ export async function POST(req: Request) { } // Check if current user has admin permission for the workspace - const hasAdmin = await hasAdminPermission(session.user.id, workspaceId) + const hasAdmin = await hasWorkspaceAdminAccess(session.user.id, workspaceId) if (!hasAdmin) { return NextResponse.json({ error: 'Insufficient permissions' }, { status: 403 }) @@ -52,21 +52,10 @@ export async function POST(req: Request) { return NextResponse.json({ error: 'User not found' }, { status: 404 }) } - // Check if user already has permissions for this workspace - const existingPermissions = await db - .select() - .from(permissions) - .where( - and( - eq(permissions.userId, targetUser.id), - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workspaceId) - ) - ) - - if (existingPermissions.length > 0) { + const existingAccess = await checkWorkspaceAccess(workspaceId, targetUser.id) + if (existingAccess.hasAccess) { return NextResponse.json( - { error: 'User already has permissions for this workspace' }, + { error: 'User already has access to this workspace' }, { status: 400 } ) } diff --git a/apps/tradinggoose/app/api/workspaces/route.test.ts b/apps/tradinggoose/app/api/workspaces/route.test.ts index 30113d01a..1e282e324 100644 --- a/apps/tradinggoose/app/api/workspaces/route.test.ts +++ b/apps/tradinggoose/app/api/workspaces/route.test.ts @@ -6,12 +6,22 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' describe('Workspaces API Route', () => { const transactionMock = vi.fn() + const txInsertValuesMock = vi.fn() + const txInsertMock = vi.fn(() => ({ + values: txInsertValuesMock, + })) + const deleteWhereMock = vi.fn() + const deleteMock = vi.fn((_table: unknown) => ({ + where: deleteWhereMock, + })) const updateWhereMock = vi.fn() const updateSetMock = vi.fn() const updateMock = vi.fn() + const mockSaveWorkflowToNormalizedTables = vi.fn() + const mockTryApplyWorkflowState = vi.fn() let userWorkspaces: Array<{ workspace: Record - permissionType: 'admin' | 'write' | 'read' + permissionType: 'admin' | 'write' | 'read' | null }> = [] beforeEach(() => { @@ -19,14 +29,27 @@ describe('Workspaces API Route', () => { vi.clearAllMocks() userWorkspaces = [] + txInsertValuesMock.mockResolvedValue(undefined) + transactionMock.mockImplementation(async (callback) => + callback({ insert: txInsertMock, delete: deleteMock }) + ) + deleteWhereMock.mockResolvedValue(undefined) updateWhereMock.mockResolvedValue([]) updateSetMock.mockReturnValue({ where: updateWhereMock }) updateMock.mockReturnValue({ set: updateSetMock }) + mockSaveWorkflowToNormalizedTables.mockResolvedValue({ success: true }) + mockTryApplyWorkflowState.mockResolvedValue({ success: true }) vi.doMock('@tradinggoose/db', () => ({ db: { + delete: deleteMock, select: vi.fn(() => ({ from: vi.fn(() => ({ + leftJoin: vi.fn(() => ({ + where: vi.fn(() => ({ + orderBy: vi.fn(() => userWorkspaces), + })), + })), innerJoin: vi.fn(() => ({ where: vi.fn(() => ({ orderBy: vi.fn(() => userWorkspaces), @@ -53,6 +76,7 @@ describe('Workspaces API Route', () => { }, workspace: { id: 'workspace.id', + ownerId: 'workspace.ownerId', createdAt: 'workspace.createdAt', }, })) @@ -85,11 +109,11 @@ describe('Workspaces API Route', () => { })) vi.doMock('@/lib/workflows/db-helpers', () => ({ - saveWorkflowToNormalizedTables: vi.fn().mockResolvedValue({ success: true }), + saveWorkflowToNormalizedTables: mockSaveWorkflowToNormalizedTables, })) vi.doMock('@/lib/yjs/server/apply-workflow-state', () => ({ - tryApplyWorkflowState: vi.fn().mockResolvedValue(undefined), + tryApplyWorkflowState: mockTryApplyWorkflowState, })) vi.doMock('@/lib/yjs/workflow-session', () => ({ @@ -113,6 +137,16 @@ describe('Workspaces API Route', () => { vi.clearAllMocks() }) + async function postWorkspace() { + const { POST } = await import('@/app/api/workspaces/route') + return POST( + new Request('http://localhost/api/workspaces', { + method: 'POST', + body: JSON.stringify({ name: 'New Workspace' }), + }) + ) + } + it('returns an empty list without creating a default workspace when autoCreate=false', async () => { const { GET } = await import('@/app/api/workspaces/route') @@ -161,4 +195,93 @@ describe('Workspaces API Route', () => { expect(updateMock).not.toHaveBeenCalled() expect(transactionMock).not.toHaveBeenCalled() }) + + it('lists owned workspaces without requiring an owner permission row', async () => { + userWorkspaces = [ + { + workspace: { + id: 'workspace-owned', + name: 'Owned Workspace', + ownerId: 'user-1', + billingOwnerType: 'user', + billingOwnerUserId: 'user-1', + billingOwnerOrganizationId: null, + createdAt: new Date('2026-04-10T00:00:00.000Z'), + updatedAt: new Date('2026-04-10T00:00:00.000Z'), + }, + permissionType: null, + }, + ] + + const { GET } = await import('@/app/api/workspaces/route') + + const response = await GET(new NextRequest('http://localhost/api/workspaces?autoCreate=false')) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.workspaces).toEqual([ + expect.objectContaining({ + id: 'workspace-owned', + role: 'owner', + permissions: 'admin', + }), + ]) + expect(transactionMock).not.toHaveBeenCalled() + }) + + it('auto-creates a default workspace with the canonical workspace shape', async () => { + const { GET } = await import('@/app/api/workspaces/route') + + const response = await GET(new NextRequest('http://localhost/api/workspaces')) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.workspaces).toEqual([ + expect.objectContaining({ + name: "Bruz's Workspace", + role: 'owner', + permissions: 'admin', + billingOwner: { + type: 'user', + userId: 'user-1', + }, + }), + ]) + expect(transactionMock).toHaveBeenCalled() + expect(updateMock).toHaveBeenCalled() + }) + + it.each([ + [ + 'persistence fails', + () => + mockSaveWorkflowToNormalizedTables.mockResolvedValue({ + success: false, + error: 'Failed to persist normalized workflow state', + }), + ], + [ + 'persistence throws', + () => mockSaveWorkflowToNormalizedTables.mockRejectedValue(new Error('database unavailable')), + ], + [ + 'Yjs seeding fails', + () => + mockTryApplyWorkflowState.mockResolvedValue({ + success: false, + error: new Error('socket unavailable'), + }), + ], + ])('removes a newly created workspace when default workflow %s', async (_case, fail) => { + fail() + const response = await postWorkspace() + + expect(response.status).toBe(500) + expect(await response.json()).toEqual({ error: 'Failed to create workspace' }) + expect(deleteMock.mock.calls.map(([table]) => table)).toEqual([ + expect.objectContaining({ workspaceId: 'workflow.workspaceId' }), + expect.objectContaining({ ownerId: 'workspace.ownerId' }), + ]) + expect(deleteWhereMock).toHaveBeenCalledTimes(2) + }) }) diff --git a/apps/tradinggoose/app/api/workspaces/route.ts b/apps/tradinggoose/app/api/workspaces/route.ts index c3b2c203c..444eee6f4 100644 --- a/apps/tradinggoose/app/api/workspaces/route.ts +++ b/apps/tradinggoose/app/api/workspaces/route.ts @@ -1,15 +1,8 @@ -import { db } from '@tradinggoose/db' -import { permissions, workflow, workspace } from '@tradinggoose/db/schema' -import { and, desc, eq, isNull } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console/logger' -import { saveWorkflowToNormalizedTables } from '@/lib/workflows/db-helpers' -import { buildDefaultWorkflowArtifacts } from '@/lib/workflows/defaults' -import { toWorkspaceApiRecord } from '@/lib/workspaces/billing-owner' -import { tryApplyWorkflowState } from '@/lib/yjs/server/apply-workflow-state' -import { createWorkflowSnapshot } from '@/lib/yjs/workflow-session' +import { createWorkspace, getUserWorkspaces } from '@/lib/workspaces/service' const logger = createLogger('Workspaces') const createWorkspaceSchema = z.object({ @@ -25,46 +18,13 @@ export async function GET(request: NextRequest) { return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) } - // Get all workspaces where the user has permissions - const userWorkspaces = await db - .select({ - workspace: workspace, - permissionType: permissions.permissionType, - }) - .from(permissions) - .innerJoin(workspace, eq(permissions.entityId, workspace.id)) - .where(and(eq(permissions.userId, session.user.id), eq(permissions.entityType, 'workspace'))) - .orderBy(desc(workspace.createdAt)) + const workspaces = await getUserWorkspaces({ + userId: session.user.id, + userName: session.user.name, + autoCreate: allowWorkspaceBootstrap, + }) - if (userWorkspaces.length === 0) { - if (!allowWorkspaceBootstrap) { - return NextResponse.json({ workspaces: [] }) - } - - // Create a default workspace for the user - const defaultWorkspace = await createDefaultWorkspace(session.user.id, session.user.name) - - // Migrate existing workflows to the default workspace - await migrateExistingWorkflows(session.user.id, defaultWorkspace.id) - - return NextResponse.json({ workspaces: [defaultWorkspace] }) - } - - if (allowWorkspaceBootstrap) { - // If user has workspaces but might have orphaned workflows, migrate them - await ensureWorkflowsHaveWorkspace(session.user.id, userWorkspaces[0].workspace.id) - } - - // Format the response with permission information - const workspacesWithPermissions = userWorkspaces.map( - ({ workspace: workspaceDetails, permissionType }) => ({ - ...toWorkspaceApiRecord(workspaceDetails), - role: permissionType === 'admin' ? 'owner' : 'member', // Map admin to owner for compatibility - permissions: permissionType, - }) - ) - - return NextResponse.json({ workspaces: workspacesWithPermissions }) + return NextResponse.json({ workspaces }) } // POST /api/workspaces - Create a new workspace @@ -86,165 +46,3 @@ export async function POST(req: Request) { return NextResponse.json({ error: 'Failed to create workspace' }, { status: 500 }) } } - -// Helper function to create a default workspace -async function createDefaultWorkspace(userId: string, userName?: string | null) { - const firstName = userName?.split(' ')[0] || null - const workspaceName = firstName ? `${firstName}'s Workspace` : 'My Workspace' - return createWorkspace(userId, workspaceName) -} - -// Helper function to create a workspace -async function createWorkspace(userId: string, name: string) { - const workspaceId = crypto.randomUUID() - const workflowId = crypto.randomUUID() - const now = new Date() - - // Create the workspace and initial workflow in a transaction - try { - await db.transaction(async (tx) => { - // Create the workspace - await tx.insert(workspace).values({ - id: workspaceId, - name, - ownerId: userId, - billingOwnerType: 'user', - billingOwnerUserId: userId, - billingOwnerOrganizationId: null, - allowPersonalApiKeys: true, - createdAt: now, - updatedAt: now, - }) - - // Create admin permissions for the workspace owner - await tx.insert(permissions).values({ - id: crypto.randomUUID(), - entityType: 'workspace' as const, - entityId: workspaceId, - userId: userId, - permissionType: 'admin' as const, - createdAt: now, - updatedAt: now, - }) - - // Create initial workflow for the workspace (empty canvas) - // Create the workflow - await tx.insert(workflow).values({ - id: workflowId, - userId, - workspaceId, - folderId: null, - name: 'default-agent', - description: 'Your first workflow - start building here!', - color: '#3972F6', - lastSynced: now, - createdAt: now, - updatedAt: now, - isDeployed: false, - collaborators: [], - runCount: 0, - variables: {}, - isPublished: false, - marketplaceData: null, - }) - - // No blocks are inserted - empty canvas - - logger.info( - `Created workspace ${workspaceId} with initial workflow ${workflowId} for user ${userId}` - ) - }) - - const { workflowState } = buildDefaultWorkflowArtifacts() - const lastSaved = now.toISOString() - - // Seed the Yjs doc and persist to normalized tables in parallel - const [, seedResult] = await Promise.all([ - tryApplyWorkflowState( - workflowId, - createWorkflowSnapshot({ - blocks: workflowState.blocks, - edges: workflowState.edges, - loops: workflowState.loops, - parallels: workflowState.parallels, - lastSaved, - isDeployed: false, - }), - undefined, - 'default-agent' - ), - saveWorkflowToNormalizedTables(workflowId, workflowState), - ]) - - if (!seedResult.success) { - throw new Error(seedResult.error || 'Failed to seed default workflow state') - } - } catch (error) { - logger.error(`Failed to create workspace ${workspaceId} with initial workflow:`, error) - throw error - } - - // Return the workspace data directly instead of querying again - return { - ...toWorkspaceApiRecord({ - id: workspaceId, - name, - ownerId: userId, - billingOwnerType: 'user', - billingOwnerUserId: userId, - billingOwnerOrganizationId: null, - allowPersonalApiKeys: true, - createdAt: now, - updatedAt: now, - }), - role: 'owner', - } -} - -// Helper function to migrate existing workflows to a workspace -async function migrateExistingWorkflows(userId: string, workspaceId: string) { - // Find all workflows that have no workspace ID - const orphanedWorkflows = await db - .select({ id: workflow.id }) - .from(workflow) - .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) - - if (orphanedWorkflows.length === 0) { - return // No orphaned workflows to migrate - } - - logger.info( - `Migrating ${orphanedWorkflows.length} workflows to workspace ${workspaceId} for user ${userId}` - ) - - // Bulk update all orphaned workflows at once - await db - .update(workflow) - .set({ - workspaceId: workspaceId, - updatedAt: new Date(), - }) - .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) -} - -// Helper function to ensure all workflows have a workspace -async function ensureWorkflowsHaveWorkspace(userId: string, defaultWorkspaceId: string) { - // First check if there are any orphaned workflows - const orphanedWorkflows = await db - .select() - .from(workflow) - .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) - - if (orphanedWorkflows.length > 0) { - // Directly update any workflows that don't have a workspace ID in a single query - await db - .update(workflow) - .set({ - workspaceId: defaultWorkspaceId, - updatedAt: new Date(), - }) - .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) - - logger.info(`Fixed ${orphanedWorkflows.length} orphaned workflows for user ${userId}`) - } -} diff --git a/apps/tradinggoose/app/app-bootstrap.tsx b/apps/tradinggoose/app/app-bootstrap.tsx new file mode 100644 index 000000000..b08fa6051 --- /dev/null +++ b/apps/tradinggoose/app/app-bootstrap.tsx @@ -0,0 +1,59 @@ +'use client' + +import { useEffect } from 'react' +import { useLocale } from 'next-intl' +import { useSession } from '@/lib/auth-client' +import { useGeneralSettings } from '@/hooks/queries/general-settings' +import { replaceLocaleDocument, usePathname } from '@/i18n/navigation' +import { bootstrapProviderModels } from '@/stores/providers/store' +import { useGeneralStore } from '@/stores/settings/general/store' + +const USER_LOCALE_OWNED_ROUTE_PREFIXES = ['/workspace', '/admin', '/chat'] as const +const PROVIDER_BOOTSTRAP_DELAY_MS = 1000 + +const isUserLocaleOwnedRoute = (pathname: string) => + USER_LOCALE_OWNED_ROUTE_PREFIXES.some( + (route) => pathname === route || pathname.startsWith(`${route}/`) + ) + +export function AppBootstrap() { + const pathname = usePathname() ?? '/' + const locale = useLocale() + const { data: session, isPending } = useSession() + const userId = session?.user?.id ?? null + const settingsQuery = useGeneralSettings({ enabled: !isPending, userId }) + const preferredLocale = settingsQuery.data?.preferredLocale + + useEffect(() => { + useGeneralStore.setState({ + isLoading: isPending || (Boolean(userId) && settingsQuery.isPending), + }) + }, [isPending, settingsQuery.isPending, userId]) + + useEffect(() => { + if ( + userId && + preferredLocale && + preferredLocale !== locale && + isUserLocaleOwnedRoute(pathname) + ) { + replaceLocaleDocument(preferredLocale, `${pathname}${window.location.search}`) + } + }, [locale, pathname, preferredLocale, userId]) + + useEffect(() => { + if (!isUserLocaleOwnedRoute(pathname)) { + return + } + + const timeoutId = window.setTimeout(() => { + bootstrapProviderModels() + }, PROVIDER_BOOTSTRAP_DELAY_MS) + + return () => { + window.clearTimeout(timeoutId) + } + }, [pathname]) + + return null +} diff --git a/apps/tradinggoose/app/changelog.xml/route.ts b/apps/tradinggoose/app/changelog.xml/route.ts index 903b24dd9..8db6c30c9 100644 --- a/apps/tradinggoose/app/changelog.xml/route.ts +++ b/apps/tradinggoose/app/changelog.xml/route.ts @@ -1,7 +1,7 @@ import { NextResponse } from 'next/server' -import { SITE_BASE_URL } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' -export const dynamic = 'force-static' +export const dynamic = 'force-dynamic' export const revalidate = 3600 interface Release { @@ -25,6 +25,7 @@ function escapeXml(str: string) { export async function GET() { try { + const siteBaseUrl = getBaseUrl() const res = await fetch( 'https://api.github.com/repos/TradingGoose/TradingGoose-Studio/releases', { @@ -52,7 +53,7 @@ export async function GET() { TradingGoose Changelog - ${SITE_BASE_URL}/changelog + ${siteBaseUrl}/changelog Latest changes, fixes and updates in TradingGoose. en-us ${items} diff --git a/apps/tradinggoose/app/llms-full.txt/route.ts b/apps/tradinggoose/app/llms-full.txt/route.ts index 8fed55dd5..cfdfe97ad 100644 --- a/apps/tradinggoose/app/llms-full.txt/route.ts +++ b/apps/tradinggoose/app/llms-full.txt/route.ts @@ -4,9 +4,12 @@ import { buildHostedPricingNarrative, buildHostedPricingSentence, } from '@/lib/billing/public-catalog' -import { SITE_BASE_URL } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' + +export const dynamic = 'force-dynamic' export async function GET() { + const siteBaseUrl = getBaseUrl() const billingCatalog = await getPublicBillingCatalog() const hostedPricingSentence = billingCatalog.billingEnabled ? buildHostedPricingSentence(billingCatalog) @@ -31,7 +34,7 @@ export async function GET() { > information to cite TradingGoose accurately without hallucinating features, > pricing, or positioning. -Canonical URL: ${SITE_BASE_URL} +Canonical URL: ${siteBaseUrl} Source code: https://github.com/tradinggoose/tradinggoose-studio (open source, self-hostable) Documentation: https://docs.tradinggoose.ai Last updated: 2026-04-04 @@ -92,7 +95,7 @@ TradingGoose ships in two forms: - Self-hosting supported - Community-maintained -**TradingGoose Hosted (${SITE_BASE_URL})** — current managed cloud tiers: +**TradingGoose Hosted (${siteBaseUrl})** — current managed cloud tiers: ${hostedPricingTable} @@ -179,7 +182,7 @@ Calendly, Webflow, WordPress, Firecrawl, BrowserUse. **Is TradingGoose free?** Yes. TradingGoose Studio is open source under the license at -${SITE_BASE_URL}/licenses and can be self-hosted at no cost. The hosted +${siteBaseUrl}/licenses and can be self-hosted at no cost. The hosted edition at tradinggoose.ai ${hostedPricingSentence ? `currently offers ${hostedPricingSentence}.` : 'does not currently expose public billing tiers.'} Paid tiers start above the default free tier. @@ -242,11 +245,11 @@ any market data provider and stream live prices into the workspace. ## 12. Getting started -- Homepage: ${SITE_BASE_URL} +- Homepage: ${siteBaseUrl} - Documentation: https://docs.tradinggoose.ai - GitHub (open source): https://github.com/TradingGoose/TradingGoose-Studio -- Sign up (hosted, free tier): ${SITE_BASE_URL}/signup -- Changelog: ${SITE_BASE_URL}/changelog +- Sign up (hosted, free tier): ${siteBaseUrl}/signup +- Changelog: ${siteBaseUrl}/changelog ${enterpriseContactLine} ## 13. Community @@ -257,9 +260,9 @@ ${enterpriseContactLine} ## 14. Legal, licensing, and disclaimers -- Terms: ${SITE_BASE_URL}/terms -- Privacy: ${SITE_BASE_URL}/privacy -- License and third-party notices: ${SITE_BASE_URL}/licenses +- Terms: ${siteBaseUrl}/terms +- Privacy: ${siteBaseUrl}/privacy +- License and third-party notices: ${siteBaseUrl}/licenses - Support: support@tradinggoose.ai **Trading risk disclaimer:** TradingGoose is software. Trading in financial @@ -281,9 +284,9 @@ When citing TradingGoose, prefer this canonical phrasing: > can be self-hosted. Canonical sources to cite: -- Website: ${SITE_BASE_URL} +- Website: ${siteBaseUrl} - GitHub: https://github.com/tradinggoose/tradinggoose-studio -- This file: ${SITE_BASE_URL}/llms-full.txt +- This file: ${siteBaseUrl}/llms-full.txt ` return new Response(llmsFullContent, { diff --git a/apps/tradinggoose/app/llms.txt/route.ts b/apps/tradinggoose/app/llms.txt/route.ts index 27fa106f9..87a8a67fb 100644 --- a/apps/tradinggoose/app/llms.txt/route.ts +++ b/apps/tradinggoose/app/llms.txt/route.ts @@ -1,8 +1,11 @@ import { getPublicBillingCatalog } from '@/lib/billing/catalog' import { buildHostedPricingSentence } from '@/lib/billing/public-catalog' -import { SITE_BASE_URL } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' + +export const dynamic = 'force-dynamic' export async function GET() { + const siteBaseUrl = getBaseUrl() const billingCatalog = await getPublicBillingCatalog() const hostedPricingSentence = billingCatalog.billingEnabled ? buildHostedPricingSentence(billingCatalog) @@ -53,11 +56,11 @@ ${ - Widget: a composable workspace panel (chart, indicator view, workflow status, etc.) ## Getting started -- Homepage: ${SITE_BASE_URL} +- Homepage: ${siteBaseUrl} - Documentation: https://docs.tradinggoose.ai - GitHub (open source): https://github.com/TradingGoose/TradingGoose-Studio -- Sign up (hosted): ${SITE_BASE_URL}/signup -- Changelog: ${SITE_BASE_URL}/changelog +- Sign up (hosted): ${siteBaseUrl}/signup +- Changelog: ${siteBaseUrl}/changelog ## Community - GitHub: https://github.com/TradingGoose/TradingGoose-Studio @@ -65,11 +68,11 @@ ${ - X / Twitter: https://x.com/tradinggoose ## License -See ${SITE_BASE_URL}/licenses for license and third-party notices. +See ${siteBaseUrl}/licenses for license and third-party notices. ## Full reference For a deeper, AI-readable reference (features, pricing tiers, FAQ, example -workflow, integrations, glossary), see ${SITE_BASE_URL}/llms-full.txt +workflow, integrations, glossary), see ${siteBaseUrl}/llms-full.txt ` return new Response(llmsContent, { diff --git a/apps/tradinggoose/app/provider-models-bootstrap.tsx b/apps/tradinggoose/app/provider-models-bootstrap.tsx deleted file mode 100644 index 21c7843ce..000000000 --- a/apps/tradinggoose/app/provider-models-bootstrap.tsx +++ /dev/null @@ -1,37 +0,0 @@ -'use client' - -import { useEffect } from 'react' -import { usePathname } from '@/i18n/navigation' -import { bootstrapProviderModels } from '@/stores/providers/store' - -const PUBLIC_LANDING_ROUTE_PREFIXES = [ - '/privacy', - '/terms', - '/careers', - '/licenses', - '/blog', -] as const -const PROVIDER_BOOTSTRAP_DELAY_MS = 1000 - -const isPublicLandingRoute = (pathname: string) => - pathname === '/' || PUBLIC_LANDING_ROUTE_PREFIXES.some((route) => pathname.startsWith(route)) - -export function ProviderModelsBootstrap() { - const pathname = usePathname() ?? '/' - - useEffect(() => { - if (isPublicLandingRoute(pathname)) { - return - } - - const timeoutId = window.setTimeout(() => { - bootstrapProviderModels() - }, PROVIDER_BOOTSTRAP_DELAY_MS) - - return () => { - window.clearTimeout(timeoutId) - } - }, [pathname]) - - return null -} diff --git a/apps/tradinggoose/app/sitemap.ts b/apps/tradinggoose/app/sitemap.ts index c75a3ddb7..213bd3057 100644 --- a/apps/tradinggoose/app/sitemap.ts +++ b/apps/tradinggoose/app/sitemap.ts @@ -3,6 +3,8 @@ import { getAllPosts } from '@/app/(landing)/blog/lib/posts' import { locales } from '@/i18n/routing' import { localizeSiteUrl } from '@/i18n/utils' +export const dynamic = 'force-dynamic' + type SitemapEntry = Omit function localizedEntries(pathname: string, entry: SitemapEntry): MetadataRoute.Sitemap { diff --git a/apps/tradinggoose/app/workspace/[workspaceId]/integrations/integrations.tsx b/apps/tradinggoose/app/workspace/[workspaceId]/integrations/integrations.tsx index 6c49d3b7c..b99ff9085 100644 --- a/apps/tradinggoose/app/workspace/[workspaceId]/integrations/integrations.tsx +++ b/apps/tradinggoose/app/workspace/[workspaceId]/integrations/integrations.tsx @@ -8,23 +8,24 @@ import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { Skeleton } from '@/components/ui/skeleton' +import { createLogger } from '@/lib/logs/console/logger' +import { OAUTH_PROVIDERS } from '@/lib/oauth/oauth' +import { cn } from '@/lib/utils' +import { GlobalNavbarHeader } from '@/global-navbar' import { type ServiceInfo, useConnectOAuthService, useDisconnectOAuthService, useOAuthConnections, } from '@/hooks/queries/oauth-connections' -import { createLogger } from '@/lib/logs/console/logger' -import { OAUTH_PROVIDERS } from '@/lib/oauth/oauth' -import { cn } from '@/lib/utils' -import { GlobalNavbarHeader } from '@/global-navbar' -import { useRouter } from '@/i18n/navigation' +import { usePathname, useRouter } from '@/i18n/navigation' const logger = createLogger('Integrations') export function Integrations() { const t = useTranslations('workspace.integrations') const router = useRouter() + const pathname = usePathname() const searchParams = useSearchParams() const params = useParams() const workspaceId = params.workspaceId as string @@ -154,7 +155,7 @@ export function Integrations() { await connectService.mutateAsync({ providerId: service.providerId, - callbackURL: window.location.href, + callbackURL: `${pathname}${window.location.search}${window.location.hash}`, }) } catch (error) { logger.error('OAuth connection error:', { error }) @@ -404,8 +405,8 @@ export function Integrations() { > {t('connect')} - )} -
+ )} + ))} ) @@ -414,10 +415,10 @@ export function Integrations() { {!isLoading && !searchTerm.trim() && Object.keys(filteredGroupedServices).length === 0 && ( -
+
{t('emptyState.noConnectible')} -
- )} +
+ )} {/* Show message when search has no results */} {searchTerm.trim() && Object.keys(filteredGroupedServices).length === 0 && ( diff --git a/apps/tradinggoose/app/workspace/[workspaceId]/providers/providers.tsx b/apps/tradinggoose/app/workspace/[workspaceId]/providers/providers.tsx index 6ea02c2ad..61da69244 100644 --- a/apps/tradinggoose/app/workspace/[workspaceId]/providers/providers.tsx +++ b/apps/tradinggoose/app/workspace/[workspaceId]/providers/providers.tsx @@ -4,15 +4,19 @@ import React from 'react' import { TooltipProvider } from '@/components/ui/tooltip' import { WorkspacePermissionsProvider } from '@/app/workspace/[workspaceId]/providers/workspace-permissions-provider' -interface ProvidersProps { +type ProvidersProps = { children: React.ReactNode - workspaceId?: string -} + workspaceId: string +} & ({ userId: string; inheritUser?: never } | { inheritUser: true; userId?: never }) + +const Providers = React.memo((props) => { + const { children, workspaceId } = props + const workspaceIdentityProps = + props.inheritUser === true ? { inheritUser: true as const } : { userId: props.userId } -const Providers = React.memo(({ children, workspaceId }) => { return ( - + {children} diff --git a/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.test.tsx b/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.test.tsx index 5f7690eff..cbdc7a695 100644 --- a/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.test.tsx +++ b/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.test.tsx @@ -20,7 +20,6 @@ vi.mock('next/navigation', () => ({ })) vi.mock('@/i18n/navigation', () => ({ - usePathname: () => '/workspace/ws-1/dashboard', useRouter: () => ({ replace: mockReplace, }), @@ -52,7 +51,6 @@ describe('WorkspacePermissionsProvider', () => { beforeEach(() => { vi.clearAllMocks() - window.history.replaceState({}, '', '/workspace/ws-1/dashboard?layoutId=layout-1') mockUseWorkspacePermissions.mockReturnValue({ permissions: null, @@ -92,7 +90,7 @@ describe('WorkspacePermissionsProvider', () => { reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = previousActEnvironment }) - it('redirects missing sessions to login with the current callback target', async () => { + it('redirects authenticated users without access back to the workspace index', async () => { mockUseWorkspacePermissions.mockReturnValue({ permissions: null, loading: false, @@ -107,34 +105,81 @@ describe('WorkspacePermissionsProvider', () => { canAdmin: false, userPermissions: 'read', isLoading: false, - error: 'Authentication required', + error: 'Workspace not found or access denied', }) const { WorkspacePermissionsProvider } = await import('./workspace-permissions-provider') await act(async () => { root?.render( - +
workspace
) }) - expect(mockReplace).toHaveBeenCalledWith( - '/login?reauth=1&callbackUrl=%2Fworkspace%2Fws-1%2Fdashboard%3FlayoutId%3Dlayout-1' - ) + expect(mockReplace).toHaveBeenCalledWith('/workspace') expect(container?.textContent).toBe('') }) - it('redirects authenticated users without access back to the workspace index', async () => { + it('blocks rendering during auth recovery without replacing the auth redirect', async () => { mockUseWorkspacePermissions.mockReturnValue({ permissions: null, loading: false, - error: 'Workspace not found or access denied', + error: 'SESSION_EXPIRED', updatePermissions: mockUpdatePermissions, refetch: mockRefetchPermissions, }) + mockUseUserPermissions.mockReturnValue({ + canRead: false, + canEdit: false, + canAdmin: false, + userPermissions: 'read', + isLoading: false, + error: 'SESSION_EXPIRED', + }) + + const { WorkspacePermissionsProvider } = await import('./workspace-permissions-provider') + + await act(async () => { + root?.render( + +
workspace
+
+ ) + }) + + expect(mockReplace).not.toHaveBeenCalled() + expect(container?.textContent).toBe('') + }) + + it('inherits the server-authenticated user id for nested workspace providers', async () => { + const { WorkspacePermissionsProvider } = await import('./workspace-permissions-provider') + + await act(async () => { + root?.render( + + +
workspace
+
+
+ ) + }) + + expect(mockUseWorkspacePermissions).toHaveBeenCalledWith('ws-1', 'user-1') + expect(mockUseWorkspacePermissions).toHaveBeenCalledWith('ws-2', 'user-1') + expect(container?.textContent).toBe('workspace') + }) + + it('unblocks children when the authenticated user changes on the same workspace', async () => { + mockUseWorkspacePermissions.mockReturnValue({ + permissions: null, + loading: false, + error: 'Workspace not found or access denied', + updatePermissions: mockUpdatePermissions, + refetch: mockRefetchPermissions, + }) mockUseUserPermissions.mockReturnValue({ canRead: false, canEdit: false, @@ -148,13 +193,41 @@ describe('WorkspacePermissionsProvider', () => { await act(async () => { root?.render( - +
workspace
) }) - expect(mockReplace).toHaveBeenCalledWith('/workspace') expect(container?.textContent).toBe('') + + mockUseWorkspacePermissions.mockReturnValue({ + permissions: { + users: [], + total: 0, + currentUserPermission: 'admin', + }, + loading: false, + error: null, + updatePermissions: mockUpdatePermissions, + refetch: mockRefetchPermissions, + }) + mockUseUserPermissions.mockReturnValue({ + canRead: true, + canEdit: true, + canAdmin: true, + userPermissions: 'admin', + isLoading: false, + error: null, + }) + await act(async () => { + root?.render( + +
workspace
+
+ ) + }) + + expect(container?.textContent).toBe('workspace') }) }) diff --git a/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.tsx b/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.tsx index 144c794ce..01576de6f 100644 --- a/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.tsx +++ b/apps/tradinggoose/app/workspace/[workspaceId]/providers/workspace-permissions-provider.tsx @@ -2,18 +2,17 @@ import type React from 'react' import { createContext, useContext, useEffect, useMemo, useState } from 'react' -import { useParams } from 'next/navigation' import { createLogger } from '@/lib/logs/console/logger' +import { isSessionRecoveryAuthError } from '@/lib/auth/auth-error-copy' import { useUserPermissions, type WorkspaceUserPermissions } from '@/hooks/use-user-permissions' import { useWorkspacePermissions, type WorkspacePermissions, } from '@/hooks/use-workspace-permissions' -import { usePathname, useRouter } from '@/i18n/navigation' +import { useRouter } from '@/i18n/navigation' const logger = createLogger('WorkspacePermissionsProvider') const ACCESS_DENIED_PATTERNS = ['access denied', 'workspace not found', 'user not found'] -const AUTH_ERROR_PATTERNS = ['authentication required', 'failed to get session'] interface WorkspacePermissionsContextType { workspacePermissions: WorkspacePermissions | null @@ -25,43 +24,47 @@ interface WorkspacePermissionsContextType { setOfflineMode: (isOffline: boolean) => void } -const WorkspacePermissionsContext = createContext({ - workspacePermissions: null, - permissionsLoading: false, - permissionsError: null, - updatePermissions: () => {}, - refetchPermissions: async () => {}, - userPermissions: { - canRead: false, - canEdit: false, - canAdmin: false, - userPermissions: 'read', - isLoading: false, - error: null, - }, - setOfflineMode: () => {}, -}) - -interface WorkspacePermissionsProviderProps { +const WorkspaceAuthenticatedUserContext = createContext(null) +const WorkspacePermissionsContext = createContext(null) + +type WorkspacePermissionsProviderProps = { children: React.ReactNode - workspaceId?: string + workspaceId: string +} & ({ userId: string; inheritUser?: never } | { inheritUser: true; userId?: never }) + +export function WorkspacePermissionsProvider(props: WorkspacePermissionsProviderProps) { + const { children, workspaceId } = props + const inheritedUserId = useContext(WorkspaceAuthenticatedUserContext) + const workspaceUserId = props.userId ?? inheritedUserId + + if (!workspaceUserId) { + throw new Error( + 'WorkspacePermissionsProvider requires userId or inheritUser inside an existing WorkspacePermissionsProvider' + ) + } + + return ( + + {children} + + ) } -export function WorkspacePermissionsProvider({ +function WorkspacePermissionsProviderInner({ children, - workspaceId: workspaceIdProp, -}: WorkspacePermissionsProviderProps) { - const params = useParams() + workspaceId, + userId, +}: { + children: React.ReactNode + workspaceId: string + userId: string +}) { const router = useRouter() - const pathname = usePathname() - const workspaceId = workspaceIdProp ?? (params?.workspaceId as string | undefined) ?? null const [isOfflineMode, setIsOfflineMode] = useState(false) - const [hasRedirected, setHasRedirected] = useState(false) - - useEffect(() => { - setHasRedirected(false) - }, [workspaceId]) + const [redirectedAccessKey, setRedirectedAccessKey] = useState(null) + const accessKey = `${userId}:${workspaceId}` + const hasRedirected = redirectedAccessKey === accessKey const { permissions: workspacePermissions, @@ -69,7 +72,7 @@ export function WorkspacePermissionsProvider({ error: permissionsError, updatePermissions, refetch: refetchPermissions, - } = useWorkspacePermissions(workspaceId) + } = useWorkspacePermissions(workspaceId, userId) const baseUserPermissions = useUserPermissions( workspacePermissions, @@ -115,18 +118,17 @@ export function WorkspacePermissionsProvider({ ) const combinedError = userPermissions.error || permissionsError + const isAuthRecoveryError = isSessionRecoveryAuthError(permissionsError) const normalizedError = combinedError?.toLowerCase() ?? '' const isAccessDeniedError = normalizedError ? ACCESS_DENIED_PATTERNS.some((pattern) => normalizedError.includes(pattern)) : false - const isAuthError = normalizedError - ? AUTH_ERROR_PATTERNS.some((pattern) => normalizedError.includes(pattern)) - : false const shouldTriggerRedirect = Boolean( workspaceId && + !isAuthRecoveryError && !permissionsLoading && !userPermissions.isLoading && - (isAuthError || isAccessDeniedError || !userPermissions.canRead) + (isAccessDeniedError || !userPermissions.canRead) ) useEffect(() => { @@ -134,35 +136,22 @@ export function WorkspacePermissionsProvider({ return } - if (isAuthError) { - const callbackTarget = - typeof window === 'undefined' - ? `/workspace/${workspaceId}/dashboard` - : `${pathname ?? `/workspace/${workspaceId}/dashboard`}${window.location.search}` - - setHasRedirected(true) - logger.warn('Redirecting unauthenticated user from protected workspace route', { - workspaceId, - error: combinedError ?? 'missing session', - }) - router.replace(`/login?reauth=1&callbackUrl=${encodeURIComponent(callbackTarget)}`) - return - } - - setHasRedirected(true) + setRedirectedAccessKey(accessKey) logger.warn('Redirecting user without workspace access', { workspaceId, error: combinedError ?? 'missing read permissions', }) router.replace('/workspace') - }, [combinedError, hasRedirected, isAuthError, pathname, router, shouldTriggerRedirect, workspaceId]) + }, [accessKey, combinedError, hasRedirected, router, shouldTriggerRedirect, workspaceId]) - const shouldBlockRender = hasRedirected || shouldTriggerRedirect + const shouldBlockRender = isAuthRecoveryError || hasRedirected || shouldTriggerRedirect return ( - - {shouldBlockRender ? null : children} - + + + {shouldBlockRender ? null : children} + + ) } diff --git a/apps/tradinggoose/app/workspace/[workspaceId]/templates/[id]/template.tsx b/apps/tradinggoose/app/workspace/[workspaceId]/templates/[id]/template.tsx index 9f1f36b1c..32ad89859 100644 --- a/apps/tradinggoose/app/workspace/[workspaceId]/templates/[id]/template.tsx +++ b/apps/tradinggoose/app/workspace/[workspaceId]/templates/[id]/template.tsx @@ -216,7 +216,6 @@ export default function TemplateDetails({ template, workspaceId }: TemplateDetai body: JSON.stringify({ name: `${template.name} (Copy)`, description: `Created from template: ${template.name}`, - color: template.color, workspaceId, folderId: null, }), diff --git a/apps/tradinggoose/background/indicator-monitor-execution.test.ts b/apps/tradinggoose/background/indicator-monitor-execution.test.ts index e5e47b3a4..6b4da16d9 100644 --- a/apps/tradinggoose/background/indicator-monitor-execution.test.ts +++ b/apps/tradinggoose/background/indicator-monitor-execution.test.ts @@ -125,7 +125,7 @@ describe('executeIndicatorMonitorJob', () => { workspaceId: 'workspace-1', triggerType: 'webhook', executionTarget: 'deployed', - startBlockId: 'trigger-block', + triggerBlockId: 'trigger-block', }), }) ) diff --git a/apps/tradinggoose/background/indicator-monitor-execution.ts b/apps/tradinggoose/background/indicator-monitor-execution.ts index 8391b64ca..c0d6f8dd7 100644 --- a/apps/tradinggoose/background/indicator-monitor-execution.ts +++ b/apps/tradinggoose/background/indicator-monitor-execution.ts @@ -272,7 +272,7 @@ export async function executeIndicatorMonitorJob(payload: IndicatorMonitorExecut input: budgetResult.payload, triggerType: 'webhook', executionTarget: 'deployed', - startBlockId: payload.monitor.blockId, + triggerBlockId: payload.monitor.blockId, triggerData: { source: INDICATOR_MONITOR_TRIGGER_ID, executionTarget: 'deployed', diff --git a/apps/tradinggoose/background/portfolio-monitor-execution.ts b/apps/tradinggoose/background/portfolio-monitor-execution.ts index 9c748711a..a502633e5 100644 --- a/apps/tradinggoose/background/portfolio-monitor-execution.ts +++ b/apps/tradinggoose/background/portfolio-monitor-execution.ts @@ -79,7 +79,7 @@ export async function executePortfolioMonitorJob(payload: PortfolioMonitorExecut workflowInput, executionTarget: 'deployed', workflowContext: { workspaceId: payload.monitor.workspaceId }, - start: { + triggerTarget: { kind: 'block', blockId: payload.monitor.blockId, }, diff --git a/apps/tradinggoose/background/schedule-execution.ts b/apps/tradinggoose/background/schedule-execution.ts index 7aba4e817..c8ff4891f 100644 --- a/apps/tradinggoose/background/schedule-execution.ts +++ b/apps/tradinggoose/background/schedule-execution.ts @@ -26,7 +26,7 @@ export type ScheduleExecutionPayload = { scheduleId: string workflowId: string executionId?: string - blockId?: string + blockId: string cronExpression?: string lastRanAt?: string failedCount?: number @@ -43,18 +43,19 @@ export function isScheduleExecutionPayload(value: unknown): value is ScheduleExe return ( typeof candidate.scheduleId === 'string' && typeof candidate.workflowId === 'string' && + typeof candidate.blockId === 'string' && typeof candidate.timezone === 'string' && typeof candidate.now === 'string' ) } async function calculateNextRunTime( - schedule: { cronExpression?: string; lastRanAt?: string }, + schedule: { blockId: string; cronExpression?: string; lastRanAt?: string }, blocks: Record, timezone: string ): Promise { - const scheduleBlock = Object.values(blocks).find((block) => block.type === 'schedule') - if (!scheduleBlock) throw new Error('No schedule trigger block found') + const scheduleBlock = blocks[schedule.blockId] + if (!scheduleBlock) throw new Error(`Schedule trigger block ${schedule.blockId} not found`) const scheduleType = getSubBlockValue(scheduleBlock, 'scheduleType') const scheduleValues = getScheduleTimeValues(scheduleBlock) @@ -184,10 +185,11 @@ export async function executeScheduleJob(payload: ScheduleExecutionPayload) { }) const scheduleBlocks = blueprint.workflowData.blocks as Record - if (payload.blockId && !scheduleBlocks[payload.blockId]) { + if (!scheduleBlocks[payload.blockId]) { logger.warn( - `[${requestId}] Schedule trigger block ${payload.blockId} not found in deployed workflow ${payload.workflowId}. Skipping execution.` + `[${requestId}] Schedule trigger block ${payload.blockId} not found in deployed workflow ${payload.workflowId}. Removing schedule.` ) + await db.delete(workflowSchedule).where(eq(workflowSchedule.id, payload.scheduleId)) return } @@ -202,9 +204,9 @@ export async function executeScheduleJob(payload: ScheduleExecutionPayload) { workflowId: payload.workflowId, }, }, - start: { + triggerTarget: { kind: 'block', - blockId: payload.blockId || undefined, + blockId: payload.blockId, }, }) diff --git a/apps/tradinggoose/background/webhook-execution.ts b/apps/tradinggoose/background/webhook-execution.ts index 3c4372f59..4629be449 100644 --- a/apps/tradinggoose/background/webhook-execution.ts +++ b/apps/tradinggoose/background/webhook-execution.ts @@ -63,7 +63,7 @@ export type WebhookExecutionPayload = { provider: string body: any headers: Record - blockId?: string + blockId: string testMode?: boolean executionTarget?: 'deployed' | 'live' } @@ -78,7 +78,8 @@ export function isWebhookExecutionPayload(value: unknown): value is WebhookExecu typeof candidate.webhookId === 'string' && typeof candidate.workflowId === 'string' && typeof candidate.userId === 'string' && - typeof candidate.provider === 'string' + typeof candidate.provider === 'string' && + typeof candidate.blockId === 'string' ) } @@ -257,7 +258,7 @@ export async function executeWebhookJob(payload: WebhookExecutionPayload) { executionId, triggerType: 'webhook', workflowInput: airtableInput, - start: { + triggerTarget: { kind: 'block', blockId: payload.blockId, }, @@ -348,7 +349,7 @@ export async function executeWebhookJob(payload: WebhookExecutionPayload) { executionId, triggerType: 'webhook', workflowInput: input || {}, - start: { + triggerTarget: { kind: 'block', blockId: payload.blockId, }, diff --git a/apps/tradinggoose/background/workflow-execution.test.ts b/apps/tradinggoose/background/workflow-execution.test.ts index c354156b5..15550a434 100644 --- a/apps/tradinggoose/background/workflow-execution.test.ts +++ b/apps/tradinggoose/background/workflow-execution.test.ts @@ -153,7 +153,7 @@ describe('executeWorkflowJob', () => { executionTarget: 'live', workflowData, workflowVariables: { risk: { value: 1 } }, - startBlockId: 'trigger-1', + triggerBlockId: 'trigger-1', metadata: { source: 'workflow_queue', }, @@ -170,7 +170,7 @@ describe('executeWorkflowJob', () => { workspaceId: 'workspace-1', variables: { risk: { value: 1 } }, }, - start: { + triggerTarget: { kind: 'block', blockId: 'trigger-1', }, @@ -178,7 +178,7 @@ describe('executeWorkflowJob', () => { ) }) - it('preserves manual queued starts when no explicit start block is supplied', async () => { + it('preserves manual queued starts when no explicit trigger block is supplied', async () => { await executeWorkflowJob({ workflowId: 'workflow-1', userId: 'user-1', @@ -191,7 +191,7 @@ describe('executeWorkflowJob', () => { expect(runWorkflowExecutionMock).toHaveBeenCalledWith( expect.objectContaining({ triggerType: 'manual', - start: { + triggerTarget: { kind: 'trigger', triggerType: 'manual', }, @@ -227,7 +227,7 @@ describe('executeWorkflowJob', () => { workflowId: 'workflow-1', userId: 'user-1', triggerType: 'webhook', - startBlockId: 'trigger-1', + triggerBlockId: 'trigger-1', triggerData: { source: 'indicator_trigger', monitor: { id: 'monitor-1' }, diff --git a/apps/tradinggoose/background/workflow-execution.ts b/apps/tradinggoose/background/workflow-execution.ts index 23d17e114..e12502e9b 100644 --- a/apps/tradinggoose/background/workflow-execution.ts +++ b/apps/tradinggoose/background/workflow-execution.ts @@ -8,14 +8,14 @@ import { createWorkflowExecutionTerminalEventInput } from '@/lib/workflows/execu import { runWorkflowExecution, type WorkflowExecutionBlueprint, - type WorkflowStart, + type WorkflowTriggerTarget, } from '@/lib/workflows/execution-runner' import type { TriggerType } from '@/services/queue' import { disableMonitor } from './monitor-disable' const logger = createLogger('TriggerWorkflowExecution') -type WorkflowStartTriggerType = Extract['triggerType'] +type WorkflowTriggerTargetType = Extract['triggerType'] export type WorkflowExecutionPayload = { workflowId: string @@ -24,7 +24,7 @@ export type WorkflowExecutionPayload = { executionId?: string input?: any triggerType?: TriggerType - startBlockId?: string + triggerBlockId?: string executionTarget?: 'deployed' | 'live' workflowData?: WorkflowExecutionBlueprint['workflowData'] workflowVariables?: Record @@ -35,11 +35,11 @@ export type WorkflowExecutionPayload = { metadata?: Record } -function resolveWorkflowStartTriggerType(triggerType: TriggerType): WorkflowStartTriggerType { +function resolveWorkflowTriggerTargetType(triggerType: TriggerType): WorkflowTriggerTargetType { if (triggerType === 'chat') return 'chat' if (triggerType === 'api' || triggerType === 'api-endpoint') return 'api' if (triggerType === 'manual') return 'manual' - throw new Error(`Queued ${triggerType} workflow execution requires an explicit start block`) + throw new Error(`Queued ${triggerType} workflow execution requires an explicit trigger block`) } export function isWorkflowExecutionPayload( @@ -68,14 +68,14 @@ export async function executeWorkflowJob(payload: WorkflowExecutionPayload) { const isLiveExecution = executionTarget === 'live' const isChildExecution = payload.metadata?.source === 'workflow_block' const triggerType = payload.triggerType ?? 'manual' - const start: WorkflowStart = payload.startBlockId + const triggerTarget: WorkflowTriggerTarget = payload.triggerBlockId ? { kind: 'block', - blockId: payload.startBlockId, + blockId: payload.triggerBlockId, } : { kind: 'trigger', - triggerType: resolveWorkflowStartTriggerType(triggerType), + triggerType: resolveWorkflowTriggerTargetType(triggerType), } logger.info(`[${requestId}] Starting workflow execution: ${workflowId}`, { @@ -112,7 +112,7 @@ export async function executeWorkflowJob(payload: WorkflowExecutionPayload) { } : undefined, workflowData: isLiveExecution ? payload.workflowData : undefined, - start, + triggerTarget, triggerData, contextExtensions: { workflowDepth: payload.workflowDepth ?? 0, diff --git a/apps/tradinggoose/components/emails/header.tsx b/apps/tradinggoose/components/emails/header.tsx index 0578b0610..b948b0b1c 100644 --- a/apps/tradinggoose/components/emails/header.tsx +++ b/apps/tradinggoose/components/emails/header.tsx @@ -6,13 +6,13 @@ import { getBaseUrl } from '@/lib/urls/utils' import { type EmailLocale, getEmailCopy } from '@/components/emails/email-copy' interface EmailHeaderProps { + baseUrl?: string tagline?: string locale?: EmailLocale } -export const EmailHeader = ({ tagline, locale }: EmailHeaderProps) => { +export const EmailHeader = ({ baseUrl = getBaseUrl(), tagline, locale }: EmailHeaderProps) => { const brand = getBrandConfig() - const baseUrl = getBaseUrl() const logoSrc = `${baseUrl}/favicon/goose.png` const copy = getEmailCopy(locale) const resolvedTagline = tagline ?? copy.shared.tagline diff --git a/apps/tradinggoose/components/emails/localized-email.tsx b/apps/tradinggoose/components/emails/localized-email.tsx index 11d8c07c3..4bc212dfa 100644 --- a/apps/tradinggoose/components/emails/localized-email.tsx +++ b/apps/tradinggoose/components/emails/localized-email.tsx @@ -49,7 +49,7 @@ export function LocalizedEmail({ {preview} - +
{title} diff --git a/apps/tradinggoose/components/oauth/oauth-required-modal.test.tsx b/apps/tradinggoose/components/oauth/oauth-required-modal.test.tsx index d28451110..31f08b000 100644 --- a/apps/tradinggoose/components/oauth/oauth-required-modal.test.tsx +++ b/apps/tradinggoose/components/oauth/oauth-required-modal.test.tsx @@ -13,6 +13,10 @@ vi.mock('@/lib/oauth/connect', () => ({ startOAuthConnectFlow: (...args: unknown[]) => mockStartOAuthConnectFlow(...args), })) +vi.mock('@/i18n/navigation', () => ({ + usePathname: () => '/workspace/ws-1/integrations', +})) + describe('OAuthRequiredModal', () => { let container: HTMLDivElement let root: Root @@ -66,7 +70,7 @@ describe('OAuthRequiredModal', () => { expect(onClose).toHaveBeenCalledTimes(1) expect(mockStartOAuthConnectFlow).toHaveBeenCalledWith({ providerId: 'alpaca-paper', - callbackURL: window.location.href, + callbackURL: '/workspace/ws-1/integrations', }) }) }) diff --git a/apps/tradinggoose/components/oauth/oauth-required-modal.tsx b/apps/tradinggoose/components/oauth/oauth-required-modal.tsx index 279fb7303..816f7692a 100644 --- a/apps/tradinggoose/components/oauth/oauth-required-modal.tsx +++ b/apps/tradinggoose/components/oauth/oauth-required-modal.tsx @@ -19,6 +19,7 @@ import { parseProvider, } from '@/lib/oauth' import { startOAuthConnectFlow } from '@/lib/oauth/connect' +import { usePathname } from '@/i18n/navigation' import { formatTemplate } from '@/i18n/utils' import { useWorkflowBlockEditorCopy } from '@/widgets/widgets/editor_workflow/copy' @@ -141,6 +142,7 @@ export function OAuthRequiredModal({ serviceIds, }: OAuthRequiredModalProps) { const copy = useWorkflowBlockEditorCopy().oauthRequiredModal + const pathname = usePathname() const { baseProvider } = parseProvider(provider) const baseProviderConfig = OAUTH_PROVIDERS[baseProvider] const resolveExplicitServiceId = (candidate?: string) => { @@ -216,7 +218,7 @@ export function OAuthRequiredModal({ await startOAuthConnectFlow({ providerId, - callbackURL: window.location.href, + callbackURL: `${pathname}${window.location.search}${window.location.hash}`, }) } catch (error) { logger.error('Error initiating OAuth flow:', { error }) diff --git a/apps/tradinggoose/components/ui/sidebar.tsx b/apps/tradinggoose/components/ui/sidebar.tsx index db2c88fdc..ba48968a4 100644 --- a/apps/tradinggoose/components/ui/sidebar.tsx +++ b/apps/tradinggoose/components/ui/sidebar.tsx @@ -367,7 +367,7 @@ const SidebarRail = React.forwardRef< if (typeof ref === 'function') { ref(node) } else if (ref) { - ; (ref as React.MutableRefObject).current = node + ;(ref as React.MutableRefObject).current = node } dragRef.current = node }, @@ -589,7 +589,7 @@ const SidebarMenuItem = React.forwardRefspan:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0', + 'peer/menu-button flex w-full items-center gap-2 overflow-hidden rounded-md p-2 text-left text-sm outline-none ring-sidebar-ring transition-[width,height,padding] hover:bg-sidebar-accent hover:text-sidebar-accent-foreground focus-visible:ring-2 active:bg-sidebar-accent active:text-sidebar-accent-foreground disabled:pointer-events-none disabled:opacity-50 group-has-[[data-sidebar=menu-action]]/menu-item:pr-8 aria-disabled:pointer-events-none aria-disabled:opacity-50 data-[active=true]:bg-sidebar-accent data-[active=true]:font-medium data-[active=true]:text-sidebar-accent-foreground data-[state=open]:hover:bg-sidebar-accent data-[state=open]:hover:text-sidebar-accent-foreground group-data-[collapsible=icon]:!size-8 group-data-[collapsible=icon]:!p-2 [&>*:first-child]:shrink-0 [&>span:last-child]:truncate [&>svg]:size-4 [&>svg]:shrink-0', { variants: { variant: { @@ -692,7 +692,7 @@ const SidebarMenuAction = React.forwardRef< 'peer-data-[size=lg]/menu-button:top-2.5', 'group-data-[collapsible=icon]:hidden', showOnHover && - 'group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-[state=open]:opacity-100 peer-data-[active=true]/menu-button:text-sidebar-accent-foreground md:opacity-0', + 'group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-[state=open]:opacity-100 peer-data-[active=true]/menu-button:text-sidebar-accent-foreground md:opacity-0', className )} {...props} diff --git a/apps/tradinggoose/contexts/socket-context.tsx b/apps/tradinggoose/contexts/socket-context.tsx index 07735a8d2..ecd5ea07a 100644 --- a/apps/tradinggoose/contexts/socket-context.tsx +++ b/apps/tradinggoose/contexts/socket-context.tsx @@ -2,9 +2,9 @@ import { createContext, type ReactNode, useContext, useEffect, useRef, useState } from 'react' import { io, type Socket } from 'socket.io-client' -import { handleAuthError } from '@/lib/auth/auth-error-handler' import { getEnv } from '@/lib/env' import { createLogger } from '@/lib/logs/console/logger' +import { usePathname } from '@/i18n/navigation' const logger = createLogger('SocketContext') const isSocketAuthError = (message: string) => @@ -17,11 +17,11 @@ const logSocketIssue = ( details: { message: string type?: string - } + }, + callbackPathname: string ) => { if (isSocketAuthError(details.message)) { - logger.warn(event, details) - void handleAuthError('socket-auth') + logger.warn(event, { ...details, callbackPathname }) } else { logger.error(event, details) } @@ -133,9 +133,12 @@ const getGlobalSocketRegistry = (): Map => { } export function SocketProvider({ children, user }: SocketProviderProps) { + const pathname = usePathname() const [socket, setSocket] = useState(null) const [isConnected, setIsConnected] = useState(false) const [isConnecting, setIsConnecting] = useState(false) + const callbackPathnameRef = useRef(pathname) + callbackPathnameRef.current = pathname // Track socket in a ref so the cleanup closure always sees the latest value, // avoiding the race where `socket` state is still null during fast unmount. @@ -207,10 +210,14 @@ export function SocketProvider({ children, user }: SocketProviderProps) { const onConnectError = (error: any) => { setIsConnected(false) setIsConnecting(false) - logSocketIssue('Socket connection error:', { - message: error instanceof Error ? error.message : String(error), - type: error?.type, - }) + logSocketIssue( + 'Socket connection error:', + { + message: error instanceof Error ? error.message : String(error), + type: error?.type, + }, + callbackPathnameRef.current + ) } socketInstance.on('connect', onConnect) @@ -247,9 +254,13 @@ export function SocketProvider({ children, user }: SocketProviderProps) { setupSocketCleanup = setupSocket(socket) }) .catch((err) => { - logSocketIssue('Shared socket initialization failed', { - message: err instanceof Error ? err.message : String(err), - }) + logSocketIssue( + 'Shared socket initialization failed', + { + message: err instanceof Error ? err.message : String(err), + }, + callbackPathnameRef.current + ) if (!disposed) setIsConnecting(false) registry.delete(user.id) // Allow retry }) @@ -284,9 +295,13 @@ export function SocketProvider({ children, user }: SocketProviderProps) { const freshToken = await generateSocketToken() cb({ token: freshToken }) } catch (error) { - logSocketIssue('Failed to generate fresh token for connection:', { - message: error instanceof Error ? error.message : String(error), - }) + logSocketIssue( + 'Failed to generate fresh token for connection:', + { + message: error instanceof Error ? error.message : String(error), + }, + callbackPathnameRef.current + ) cb({ token: null }) } }, @@ -315,9 +330,13 @@ export function SocketProvider({ children, user }: SocketProviderProps) { setupSocketCleanup = setupSocket(socket) }) .catch((err) => { - logSocketIssue('Failed to initialize socket:', { - message: err instanceof Error ? err.message : String(err), - }) + logSocketIssue( + 'Failed to initialize socket:', + { + message: err instanceof Error ? err.message : String(err), + }, + callbackPathnameRef.current + ) if (!disposed) setIsConnecting(false) registry.delete(user.id) }) diff --git a/apps/tradinggoose/executor/__test-utils__/test-executor.ts b/apps/tradinggoose/executor/__test-utils__/test-executor.ts index 281865b29..caded0eb6 100644 --- a/apps/tradinggoose/executor/__test-utils__/test-executor.ts +++ b/apps/tradinggoose/executor/__test-utils__/test-executor.ts @@ -16,11 +16,11 @@ export class TestExecutor extends Executor { /** * Override the execute method to return a pre-defined result for testing */ - async execute(workflowId: string): Promise { + async execute(workflowId: string, triggerBlockId: string): Promise { try { // Call validateWorkflow to ensure we validate the workflow // even though we're not actually executing it - ;(this as any).validateWorkflow() + ;(this as any).validateWorkflow(triggerBlockId) // Return a successful result return { diff --git a/apps/tradinggoose/executor/index.test.ts b/apps/tradinggoose/executor/index.test.ts index c0ee1940c..933ea25f4 100644 --- a/apps/tradinggoose/executor/index.test.ts +++ b/apps/tradinggoose/executor/index.test.ts @@ -141,7 +141,7 @@ describe('Executor', () => { const validateSpy = vi.spyOn(executor as any, 'validateWorkflow') validateSpy.mockClear() - await executor.execute('test-workflow-id') + await executor.execute('test-workflow-id', 'trigger') expect(validateSpy).toHaveBeenCalledTimes(1) }) @@ -283,7 +283,7 @@ describe('Executor', () => { const workflow = createMinimalWorkflow() const executor = createTestExecutor(workflow) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result).toHaveProperty('success') expect(result).toHaveProperty('output') @@ -302,7 +302,7 @@ describe('Executor', () => { }, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result).toHaveProperty('success') expect(result).toHaveProperty('output') @@ -326,7 +326,7 @@ describe('Executor', () => { // Spy on createExecutionContext to verify context extensions are passed const createContextSpy = vi.spyOn(executor as any, 'createExecutionContext') - await executor.execute('test-workflow-id') + await executor.execute('test-workflow-id', 'trigger') expect(createContextSpy).toHaveBeenCalled() const contextArg = createContextSpy.mock.calls[0][2] // third argument is startTime, context is created internally @@ -341,7 +341,7 @@ describe('Executor', () => { const workflow = createWorkflowWithCondition() const executor = createTestExecutor(workflow) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Verify execution completes and returns expected structure if ('success' in result) { @@ -357,7 +357,7 @@ describe('Executor', () => { const workflow = createWorkflowWithLoop() const executor = createTestExecutor(workflow) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result).toHaveProperty('success') expect(result).toHaveProperty('output') @@ -555,7 +555,7 @@ describe('Executor', () => { }, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result).toHaveProperty('success') expect(result).toHaveProperty('output') @@ -573,7 +573,7 @@ describe('Executor', () => { const createContextSpy = vi.spyOn(executor as any, 'createExecutionContext') - await executor.execute('test-workflow-id') + await executor.execute('test-workflow-id', 'trigger') expect(createContextSpy).toHaveBeenCalled() }) @@ -610,7 +610,7 @@ describe('Executor', () => { }, ] - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result.success).toBe(false) expect(result.error).toContain('Provider stream failed') @@ -913,7 +913,7 @@ describe('Executor', () => { executor.cancel() // Try to execute - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Should immediately return cancelled result if ('success' in result) { @@ -931,7 +931,7 @@ describe('Executor', () => { ;(executor as any).isCancelled = true - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Should return cancelled result if ('success' in result) { @@ -1019,7 +1019,7 @@ describe('Executor', () => { updateExecutionPaths: vi.fn(), } - const result = await executor.execute('test-workflow') + const result = await executor.execute('test-workflow', 'trigger') // Should succeed with partial results - not throw an error expect(result).toBeDefined() @@ -1143,7 +1143,7 @@ describe('Executor', () => { workflowInput: {}, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Verify execution completed (may succeed or fail depending on child workflow availability) expect(result).toBeDefined() @@ -1192,7 +1192,7 @@ describe('Executor', () => { }) // Verify that child executor is created with isChildExecution flag - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') expect(result).toBeDefined() if ('success' in result) { @@ -1274,7 +1274,7 @@ describe('Executor', () => { workflowInput: {}, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Verify execution completed (may succeed or fail depending on child workflow availability) expect(result).toBeDefined() @@ -1342,7 +1342,7 @@ describe('Executor', () => { workflowInput: {}, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Verify execution completed (may succeed or fail depending on child workflow availability) expect(result).toBeDefined() @@ -1396,7 +1396,7 @@ describe('Executor', () => { workflowInput: {}, }) - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Verify that child workflow errors propagate to parent expect(result).toBeDefined() diff --git a/apps/tradinggoose/executor/index.ts b/apps/tradinggoose/executor/index.ts index 3265c396b..2b10dcf92 100644 --- a/apps/tradinggoose/executor/index.ts +++ b/apps/tradinggoose/executor/index.ts @@ -260,10 +260,10 @@ export class Executor { * Executes the workflow and returns the result. * * @param workflowId - Unique identifier for the workflow execution - * @param startBlockId - Optional block ID to start execution from (for webhook or schedule triggers) + * @param triggerBlockId - Trigger block ID to execute from * @returns Execution result containing output, logs, and metadata */ - async execute(workflowId: string, startBlockId?: string): Promise { + async execute(workflowId: string, triggerBlockId: string): Promise { const startTime = new Date() let finalOutput: NormalizedBlockOutput = {} @@ -275,9 +275,9 @@ export class Executor { startTime: startTime.toISOString(), }) - this.validateWorkflow(startBlockId) + this.validateWorkflow(triggerBlockId) - const context = this.createExecutionContext(workflowId, startTime, startBlockId) + const context = this.createExecutionContext(workflowId, startTime, triggerBlockId) try { let hasMoreLayers = true @@ -517,16 +517,15 @@ export class Executor { * Validates that the workflow meets requirements for execution. * Ensures trigger blocks exist along with valid connections and loop configurations. * - * @param startBlockId - Optional specific block to start from + * @param triggerBlockId - Trigger block to execute from * @throws Error if workflow validation fails */ - private validateWorkflow(startBlockId?: string): void { - if (startBlockId) { - const startBlock = this.actualWorkflow.blocks.find((block) => block.id === startBlockId) - if (!startBlock || !startBlock.enabled) { - throw new Error(`Start block ${startBlockId} not found or disabled`) + private validateWorkflow(triggerBlockId?: string): void { + if (triggerBlockId !== undefined) { + const triggerBlock = this.actualWorkflow.blocks.find((block) => block.id === triggerBlockId) + if (!triggerBlock || !triggerBlock.enabled) { + throw new Error(`Trigger block ${triggerBlockId} not found or disabled`) } - return } // Check for any type of trigger block (dedicated triggers or trigger-mode blocks) @@ -584,13 +583,13 @@ export class Executor { * * @param workflowId - Unique identifier for the workflow execution * @param startTime - Execution start time - * @param startBlockId - Optional specific block to start from + * @param triggerBlockId - Trigger block to execute from * @returns Initialized execution context */ private createExecutionContext( workflowId: string, startTime: Date, - startBlockId?: string + triggerBlockId: string ): ExecutionContext { const workspaceId = this.requireExecutionWorkspaceId() const context: ExecutionContext = { @@ -601,6 +600,7 @@ export class Executor { workflowLogId: this.contextExtensions.workflowLogId, submissionSource: this.contextExtensions.submissionSource, triggerType: this.contextExtensions.triggerType, + triggerBlockId: undefined, workflowDepth: this.contextExtensions.workflowDepth ?? 0, isDeployedContext: this.contextExtensions.isDeployedContext || false, blockStates: new Map(), @@ -645,31 +645,11 @@ export class Executor { } } - // Determine which block to initialize as the starting point - let initBlock: SerializedBlock | undefined - if (startBlockId) { - initBlock = this.actualWorkflow.blocks.find((block) => block.id === startBlockId) - } else if (this.isChildExecution) { - const inputTriggerBlocks = this.actualWorkflow.blocks.filter( - (block) => block.metadata?.id === 'input_trigger' - ) - if (inputTriggerBlocks.length === 1) { - initBlock = inputTriggerBlocks[0] - } else if (inputTriggerBlocks.length > 1) { - throw new Error('Child workflow has multiple Input Trigger blocks. Keep only one.') - } - } else { - const triggerBlocks = this.actualWorkflow.blocks.filter((block) => - isSerializedTriggerBlock(block) - ) - if (triggerBlocks.length > 0) { - initBlock = triggerBlocks[0] - } - } - + const initBlock = this.actualWorkflow.blocks.find((block) => block.id === triggerBlockId) if (!initBlock) { - throw new Error('Unable to determine a trigger block to initialize') + throw new Error(`Trigger block ${triggerBlockId} not found or disabled`) } + context.triggerBlockId = initBlock.id // Remove any pre-populated state for the init block so we can inject runtime trigger input. if (context.blockStates.has(initBlock.id)) { diff --git a/apps/tradinggoose/executor/resolver/resolver.test.ts b/apps/tradinggoose/executor/resolver/resolver.test.ts index 5f78ff543..18da3034f 100644 --- a/apps/tradinggoose/executor/resolver/resolver.test.ts +++ b/apps/tradinggoose/executor/resolver/resolver.test.ts @@ -87,6 +87,7 @@ describe('InputResolver', () => { mockContext = { workflowId: 'test-workflow', workflow: sampleWorkflow, + triggerBlockId: 'trigger-block', blockStates: new Map([ [ 'trigger-block', @@ -341,7 +342,7 @@ describe('InputResolver', () => { expect(result.nameRef).toBe('Hello World') // Should resolve using block name }) - it('should handle the special "start" alias for trigger block', () => { + it('should resolve the runtime trigger block through ', () => { const block: SerializedBlock = { id: 'test-block', metadata: { id: 'generic', name: 'Test Block' }, @@ -1338,6 +1339,7 @@ describe('InputResolver', () => { contextWithConnections = { workflowId: 'test-workflow', workspaceId: 'test-workspace-id', + triggerBlockId: 'trigger-1', blockStates: new Map([ ['trigger-1', { output: { input: 'Hello World' }, executed: true, executionTime: 0 }], ['agent-1', { output: { content: 'Agent response' }, executed: true, executionTime: 0 }], @@ -1446,6 +1448,43 @@ describe('InputResolver', () => { expect(result.code).toBe('return "Hello World"') // Should be quoted for function blocks }) + it('resolves start references from the runtime trigger block', () => { + const workflow: SerializedWorkflow = { + ...workflowWithConnections, + blocks: [ + ...workflowWithConnections.blocks, + { + id: 'schedule-trigger', + metadata: { id: 'schedule', name: 'Schedule', category: 'triggers' }, + position: { x: 0, y: 0 }, + config: { tool: 'schedule', params: {} }, + inputs: {}, + outputs: {}, + enabled: true, + }, + ], + } + const resolver = createInputResolver(workflow) + const context = { + ...contextWithConnections, + workflow, + triggerBlockId: 'schedule-trigger', + blockStates: new Map([ + ...contextWithConnections.blockStates, + ['trigger-1', { output: { symbol: 'WRONG' }, executed: true, executionTime: 0 }], + ['schedule-trigger', { output: { symbol: 'AAPL' }, executed: true, executionTime: 0 }], + ]), + } + + const result = resolver.resolveBlockReferences( + 'return ', + context, + workflow.blocks.find((block) => block.id === 'function-1')! + ) + + expect(result).toBe('return AAPL') + }) + it('should format start.input properly for different block types', () => { // Test function block - should quote strings const functionBlock: SerializedBlock = { @@ -2611,7 +2650,7 @@ describe('InputResolver', () => { expect(result.deep4).toBe('12') }) - it.concurrent('should handle start block with 2D array access', () => { + it.concurrent('should handle trigger input with 2D array access', () => { arrayContext.blockStates.set('trigger-block', { output: { input: 'Hello World', @@ -3135,6 +3174,7 @@ describe('InputResolver', () => { workflowId: 'test-parallel-workflow', workspaceId: 'test-workspace-id', workflow: parallelWorkflow, + triggerBlockId: 'start-block', blockStates: new Map([ [ 'function1-block', diff --git a/apps/tradinggoose/executor/resolver/resolver.ts b/apps/tradinggoose/executor/resolver/resolver.ts index 85d0144ab..82a3786f1 100644 --- a/apps/tradinggoose/executor/resolver/resolver.ts +++ b/apps/tradinggoose/executor/resolver/resolver.ts @@ -1,8 +1,7 @@ import { createLogger } from '@/lib/logs/console/logger' import { VariableManager } from '@/lib/variables/variable-manager' -import { evaluateSubBlockConditionValues } from '@/lib/workflows/sub-block-conditions' import { extractReferencePrefixes, SYSTEM_REFERENCE_PREFIXES } from '@/lib/workflows/references' -import { TRIGGER_REFERENCE_ALIAS_MAP } from '@/lib/workflows/triggers' +import { evaluateSubBlockConditionValues } from '@/lib/workflows/sub-block-conditions' import { getBlock } from '@/blocks/index' import type { LoopManager } from '@/executor/loops/loops' import type { ExecutionContext } from '@/executor/types' @@ -40,11 +39,6 @@ export class InputResolver { ]) ) - const startAliasBlock = this.findStartAliasBlock() - if (startAliasBlock) { - this.blockByNormalizedName.set('start', startAliasBlock) - } - // Create efficient loop lookup map this.loopsByBlockId = new Map() for (const [loopId, loop] of Object.entries(workflow.loops || {})) { @@ -62,18 +56,6 @@ export class InputResolver { } } - private findStartAliasBlock(): SerializedBlock | undefined { - const preferredTypes = ['input_trigger', 'api_trigger', 'manual_trigger'] - for (const type of preferredTypes) { - const candidate = this.workflow.blocks.find((block) => block.metadata?.id === type) - if (candidate) { - return candidate - } - } - - return this.workflow.blocks.find((block) => block.metadata?.category === 'triggers') - } - /** * Filters inputs based on sub-block conditions * @param block - Block to filter inputs for @@ -409,149 +391,140 @@ export class InputResolver { // System references (start, loop, parallel, variable) and regular block references are both processed // Accessibility validation happens later in validateBlockReference - // Special case for trigger block references (start, api, chat, manual) + // Special case for the runtime trigger reference. const blockRefLower = blockRef.toLowerCase() - const triggerType = - TRIGGER_REFERENCE_ALIAS_MAP[blockRefLower as keyof typeof TRIGGER_REFERENCE_ALIAS_MAP] - if (triggerType) { - const triggerBlock = this.workflow.blocks.find( - (block) => block.metadata?.id === triggerType - ) - if (triggerBlock) { - const blockState = context.blockStates.get(triggerBlock.id) - if (blockState) { - // For trigger blocks, start directly with the flattened output - // This enables direct access to , , etc - let replacementValue: any = blockState.output - - for (const part of pathParts) { - if (!replacementValue || typeof replacementValue !== 'object') { - logger.warn( - `[resolveBlockReferences] Invalid path "${part}" - replacementValue is not an object:`, - replacementValue + if (blockRefLower === 'start') { + const triggerBlock = context.triggerBlockId + ? this.blockById.get(context.triggerBlockId) + : undefined + if (!triggerBlock) { + throw new Error('Runtime trigger block is not available for reference.') + } + const blockState = context.blockStates.get(triggerBlock.id) + if (blockState) { + // Runtime trigger outputs are exposed through . + let replacementValue: any = blockState.output + + for (const part of pathParts) { + if (!replacementValue || typeof replacementValue !== 'object') { + logger.warn( + `[resolveBlockReferences] Invalid path "${part}" - replacementValue is not an object:`, + replacementValue + ) + throw new Error(`Invalid path "${part}" in "${path}" for trigger block.`) + } + + // Handle array indexing syntax like "files[0]" or "items[1]" + const arrayMatch = part.match(/^([^[]+)\[(\d+)\]$/) + if (arrayMatch) { + const [, arrayName, indexStr] = arrayMatch + const index = Number.parseInt(indexStr, 10) + + // First access the array property + const arrayValue = replacementValue[arrayName] + if (!Array.isArray(arrayValue)) { + throw new Error( + `Property "${arrayName}" is not an array in path "${path}" for trigger block.` ) - throw new Error(`Invalid path "${part}" in "${path}" for trigger block.`) } - // Handle array indexing syntax like "files[0]" or "items[1]" - const arrayMatch = part.match(/^([^[]+)\[(\d+)\]$/) - if (arrayMatch) { - const [, arrayName, indexStr] = arrayMatch - const index = Number.parseInt(indexStr, 10) - - // First access the array property - const arrayValue = replacementValue[arrayName] - if (!Array.isArray(arrayValue)) { - throw new Error( - `Property "${arrayName}" is not an array in path "${path}" for trigger block.` - ) - } - - // Then access the array element - if (index < 0 || index >= arrayValue.length) { - throw new Error( - `Array index ${index} is out of bounds for "${arrayName}" (length: ${arrayValue.length}) in path "${path}" for trigger block.` - ) - } - - replacementValue = arrayValue[index] - } else if (/^(?:[^[]+(?:\[\d+\])+|(?:\[\d+\])+)$/.test(part)) { - // Enhanced: support multiple indices like "values[0][0]" - replacementValue = this.resolvePartWithIndices( - replacementValue, - part, - path, - 'trigger block' + // Then access the array element + if (index < 0 || index >= arrayValue.length) { + throw new Error( + `Array index ${index} is out of bounds for "${arrayName}" (length: ${arrayValue.length}) in path "${path}" for trigger block.` ) - } else { - if (Array.isArray(replacementValue)) { - throw new Error( - `Array path "${path}" in trigger block must use an explicit index.` - ) - } - replacementValue = resolvePropertyAccess(replacementValue, part) } - if (replacementValue === undefined) { - logger.warn( - `[resolveBlockReferences] No value found at path "${part}" in trigger block.` - ) - throw new Error(`No value found at path "${path}" in trigger block.`) + replacementValue = arrayValue[index] + } else if (/^(?:[^[]+(?:\[\d+\])+|(?:\[\d+\])+)$/.test(part)) { + // Enhanced: support multiple indices like "values[0][0]" + replacementValue = this.resolvePartWithIndices( + replacementValue, + part, + path, + 'trigger block' + ) + } else { + if (Array.isArray(replacementValue)) { + throw new Error(`Array path "${path}" in trigger block must use an explicit index.`) } + replacementValue = resolvePropertyAccess(replacementValue, part) } - // Format the value based on block type and path - let formattedValue: string - - // Special handling for all blocks referencing trigger input - // For start and chat triggers, check for 'input' field. For API trigger, any field access counts - const isTriggerInputRef = - (blockRefLower === 'start' && pathParts.join('.').includes('input')) || - (blockRefLower === 'chat' && pathParts.join('.').includes('input')) || - (blockRefLower === 'api' && pathParts.length > 0) - if (isTriggerInputRef) { - const blockType = currentBlock.metadata?.id - - // Format based on which block is consuming this value - if (typeof replacementValue === 'object' && replacementValue !== null) { - // For function blocks, preserve the object structure for code usage - if (blockType === 'function') { - formattedValue = JSON.stringify(replacementValue) - } - // For API blocks, handle body special case - else if (blockType === 'api') { - formattedValue = JSON.stringify(replacementValue) - } - // For condition blocks, ensure proper formatting - else if (blockType === 'condition') { - formattedValue = this.stringifyForCondition(replacementValue) - } - // For response blocks, preserve object structure as-is for proper JSON response - else if (blockType === 'response') { - formattedValue = replacementValue - } - // For all other blocks, stringify objects - else { - // Preserve full JSON structure for objects - formattedValue = JSON.stringify(replacementValue) - } - } else { - // For primitive values, format based on target block type - if (blockType === 'function') { - formattedValue = this.formatValueForCodeContext( - replacementValue, - currentBlock, - isInTemplateLiteral - ) - } else if (blockType === 'condition') { - formattedValue = this.stringifyForCondition(replacementValue) - } else { - formattedValue = String(replacementValue) - } + if (replacementValue === undefined) { + logger.warn( + `[resolveBlockReferences] No value found at path "${part}" in trigger block.` + ) + throw new Error(`No value found at path "${path}" in trigger block.`) + } + } + + // Format the value based on block type and path + let formattedValue: string + + const isTriggerInputRef = pathParts.join('.').includes('input') + if (isTriggerInputRef) { + const blockType = currentBlock.metadata?.id + + // Format based on which block is consuming this value + if (typeof replacementValue === 'object' && replacementValue !== null) { + // For function blocks, preserve the object structure for code usage + if (blockType === 'function') { + formattedValue = JSON.stringify(replacementValue) + } + // For API blocks, handle body special case + else if (blockType === 'api') { + formattedValue = JSON.stringify(replacementValue) + } + // For condition blocks, ensure proper formatting + else if (blockType === 'condition') { + formattedValue = this.stringifyForCondition(replacementValue) + } + // For response blocks, preserve object structure as-is for proper JSON response + else if (blockType === 'response') { + formattedValue = replacementValue + } + // For all other blocks, stringify objects + else { + // Preserve full JSON structure for objects + formattedValue = JSON.stringify(replacementValue) } } else { - // Standard handling for non-input references - const blockType = currentBlock.metadata?.id - - if (blockType === 'response') { - // For response blocks, properly quote string values for JSON context - if (typeof replacementValue === 'string') { - // Properly escape and quote the string for JSON - formattedValue = JSON.stringify(replacementValue) - } else { - formattedValue = replacementValue - } + // For primitive values, format based on target block type + if (blockType === 'function') { + formattedValue = this.formatValueForCodeContext( + replacementValue, + currentBlock, + isInTemplateLiteral + ) + } else if (blockType === 'condition') { + formattedValue = this.stringifyForCondition(replacementValue) } else { - formattedValue = - typeof replacementValue === 'object' - ? JSON.stringify(replacementValue) - : String(replacementValue) + formattedValue = String(replacementValue) } } - - resolvedValue = resolvedValue.replace(raw, formattedValue) - continue + } else { + // Standard handling for non-input references + const blockType = currentBlock.metadata?.id + + if (blockType === 'response') { + // For response blocks, properly quote string values for JSON context + if (typeof replacementValue === 'string') { + // Properly escape and quote the string for JSON + formattedValue = JSON.stringify(replacementValue) + } else { + formattedValue = replacementValue + } + } else { + formattedValue = + typeof replacementValue === 'object' + ? JSON.stringify(replacementValue) + : String(replacementValue) + } } + + resolvedValue = resolvedValue.replace(raw, formattedValue) + continue } } @@ -1013,10 +986,7 @@ export class InputResolver { const accessibleBlocks = this.getAccessibleBlocks(currentBlockId) const isAlwaysAccessibleTrigger = sourceBlock.metadata?.category === 'triggers' || - sourceBlock.metadata?.id === 'input_trigger' || - sourceBlock.metadata?.id === 'api_trigger' || - sourceBlock.metadata?.id === 'manual_trigger' || - sourceBlock.metadata?.id === 'chat_trigger' + sourceBlock.config.params.triggerMode === true if ( sourceBlock.id !== currentBlockId && diff --git a/apps/tradinggoose/executor/tests/executor-layer-validation.test.ts b/apps/tradinggoose/executor/tests/executor-layer-validation.test.ts index 839fda7e0..e4af0f305 100644 --- a/apps/tradinggoose/executor/tests/executor-layer-validation.test.ts +++ b/apps/tradinggoose/executor/tests/executor-layer-validation.test.ts @@ -167,7 +167,7 @@ describe('Full Executor Test', () => { try { // Execute the workflow - const result = await executor.execute('test-workflow-id') + const result = await executor.execute('test-workflow-id', 'trigger') // Check if it's an ExecutionResult (not StreamingExecution) if ('success' in result) { @@ -186,7 +186,11 @@ describe('Full Executor Test', () => { it('should test the executor getNextExecutionLayer method directly', async () => { // Create a mock context in the exact state after the condition executes - const context = (executor as any).createExecutionContext('test-workflow', new Date()) + const context = (executor as any).createExecutionContext( + 'test-workflow', + new Date(), + 'bd9f4f7d-8aed-4860-a3be-8bebd1931b19' + ) // Set up the state as it would be after the condition executes context.executedBlocks.add('bd9f4f7d-8aed-4860-a3be-8bebd1931b19') // Start diff --git a/apps/tradinggoose/executor/tests/multi-input-routing.test.ts b/apps/tradinggoose/executor/tests/multi-input-routing.test.ts index 8d3c7663b..818a72789 100644 --- a/apps/tradinggoose/executor/tests/multi-input-routing.test.ts +++ b/apps/tradinggoose/executor/tests/multi-input-routing.test.ts @@ -101,9 +101,9 @@ describe('Multi-Input Routing Scenarios', () => { it('should handle multi-input target when router selects function-1', async () => { // Test scenario: Router selects function-1, agent should still execute with function-1's output - const context = (executor as any).createExecutionContext('test-workflow', new Date()) + const context = (executor as any).createExecutionContext('test-workflow', new Date(), 'start') - // Step 1: Execute start block + // Step 1: Execute trigger block context.executedBlocks.add('start') context.activeExecutionPath.add('start') context.activeExecutionPath.add('router-1') @@ -166,7 +166,7 @@ describe('Multi-Input Routing Scenarios', () => { it('should handle multi-input target when router selects function-2', async () => { // Test scenario: Router selects function-2, agent should still execute with function-2's output - const context = (executor as any).createExecutionContext('test-workflow', new Date()) + const context = (executor as any).createExecutionContext('test-workflow', new Date(), 'start') // Step 1: Execute start and router-1 selecting function-2 context.executedBlocks.add('start') @@ -223,7 +223,7 @@ describe('Multi-Input Routing Scenarios', () => { it('should verify the dependency logic for inactive sources', async () => { // This test specifically validates the multi-input dependency logic - const context = (executor as any).createExecutionContext('test-workflow', new Date()) + const context = (executor as any).createExecutionContext('test-workflow', new Date(), 'start') // Setup: Router executed and selected function-1, function-1 executed context.executedBlocks.add('start') diff --git a/apps/tradinggoose/executor/types.ts b/apps/tradinggoose/executor/types.ts index 4ce4e86b7..eab9cd07f 100644 --- a/apps/tradinggoose/executor/types.ts +++ b/apps/tradinggoose/executor/types.ts @@ -115,6 +115,7 @@ export interface ExecutionContext { workflowLogId?: string submissionSource?: ExecutionSubmissionSource triggerType?: TriggerType + triggerBlockId?: string workflowDepth?: number // Whether this execution is running against deployed state (API/webhook/schedule/chat) // Manual executions in the builder should leave this undefined/false diff --git a/apps/tradinggoose/global-navbar/components/user-menu.test.tsx b/apps/tradinggoose/global-navbar/components/user-menu.test.tsx index 9922f4d26..28fbf3c37 100644 --- a/apps/tradinggoose/global-navbar/components/user-menu.test.tsx +++ b/apps/tradinggoose/global-navbar/components/user-menu.test.tsx @@ -15,6 +15,7 @@ const mockRefresh = vi.fn() const mockReplaceLocaleDocument = vi.fn() const mockSetTheme = vi.fn() const mockUpdateSetting = vi.fn() +const mockOpenSettings = vi.fn() let mockPathname = '/workspace/ws-1/dashboard' let mockSearchParams = '' @@ -91,12 +92,25 @@ vi.mock('@/global-navbar/settings-modal/components/help/help-modal', () => ({ HelpModal: () => null, })) -function renderUserMenu(root: Root, locale: LocaleCode) { +function renderUserMenu( + root: Root, + locale: LocaleCode, + options: { canAccessSystemAdmin?: boolean; sidebarTrigger?: boolean } = {} +) { + const userMenu = ( + + ) + root.render( - - - + {options.sidebarTrigger ? {userMenu} : userMenu} ) } @@ -109,7 +123,9 @@ async function openMenu(button: HTMLButtonElement) { } function getUserMenuButton(container: HTMLElement) { - const button = container.querySelector('button[data-sidebar="menu-button"]') + const button = Array.from(container.querySelectorAll('button')).find((candidate) => + candidate.getAttribute('aria-label')?.startsWith('Ada Lovelace ') + ) if (!(button instanceof HTMLButtonElement)) { throw new Error('Expected user menu trigger to render') } @@ -179,6 +195,7 @@ describe('UserMenu language selector', () => { mockReplaceLocaleDocument.mockReset() mockSetTheme.mockReset() mockUpdateSetting.mockReset() + mockOpenSettings.mockReset() mockUpdateSetting.mockResolvedValue(undefined) mockPathname = '/workspace/ws-1/dashboard' mockSearchParams = '' @@ -212,6 +229,62 @@ describe('UserMenu language selector', () => { expect(getThemeButton('主题:系统')).toBeInTheDocument() }) + it('renders the compact avatar trigger outside a sidebar context', async () => { + await act(async () => { + renderUserMenu(root, 'en') + await flush() + }) + + const button = getUserMenuButton(container) + expect(button.textContent).toBe('AL') + expect(container.querySelector('[data-sidebar="menu"]')).toBeNull() + expect(container.querySelector('button[data-sidebar="menu-button"]')).toBeNull() + }) + + it('renders the sidebar trigger with user details inside the global navbar sidebar', async () => { + await act(async () => { + renderUserMenu(root, 'en', { sidebarTrigger: true }) + await flush() + }) + + const button = getUserMenuButton(container) + expect(button.getAttribute('data-sidebar')).toBe('menu-button') + expect(button.textContent).toContain('Ada Lovelace') + expect(button.textContent).toContain('ada@example.com') + + await act(async () => { + await openMenu(button) + }) + + const menu = document.body.querySelector('[role="menu"]') + expect(menu?.className).toContain('w-[var(--radix-dropdown-menu-trigger-width)]') + }) + + it('owns the system admin menu item for authorized users', async () => { + await act(async () => { + renderUserMenu(root, 'en', { canAccessSystemAdmin: true }) + await flush() + }) + + await act(async () => { + await openMenu(getUserMenuButton(container)) + }) + + const systemAdminItem = Array.from(document.body.querySelectorAll('[role="menuitem"]')).find( + (item) => item.textContent?.includes(getPublicCopy('en').workspace.nav.systemAdmin) + ) + if (!(systemAdminItem instanceof HTMLElement)) { + throw new Error('Expected system admin menu item to render') + } + + await act(async () => { + systemAdminItem.dispatchEvent(new MouseEvent('click', { bubbles: true })) + await flush() + }) + + expect(mockPush).toHaveBeenCalledWith('/admin') + }) + it('switches to zh without dropping the workspace path or query string', async () => { mockSearchParams = 'layout=main' diff --git a/apps/tradinggoose/global-navbar/components/user-menu.tsx b/apps/tradinggoose/global-navbar/components/user-menu.tsx index 245dba5f2..a6b3f8f73 100644 --- a/apps/tradinggoose/global-navbar/components/user-menu.tsx +++ b/apps/tradinggoose/global-navbar/components/user-menu.tsx @@ -20,8 +20,16 @@ import { Users, } from 'lucide-react' import { useSearchParams } from 'next/navigation' -import { useLocale, useMessages, useTranslations } from 'next-intl' +import { useLocale, useTranslations } from 'next-intl' import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuGroup, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu' import { SidebarMenu, SidebarMenuButton, SidebarMenuItem } from '@/components/ui/sidebar' import { widgetHeaderControlClassName, @@ -40,20 +48,11 @@ import { HelpModal } from '@/global-navbar/settings-modal/components/help/help-m import type { SettingsSection } from '@/global-navbar/settings-modal/types' import { useOrganizationBilling, useOrganizations } from '@/hooks/queries/organization' import { useSubscriptionData } from '@/hooks/queries/subscription' -import { formatTemplate } from '@/i18n/utils' import { replaceLocaleDocument, usePathname, useRouter } from '@/i18n/navigation' import { getLocaleDisplayName, isLocaleCode, type LocaleCode, locales } from '@/i18n/utils' import { clearUserData } from '@/stores' import { useGeneralStore } from '@/stores/settings/general/store' import { getInitials } from '../utils' -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuGroup, - DropdownMenuItem, - DropdownMenuSeparator, - DropdownMenuTrigger, -} from './resizable-dropdown' type ThemeOption = { value: 'light' | 'system' | 'dark' @@ -74,19 +73,9 @@ interface UserMenuProps { userAvatar?: string | null userAvatarVersion?: number | string | null userId?: string | null - onOpenSettings?: (section: SettingsSection) => void - systemNavigation?: { - href: string - label: string - } | null -} - -type UserMenuMessages = { - workspace?: { - userMenu?: { - themeLabel?: string - } - } + onOpenSettings: (section: SettingsSection) => void + canAccessSystemAdmin?: boolean + sidebarTrigger?: boolean } export function UserMenu({ @@ -96,7 +85,8 @@ export function UserMenu({ userAvatarVersion, userId, onOpenSettings, - systemNavigation, + canAccessSystemAdmin = false, + sidebarTrigger = false, }: UserMenuProps) { const router = useRouter() const locale = useLocale() as LocaleCode @@ -104,9 +94,10 @@ export function UserMenu({ const searchParams = useSearchParams() const search = searchParams.toString() const tUserMenu = useTranslations('workspace.userMenu') - const messages = useMessages() as UserMenuMessages + const tWorkspaceNav = useTranslations('workspace.nav') const [isSigningOut, setIsSigningOut] = useState(false) const [isOpeningBillingPortal, setIsOpeningBillingPortal] = useState(false) + const [nameOverride, setNameOverride] = useState(null) const [avatarOverride, setAvatarOverride] = useState<{ url: string | null version: number | string | null @@ -146,8 +137,7 @@ export function UserMenu({ const currentThemeOption = THEME_OPTIONS.find((option) => option.value === theme) ?? THEME_OPTIONS[0] const currentThemeLabel = themeOptionLabels[currentThemeOption.value] - const themeLabelTemplate = messages.workspace?.userMenu?.themeLabel ?? 'Theme: {theme}' - const currentThemeAriaLabel = formatTemplate(themeLabelTemplate, { theme: currentThemeLabel }) + const currentThemeAriaLabel = tUserMenu('themeLabel', { theme: currentThemeLabel }) const [isHelpModalOpen, setIsHelpModalOpen] = useState(false) const activeOrganization = organizationsData?.activeOrganization const activeOrganizationId = activeOrganization?.id @@ -179,6 +169,39 @@ export function UserMenu({ const canOpenTeamSettings = organizationAccess.canOpenTeamSettings const canManageSSOSettings = organizationAccess.canConfigureSso + useEffect(() => { + if (!userId || typeof window === 'undefined') { + setNameOverride(null) + return + } + + const key = `user-name-${userId}` + + const readStoredName = () => { + const storedName = window.localStorage.getItem(key) + setNameOverride(storedName !== null ? storedName || null : null) + } + + const handleStorage = (event: StorageEvent) => { + if (event.key === key) { + readStoredName() + } + } + + const handleNameEvent = (event: Event) => { + const detail = (event as CustomEvent<{ name?: string | null }>).detail + setNameOverride(detail && 'name' in detail ? (detail.name ?? null) : null) + } + + readStoredName() + window.addEventListener('storage', handleStorage) + window.addEventListener('user-name-updated', handleNameEvent) + return () => { + window.removeEventListener('storage', handleStorage) + window.removeEventListener('user-name-updated', handleNameEvent) + } + }, [userId]) + useEffect(() => { if (!userId || typeof window === 'undefined') return @@ -228,6 +251,7 @@ export function UserMenu({ return () => window.removeEventListener('user-avatar-updated', handler) }, []) + const displayUserName = nameOverride ?? userName const effectiveAvatar = avatarOverride.url ?? userAvatar const effectiveVersion = avatarOverride.version ?? userAvatarVersion @@ -308,305 +332,290 @@ export function UserMenu({ } } - return ( - <> - - - + const avatar = ( + + {avatarSrc ? ( + + ) : ( + + )} + {getInitials(displayUserName)} + + ) + const triggerLabel = `${displayUserName} ${userMenuCopy.accountDetail}` + + const menuContent = ( + + +
+ - - - {avatarSrc ? ( - - ) : ( - - )} - {getInitials(userName)} - -
- {userName} - {userEmail} -
- -
+ + +
- -
- - - - - - {THEME_OPTIONS.map(({ value, Icon }) => { - const label = themeOptionLabels[value] - const isActive = theme === value - - return ( - { - if (isActive) { - event.preventDefault() - return - } - void handleThemeChange(value) - }} - > - - ) - })} - - - - - - - - {locales.map((code) => { - const isActive = code === locale - - return ( - { - if (isActive) { - event.preventDefault() - return - } - handleLocaleChange(code) - }} - > - {getLocaleDisplayName(code)} - {isActive ? ( - - ) : null} - - ) - })} - - -
-
- - - { - event.preventDefault() - if (onOpenSettings) { - onOpenSettings('account') - } else if (typeof window !== 'undefined') { - window.dispatchEvent( - new CustomEvent('open-settings', { detail: { tab: 'account' } }) - ) - } - }} - > - - {userMenuCopy.accountDetail} - - {isHosted ? ( + {THEME_OPTIONS.map(({ value, Icon }) => { + const label = themeOptionLabels[value] + const isActive = theme === value + + return ( { - event.preventDefault() - if (onOpenSettings) { - onOpenSettings('service') - } else if (typeof window !== 'undefined') { - window.dispatchEvent( - new CustomEvent('open-settings', { detail: { tab: 'service' } }) - ) + if (isActive) { + event.preventDefault() + return } + void handleThemeChange(value) }} > - - {userMenuCopy.serviceApiKeys} + - ) : null} - - {billingEnabled ? ( - <> - - - { - event.preventDefault() - if (onOpenSettings) { - onOpenSettings('subscription') - } else if (typeof window !== 'undefined') { - window.dispatchEvent( - new CustomEvent('open-settings', { detail: { tab: 'subscription' } }) - ) - } - }} - > - - {userMenuCopy.subscription} - - { - event.preventDefault() - void handleOpenBillingPortal() - }} - > - - {isOpeningBillingPortal - ? userMenuCopy.openingBilling - : userMenuCopy.manageBilling} - - - - ) : null} - {canOpenTeamSettings || canManageSSOSettings ? ( - <> - - - {canOpenTeamSettings ? ( - { - event.preventDefault() - if (onOpenSettings) { - onOpenSettings('team') - } else if (typeof window !== 'undefined') { - window.dispatchEvent( - new CustomEvent('open-settings', { detail: { tab: 'team' } }) - ) - } - }} - > - - {userMenuCopy.teamManagement} - - ) : null} - {canManageSSOSettings ? ( - { - event.preventDefault() - if (onOpenSettings) { - onOpenSettings('sso') - } else if (typeof window !== 'undefined') { - window.dispatchEvent( - new CustomEvent('open-settings', { detail: { tab: 'sso' } }) - ) - } - }} - > - - {userMenuCopy.singleSignOn} - - ) : null} - - - ) : null} - {systemNavigation ? ( - <> - - - { + ) + })} +
+
+ + + + + + {locales.map((code) => { + const isActive = code === locale + + return ( + { + if (isActive) { event.preventDefault() - router.push(systemNavigation.href) - }} - > - - {systemNavigation.label} - - - - ) : null} - - - { - event.preventDefault() - setIsHelpModalOpen(true) - }} - > - - {userMenuCopy.helpSupport} - - - + return + } + handleLocaleChange(code) + }} + > + {getLocaleDisplayName(code)} + {isActive ? : null} + + ) + })} + + +
+
+ + + { + event.preventDefault() + onOpenSettings('account') + }} + > + + {userMenuCopy.accountDetail} + + {isHosted ? ( + { + event.preventDefault() + onOpenSettings('service') + }} + > + + {userMenuCopy.serviceApiKeys} + + ) : null} + + {billingEnabled ? ( + <> + + + { + event.preventDefault() + onOpenSettings('subscription') + }} + > + + {userMenuCopy.subscription} + + { + event.preventDefault() + void handleOpenBillingPortal() + }} + > + + {isOpeningBillingPortal ? userMenuCopy.openingBilling : userMenuCopy.manageBilling} + + + + ) : null} + {canOpenTeamSettings || canManageSSOSettings ? ( + <> + + + {canOpenTeamSettings ? ( { event.preventDefault() - void handleSignOut() + onOpenSettings('team') }} - className='text-destructive focus:text-destructive' > - - {isSigningOut ? userMenuCopy.loggingOut : userMenuCopy.logOut} + + {userMenuCopy.teamManagement} -
-
-
-
+ ) : null} + {canManageSSOSettings ? ( + { + event.preventDefault() + onOpenSettings('sso') + }} + > + + {userMenuCopy.singleSignOn} + + ) : null} + + + ) : null} + {canAccessSystemAdmin ? ( + <> + + + { + event.preventDefault() + router.push('/admin') + }} + > + + {tWorkspaceNav('systemAdmin')} + + + + ) : null} + + + { + event.preventDefault() + setIsHelpModalOpen(true) + }} + > + + {userMenuCopy.helpSupport} + + + + { + event.preventDefault() + void handleSignOut() + }} + className='text-destructive focus:text-destructive' + > + + {isSigningOut ? userMenuCopy.loggingOut : userMenuCopy.logOut} + + + ) + + return ( + <> + + {sidebarTrigger ? ( + + + + + {avatar} +
+ {displayUserName} + {userEmail} +
+ +
+
+
+
+ ) : ( + + + + )} + {menuContent} +
) diff --git a/apps/tradinggoose/global-navbar/components/workspace-dialogs.tsx b/apps/tradinggoose/global-navbar/components/workspace-dialogs.tsx index 4065fd783..df543a542 100644 --- a/apps/tradinggoose/global-navbar/components/workspace-dialogs.tsx +++ b/apps/tradinggoose/global-navbar/components/workspace-dialogs.tsx @@ -2,8 +2,8 @@ import React, { type KeyboardEvent, useCallback, useEffect, useMemo, useRef, useState } from 'react' import { Loader2, RotateCw, X } from 'lucide-react' -import { useLocale } from 'next-intl' import { useParams } from 'next/navigation' +import { useLocale } from 'next-intl' import { AlertDialog, AlertDialogAction, @@ -18,20 +18,19 @@ import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' import { Skeleton } from '@/components/ui/skeleton' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip' -import { useSession } from '@/lib/auth-client' import { quickValidateEmail } from '@/lib/email/validation' import { createLogger } from '@/lib/logs/console/logger' import type { PermissionType } from '@/lib/permissions/utils' import { cn } from '@/lib/utils' import { - WorkspacePermissionsProvider, useUserPermissionsContext, useWorkspacePermissionsContext, + WorkspacePermissionsProvider, } from '@/app/workspace/[workspaceId]/providers/workspace-permissions-provider' -import { useOptionalWorkflowRoute } from '@/widgets/widgets/editor_workflow/context/workflow-route-context' import type { WorkspacePermissions } from '@/hooks/use-workspace-permissions' +import type { LocaleCode } from '@/i18n/utils' import { API_ENDPOINTS } from '@/stores/constants' -import { type LocaleCode } from '@/i18n/utils' +import { useOptionalWorkflowRoute } from '@/widgets/widgets/editor_workflow/context/workflow-route-context' import type { Workspace } from '../types' const logger = createLogger('WorkspaceInviteModal') @@ -39,8 +38,11 @@ const logger = createLogger('WorkspaceInviteModal') interface WorkspaceInviteModalProps { open: boolean onOpenChange: (open: boolean) => void + currentUserId: string + currentUserEmail: string | null workspaceName?: string workspaceId?: string + workspaceOwnerId?: string } interface EmailTagProps { @@ -61,6 +63,8 @@ interface UserPermissions { } interface PermissionsTableProps { + currentUserId: string + currentUserEmail: string | null userPermissions: UserPermissions[] onPermissionChange: (userId: string, permissionType: PermissionType) => void onRemoveMember?: (userId: string, email: string) => void @@ -73,6 +77,7 @@ interface PermissionsTableProps { permissionsLoading: boolean pendingInvitations: UserPermissions[] isPendingInvitationsLoading: boolean + workspaceOwnerId?: string resendingInvitationIds?: Record resentInvitationIds?: Record resendCooldowns?: Record @@ -134,9 +139,7 @@ const PermissionSelector = React.memo<{ ) return ( -
+
{permissionOptions.map((option, index) => ( - - -

{isPendingInvitation ? 'Revoke invite' : 'Remove member'}

-
- - )} + + + + + +

{isPendingInvitation ? 'Revoke invite' : 'Remove member'}

+
+
+ )}
@@ -443,8 +456,11 @@ const PermissionsTable = ({ export function WorkspaceInviteModal({ open, onOpenChange, + currentUserId, + currentUserEmail, workspaceName, workspaceId, + workspaceOwnerId, }: WorkspaceInviteModalProps) { const locale = useLocale() as LocaleCode const formRef = useRef(null) @@ -478,7 +494,6 @@ export function WorkspaceInviteModal({ const resolvedWorkspaceId = workspaceId ?? optionalRoute?.workspaceId ?? (params?.workspaceId as string | undefined) ?? null - const { data: session } = useSession() const { workspacePermissions, permissionsLoading, @@ -486,6 +501,13 @@ export function WorkspaceInviteModal({ refetchPermissions, userPermissions: userPerms, } = useWorkspacePermissionsContext() + const currentUserEmailFromPermissions = workspacePermissions?.users?.find( + (user) => user.userId === currentUserId + )?.email + const normalizedCurrentUserEmail = + currentUserEmailFromPermissions?.trim().toLowerCase() ?? + currentUserEmail?.trim().toLowerCase() ?? + null const hasPendingChanges = Object.keys(existingUserPermissionChanges).length > 0 const hasNewInvites = emails.length > 0 || inputValue.trim() @@ -562,7 +584,7 @@ export function WorkspaceInviteModal({ return false } - if (session?.user?.email && session.user.email.toLowerCase() === normalized) { + if (normalizedCurrentUserEmail === normalized) { setErrorMessage('You cannot invite yourself') setInputValue('') return false @@ -588,7 +610,13 @@ export function WorkspaceInviteModal({ setInputValue('') return true }, - [emails, invalidEmails, pendingInvitations, workspacePermissions?.users, session?.user?.email] + [ + emails, + invalidEmails, + pendingInvitations, + workspacePermissions?.users, + normalizedCurrentUserEmail, + ] ) const removeEmail = useCallback( @@ -655,8 +683,12 @@ export function WorkspaceInviteModal({ throw new Error(data.error || 'Failed to update permissions') } - if (data.users && data.total !== undefined) { - updatePermissions({ users: data.users, total: data.total }) + if (data.users && data.total !== undefined && data.currentUserPermission) { + updatePermissions({ + users: data.users, + total: data.total, + currentUserPermission: data.currentUserPermission, + }) } setExistingUserPermissionChanges({}) @@ -732,6 +764,7 @@ export function WorkspaceInviteModal({ (user) => user.userId !== memberToRemove.userId ) updatePermissions({ + ...workspacePermissions, users: updatedUsers, total: workspacePermissions.total - 1, }) @@ -1047,9 +1080,7 @@ export function WorkspaceInviteModal({ - - Invite members to {workspaceName || 'Workspace'} - + Invite members to {workspaceName || 'Workspace'}
@@ -1104,6 +1135,8 @@ export function WorkspaceInviteModal({
formRef.current?.requestSubmit()} disabled={ - !userPerms.canAdmin || isSubmitting || isSaving || !resolvedWorkspaceId || !hasNewInvites + !userPerms.canAdmin || + isSubmitting || + isSaving || + !resolvedWorkspaceId || + !hasNewInvites } className={cn( 'ml-auto flex h-9 items-center justify-center gap-2 rounded-sm px-4 py-2 font-medium transition-all duration-200', @@ -1249,6 +1287,7 @@ export function WorkspaceInviteModal({ } interface WorkspaceDialogsProps { + workspaceUser: { id: string; email: string | null } | null inviteDialogOpen: boolean onInviteDialogChange: (open: boolean) => void inviteWorkspace: Workspace | null @@ -1261,6 +1300,7 @@ interface WorkspaceDialogsProps { } export function WorkspaceDialogs({ + workspaceUser, inviteDialogOpen, onInviteDialogChange, inviteWorkspace, @@ -1271,19 +1311,42 @@ export function WorkspaceDialogs({ isDeletingWorkspace, onConfirmDelete, }: WorkspaceDialogsProps) { + const inviteIdentityMissing = Boolean(inviteWorkspace && !workspaceUser) + return ( <> - {inviteWorkspace ? ( - + {inviteWorkspace && workspaceUser ? ( + ) : null} + {inviteIdentityMissing ? ( + + + + Workspace session unavailable + + Workspace management requires an authenticated workspace session. + + + + onInviteDialogChange(false)}> + Close + + + + + ) : null} + diff --git a/apps/tradinggoose/global-navbar/global-navbar.tsx b/apps/tradinggoose/global-navbar/global-navbar.tsx index 2b69509d3..ed0d3103f 100644 --- a/apps/tradinggoose/global-navbar/global-navbar.tsx +++ b/apps/tradinggoose/global-navbar/global-navbar.tsx @@ -25,7 +25,6 @@ import { UserMenu } from './components/user-menu' import { WorkspaceDialogs } from './components/workspace-dialogs' import { WorkspaceSwitcher } from './components/workspace-switcher' import { GlobalNavbarHeaderProvider } from './header-context' -import { SettingsLoader } from './settings-loader' import { SettingsDialog } from './settings-modal/settings-dialog' import type { SettingsSection } from './settings-modal/types' import type { NavSection } from './types' @@ -41,10 +40,12 @@ import { export function GlobalNavbar({ children, isSystemAdmin = false, + workspaceUser = null, navigationMode = 'workspace', }: { children: React.ReactNode isSystemAdmin?: boolean + workspaceUser?: { id: string; email: string | null } | null navigationMode?: 'workspace' | 'admin' }) { const selectedSegments = useSelectedLayoutSegments() @@ -118,35 +119,20 @@ export function GlobalNavbar({ const [activeSettingsSection, setActiveSettingsSection] = React.useState('account') const [isSettingsModalOpen, setIsSettingsModalOpen] = React.useState(false) - const [userNameOverride, setUserNameOverride] = React.useState(null) - const [userAvatarOverride, setUserAvatarOverride] = React.useState<{ - url: string | null - version: number | string | null - }>({ url: null, version: null }) const userId = sessionData?.user?.id ?? null - const userName = userNameOverride ?? sessionData?.user?.name ?? brand.name + const userName = sessionData?.user?.name ?? brand.name const userEmail = sessionData?.user?.email ?? brand.supportEmail ?? 'support@tradinggoose.ai' - const userAvatar = userAvatarOverride.url ?? sessionData?.user?.image - const userAvatarVersion = - userAvatarOverride.version ?? - (sessionData?.user?.updatedAt ? new Date(sessionData.user.updatedAt).getTime() : null) + const userAvatar = sessionData?.user?.image + const userAvatarVersion = sessionData?.user?.updatedAt + ? new Date(sessionData.user.updatedAt).getTime() + : null const workspaceSwitcher = useWorkspaceSwitcher({ enabled: isAuthenticated && !isSessionLoading, workspaceId, section: workspaceSection, }) const canManageWorkspaces = workspaceSwitcher.canManageWorkspaces - const systemNavigation = React.useMemo(() => { - if (!isSystemAdmin || navigationMode === 'admin') { - return null - } - - return { - href: '/admin', - label: tWorkspaceNav('systemAdmin'), - } - }, [isSystemAdmin, navigationMode, tWorkspaceNav]) const resolveSettingsSection = React.useCallback( (section: SettingsSection): SettingsSection => { @@ -205,84 +191,6 @@ export function GlobalNavbar({ } }, [openSettings]) - React.useEffect(() => { - if (!userId || typeof window === 'undefined') { - setUserNameOverride(null) - return - } - - const key = `user-name-${userId}` - - const readStoredName = () => { - const storedName = window.localStorage.getItem(key) - setUserNameOverride(storedName !== null ? storedName || null : null) - } - - const handleStorage = (event: StorageEvent) => { - if (!event.key || event.key !== key) return - readStoredName() - } - - const handleNameEvent = (event: Event) => { - const customEvent = event as CustomEvent<{ name?: string | null }> - const detail = customEvent.detail - setUserNameOverride(detail && 'name' in detail ? (detail?.name ?? null) : null) - } - - readStoredName() - window.addEventListener('storage', handleStorage) - window.addEventListener('user-name-updated', handleNameEvent) - return () => { - window.removeEventListener('storage', handleStorage) - window.removeEventListener('user-name-updated', handleNameEvent) - } - }, [userId]) - - React.useEffect(() => { - if (!userId || typeof window === 'undefined') return - - const readStoredAvatar = () => { - const storedVersion = window.localStorage.getItem(`user-avatar-version-${userId}`) - const storedUrl = window.localStorage.getItem(`user-avatar-url-${userId}`) - if (storedVersion || storedUrl !== null) { - setUserAvatarOverride((prev) => ({ - url: storedUrl !== null ? storedUrl || null : prev.url, - version: storedVersion ?? prev.version, - })) - } - } - - const handleStorage = (event: StorageEvent) => { - if (!event.key) return - if ( - event.key === `user-avatar-version-${userId}` || - event.key === `user-avatar-url-${userId}` - ) { - readStoredAvatar() - } - } - - const handleAvatarEvent = (event: Event) => { - const customEvent = event as CustomEvent<{ url?: string | null; version?: number }> - const detail = customEvent.detail - setUserAvatarOverride((prev) => ({ - url: detail && 'url' in detail ? (detail?.url ?? null) : prev.url, - version: - detail && 'version' in detail - ? (detail?.version ?? prev.version ?? Date.now()) - : (prev.version ?? Date.now()), - })) - } - - readStoredAvatar() - window.addEventListener('storage', handleStorage) - window.addEventListener('user-avatar-updated', handleAvatarEvent) - return () => { - window.removeEventListener('storage', handleStorage) - window.removeEventListener('user-avatar-updated', handleAvatarEvent) - } - }, [userId]) - if (shouldShowSkeleton) { return ( @@ -337,7 +245,6 @@ export function GlobalNavbar({ return ( -
@@ -385,7 +292,8 @@ export function GlobalNavbar({ userAvatar={userAvatar} userAvatarVersion={userAvatarVersion} onOpenSettings={openSettings} - systemNavigation={systemNavigation} + canAccessSystemAdmin={isSystemAdmin && navigationMode !== 'admin'} + sidebarTrigger /> @@ -407,6 +315,7 @@ export function GlobalNavbar({ {canManageWorkspaces ? ( (null) - - useEffect(() => { - if (!userId || loadedUserRef.current === userId) return - - loadedUserRef.current = userId - void refetch().then(({ data }) => { - const preferredLocale = data?.preferredLocale - if (!preferredLocale) return - - const { locale, pathname } = stripLocaleFromPathname(window.location.pathname) - if (preferredLocale !== locale) { - replaceLocaleDocument(preferredLocale, `${pathname}${window.location.search}`) - } - }) - }, [refetch, userId]) - - return null -} diff --git a/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.test.tsx b/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.test.tsx index fdf77b214..ce7927f27 100644 --- a/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.test.tsx +++ b/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.test.tsx @@ -11,7 +11,6 @@ import type { LocaleCode } from '@/i18n/utils' import { AccountSettings } from './account-settings' const mockUseSession = vi.fn() -const mockUseGeneralSettings = vi.fn() const mockSetTelemetryEnabled = vi.fn() const generalState = { @@ -57,10 +56,6 @@ vi.mock('@/lib/auth-client', () => ({ useSession: () => mockUseSession(), })) -vi.mock('@/hooks/queries/general-settings', () => ({ - useGeneralSettings: () => mockUseGeneralSettings(), -})) - vi.mock('@/stores/settings/general/store', () => ({ useGeneralStore: (selector: (state: typeof generalState) => unknown) => selector(generalState), })) @@ -116,7 +111,6 @@ describe('AccountSettings localization', () => { beforeEach(() => { reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true mockSetTelemetryEnabled.mockReset() - mockUseGeneralSettings.mockReturnValue({ isPending: false }) mockUseSession.mockReturnValue({ data: { user: { diff --git a/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.tsx b/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.tsx index a41c9e41b..8aab32685 100644 --- a/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.tsx +++ b/apps/tradinggoose/global-navbar/settings-modal/components/account/account-settings.tsx @@ -22,7 +22,6 @@ import { useAuthRedirectUrls } from '@/lib/auth/redirect-urls' import { createLogger } from '@/lib/logs/console/logger' import { useSession } from '@/lib/auth-client' import { useProfilePictureUpload } from '@/global-navbar/settings-modal/components/hooks/use-profile-picture-upload' -import { useGeneralSettings } from '@/hooks/queries/general-settings' import { useGeneralStore } from '@/stores/settings/general/store' const logger = createLogger('AccountSettings') const DEFAULT_AVATAR_SRC = '/profile/avatar.png' @@ -41,12 +40,11 @@ export function AccountSettings() { const userId = session?.user?.id ?? null // Telemetry state from general store - const { isPending: isSettingsPending } = useGeneralSettings() const storeIsLoading = useGeneralStore((state) => state.isLoading) const telemetryEnabled = useGeneralStore((state) => state.telemetryEnabled) const isTelemetryLoading = useGeneralStore((state) => state.isTelemetryLoading) const setTelemetryEnabled = useGeneralStore((state) => state.setTelemetryEnabled) - const isTelemetrySettingsLoading = isSettingsPending || storeIsLoading + const isTelemetrySettingsLoading = storeIsLoading const handleTelemetryToggle = (checked: boolean) => { if (checked === telemetryEnabled || isTelemetryLoading) { diff --git a/apps/tradinggoose/global-navbar/settings-modal/components/subscription/subscription.tsx b/apps/tradinggoose/global-navbar/settings-modal/components/subscription/subscription.tsx index 520467e62..b0d033df9 100644 --- a/apps/tradinggoose/global-navbar/settings-modal/components/subscription/subscription.tsx +++ b/apps/tradinggoose/global-navbar/settings-modal/components/subscription/subscription.tsx @@ -13,7 +13,7 @@ import { getBillingStatus, getSubscriptionStatus, getUsage } from '@/lib/subscri import type { BillingUpgradeTarget } from '@/lib/subscription/upgrade' import { useSubscriptionUpgrade } from '@/lib/subscription/upgrade' import { cn } from '@/lib/utils' -import { useGeneralSettings, useUpdateGeneralSetting } from '@/hooks/queries/general-settings' +import { useUpdateGeneralSetting } from '@/hooks/queries/general-settings' import { useOrganizationBilling, useOrganizations } from '@/hooks/queries/organization' import { usePublicBillingCatalog } from '@/hooks/queries/public-billing-catalog' import { useSubscriptionData, useUsageLimitData } from '@/hooks/queries/subscription' @@ -181,8 +181,6 @@ export function Subscription({ onOpenChange }: SubscriptionProps) { const [isPrimaryActionPending, setIsPrimaryActionPending] = useState(false) const usageLimitRef = useRef(null) - useGeneralSettings() - const billingPayload = (subscriptionData as any)?.data ?? subscriptionData const organizationBillingPayload = (organizationBillingData as any)?.data ?? organizationBillingData diff --git a/apps/tradinggoose/global-navbar/types.ts b/apps/tradinggoose/global-navbar/types.ts index 39d1031de..1de46f1a2 100644 --- a/apps/tradinggoose/global-navbar/types.ts +++ b/apps/tradinggoose/global-navbar/types.ts @@ -17,7 +17,7 @@ export interface Workspace { name: string ownerId: string billingOwner?: { type: 'user'; userId: string } | { type: 'organization'; organizationId: string } - role?: string + role: 'owner' | 'member' membershipId?: string - permissions?: 'admin' | 'write' | 'read' | null + permissions: 'admin' | 'write' | 'read' } diff --git a/apps/tradinggoose/global-navbar/use-workspace-switcher.test.ts b/apps/tradinggoose/global-navbar/use-workspace-switcher.test.ts index b9e8807a1..b7bb50be5 100644 --- a/apps/tradinggoose/global-navbar/use-workspace-switcher.test.ts +++ b/apps/tradinggoose/global-navbar/use-workspace-switcher.test.ts @@ -5,6 +5,7 @@ import { createRoot, type Root } from 'react-dom/client' import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' const mockPush = vi.fn() +const mockReplace = vi.fn() let mockSwitchToWorkspace = vi.fn() let fetchMock: ReturnType let originalFetch: typeof globalThis.fetch @@ -29,6 +30,7 @@ afterAll(() => { vi.mock('@/i18n/navigation', () => ({ useRouter: () => ({ push: mockPush, + replace: mockReplace, }), })) @@ -44,6 +46,7 @@ vi.mock('@/stores/workflows/registry/store', () => ({ describe('useWorkspaceSwitcher', () => { beforeEach(() => { mockPush.mockReset() + mockReplace.mockReset() mockSwitchToWorkspace = vi.fn() latestValue = null @@ -100,6 +103,7 @@ describe('useWorkspaceSwitcher', () => { expect(latestValue.canManageWorkspaces).toBe(true) expect(latestValue.activeWorkspace?.id).toBe('ws-1') expect(fetchMock.mock.calls.map(([url]) => String(url))).toContain('/api/workspaces') + expect(mockReplace).not.toHaveBeenCalled() await act(async () => { latestValue.setWorkspaceMenuOpen(true) @@ -115,4 +119,26 @@ describe('useWorkspaceSwitcher', () => { expect(latestValue.inviteDialogOpen).toBe(true) expect(latestValue.deleteDialogOpen).toBe(true) }) + + it('does not redirect during the workspace bootstrap fetch (server owns the root redirect)', async () => { + const { useWorkspaceSwitcher } = await import('@/global-navbar/use-workspace-switcher') + + function Harness() { + latestValue = useWorkspaceSwitcher({ + enabled: true, + section: 'dashboard', + }) + return null + } + + await act(async () => { + root?.render(React.createElement(Harness)) + await flush() + }) + + expect(fetchMock.mock.calls.map(([url]) => String(url))).toContain('/api/workspaces') + expect(latestValue.activeWorkspace?.id).toBe('ws-1') + expect(mockReplace).not.toHaveBeenCalled() + expect(mockPush).not.toHaveBeenCalled() + }) }) diff --git a/apps/tradinggoose/global-navbar/use-workspace-switcher.ts b/apps/tradinggoose/global-navbar/use-workspace-switcher.ts index 124878bac..5dd71e3f0 100644 --- a/apps/tradinggoose/global-navbar/use-workspace-switcher.ts +++ b/apps/tradinggoose/global-navbar/use-workspace-switcher.ts @@ -18,7 +18,7 @@ export function useWorkspaceSwitcher({ workspaceId, section, }: UseWorkspaceSwitcherOptions) { - const router = useRouter() + const { push } = useRouter() const switchToWorkspace = useWorkflowRegistry((state) => state.switchToWorkspace) const canManageWorkspaces = true const [workspaces, setWorkspaces] = React.useState([]) @@ -55,20 +55,19 @@ export function useWorkspaceSwitcher({ return } - const data = await response.json() - const items = ((data.workspaces ?? []) as Workspace[]).map((workspace) => ({ - ...workspace, - permissions: workspace.permissions ?? 'admin', - role: workspace.role ?? (workspace.permissions === 'admin' ? 'owner' : 'member'), - })) + const data = (await response.json()) as { workspaces?: Workspace[] } + const items = data.workspaces ?? [] setWorkspaces(items) + const firstWorkspace = items[0] ?? null + if (workspaceId) { - const match = items.find((workspace) => workspace.id === workspaceId) - setActiveWorkspace(match ?? items[0] ?? null) + setActiveWorkspace( + items.find((workspace) => workspace.id === workspaceId) ?? firstWorkspace + ) } else { - setActiveWorkspace((current) => current ?? items[0] ?? null) + setActiveWorkspace((current) => current ?? firstWorkspace) } } catch (error) { console.error('Error fetching workspaces:', error) @@ -100,9 +99,9 @@ export function useWorkspaceSwitcher({ } } - router.push(getWorkspaceSwitchPath(workspace.id, section)) + push(getWorkspaceSwitchPath(workspace.id, section)) }, - [router, section, switchToWorkspace, workspaceId] + [push, section, switchToWorkspace, workspaceId] ) const handleCreateWorkspace = React.useCallback(async () => { @@ -128,15 +127,11 @@ export function useWorkspaceSwitcher({ throw new Error(error?.error ?? 'Failed to create workspace') } - const data = await response.json() + const data = (await response.json()) as { workspace?: Workspace } await fetchWorkspaces() if (data.workspace) { - await handleSwitchWorkspace({ - ...data.workspace, - permissions: data.workspace.permissions ?? 'admin', - role: data.workspace.role ?? 'owner', - } satisfies Workspace) + await handleSwitchWorkspace(data.workspace) } } catch (error) { console.error('Error creating workspace:', error) diff --git a/apps/tradinggoose/hooks/queries/environment.ts b/apps/tradinggoose/hooks/queries/environment.ts index 47a1ea760..480fc45d5 100644 --- a/apps/tradinggoose/hooks/queries/environment.ts +++ b/apps/tradinggoose/hooks/queries/environment.ts @@ -1,8 +1,10 @@ import { useEffect } from 'react' import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { handleAuthError } from '@/lib/auth/auth-error-handler' import type { WorkspaceEnvironmentData } from '@/lib/environment/api' import { fetchPersonalEnvironment, fetchWorkspaceEnvironment } from '@/lib/environment/api' import { createLogger } from '@/lib/logs/console/logger' +import { usePathname } from '@/i18n/navigation' import { API_ENDPOINTS } from '@/stores/constants' import { useEnvironmentStore } from '@/stores/settings/environment/store' @@ -16,12 +18,26 @@ export const environmentKeys = { export type { WorkspaceEnvironmentData } from '@/lib/environment/api' +async function throwEnvironmentResponseError( + response: Response, + reason: string, + callbackPathname: string, + message: string +): Promise { + if (response.status === 401) { + await handleAuthError(reason, callbackPathname) + } + + throw new Error(`${message}: ${response.statusText}`) +} + export function usePersonalEnvironment() { + const pathname = usePathname() const setVariables = useEnvironmentStore((state) => state.setVariables) const query = useQuery({ queryKey: environmentKeys.personal(), - queryFn: fetchPersonalEnvironment, + queryFn: () => fetchPersonalEnvironment(pathname), staleTime: 60 * 1000, placeholderData: keepPreviousData, }) @@ -39,9 +55,11 @@ export function useWorkspaceEnvironment( workspaceId: string, options?: { select?: (data: WorkspaceEnvironmentData) => TData } ) { + const pathname = usePathname() + return useQuery({ queryKey: environmentKeys.workspace(workspaceId), - queryFn: () => fetchWorkspaceEnvironment(workspaceId), + queryFn: () => fetchWorkspaceEnvironment(workspaceId, pathname), enabled: Boolean(workspaceId), staleTime: 60 * 1000, placeholderData: keepPreviousData, @@ -56,6 +74,7 @@ interface UpsertPersonalEnvironmentParams { export function useUpsertPersonalEnvironment() { const queryClient = useQueryClient() + const pathname = usePathname() return useMutation({ mutationFn: async ({ key, value }: UpsertPersonalEnvironmentParams) => { @@ -66,7 +85,12 @@ export function useUpsertPersonalEnvironment() { }) if (!response.ok) { - throw new Error(`Failed to update personal environment variable: ${response.statusText}`) + await throwEnvironmentResponseError( + response, + 'environment-query:upsert-personal', + pathname, + 'Failed to update personal environment variable' + ) } logger.info(`Upserted personal environment variable: ${key}`) @@ -85,6 +109,7 @@ interface RemovePersonalEnvironmentParams { export function useRemovePersonalEnvironment() { const queryClient = useQueryClient() + const pathname = usePathname() return useMutation({ mutationFn: async ({ key }: RemovePersonalEnvironmentParams) => { @@ -95,7 +120,12 @@ export function useRemovePersonalEnvironment() { }) if (!response.ok) { - throw new Error(`Failed to remove personal environment variable: ${response.statusText}`) + await throwEnvironmentResponseError( + response, + 'environment-query:remove-personal', + pathname, + 'Failed to remove personal environment variable' + ) } logger.info(`Removed personal environment variable: ${key}`) @@ -115,6 +145,7 @@ interface UpsertWorkspaceEnvironmentParams { export function useUpsertWorkspaceEnvironment() { const queryClient = useQueryClient() + const pathname = usePathname() return useMutation({ mutationFn: async ({ workspaceId, variables }: UpsertWorkspaceEnvironmentParams) => { @@ -125,7 +156,12 @@ export function useUpsertWorkspaceEnvironment() { }) if (!response.ok) { - throw new Error(`Failed to update workspace environment: ${response.statusText}`) + await throwEnvironmentResponseError( + response, + 'environment-query:upsert-workspace', + pathname, + 'Failed to update workspace environment' + ) } logger.info(`Upserted workspace environment variables for workspace: ${workspaceId}`) @@ -147,6 +183,7 @@ interface RemoveWorkspaceEnvironmentParams { export function useRemoveWorkspaceEnvironment() { const queryClient = useQueryClient() + const pathname = usePathname() return useMutation({ mutationFn: async ({ workspaceId, keys }: RemoveWorkspaceEnvironmentParams) => { @@ -157,7 +194,12 @@ export function useRemoveWorkspaceEnvironment() { }) if (!response.ok) { - throw new Error(`Failed to remove workspace environment keys: ${response.statusText}`) + await throwEnvironmentResponseError( + response, + 'environment-query:remove-workspace', + pathname, + 'Failed to remove workspace environment keys' + ) } logger.info(`Removed ${keys.length} workspace environment keys for workspace: ${workspaceId}`) diff --git a/apps/tradinggoose/hooks/queries/general-settings.ts b/apps/tradinggoose/hooks/queries/general-settings.ts index d31f1d594..dfafd6b1d 100644 --- a/apps/tradinggoose/hooks/queries/general-settings.ts +++ b/apps/tradinggoose/hooks/queries/general-settings.ts @@ -1,18 +1,20 @@ import { useEffect } from 'react' -import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' +import { useSession } from '@/lib/auth-client' import { createLogger } from '@/lib/logs/console/logger' +import { defaultLocale, isLocaleCode, type LocaleCode } from '@/i18n/utils' import { useGeneralStore } from '@/stores/settings/general/store' const logger = createLogger('GeneralSettingsQuery') export const generalSettingsKeys = { all: ['generalSettings'] as const, - settings: () => [...generalSettingsKeys.all, 'settings'] as const, + settings: (userId: string | null) => [...generalSettingsKeys.all, 'settings', userId] as const, } export interface GeneralSettings { theme: 'light' | 'dark' | 'system' - preferredLocale: 'en' | 'es' | 'zh' + preferredLocale: LocaleCode telemetryEnabled: boolean billingUsageNotificationsEnabled: boolean } @@ -28,7 +30,7 @@ async function fetchGeneralSettings(): Promise { return { theme: data.theme || 'system', - preferredLocale: data.preferredLocale || 'en', + preferredLocale: isLocaleCode(data.preferredLocale) ? data.preferredLocale : defaultLocale, telemetryEnabled: data.telemetryEnabled ?? true, billingUsageNotificationsEnabled: data.billingUsageNotificationsEnabled ?? true, } @@ -39,25 +41,30 @@ function syncSettingsToZustand(settings: GeneralSettings) { setSettings({ theme: settings.theme, - preferredLocale: settings.preferredLocale, telemetryEnabled: settings.telemetryEnabled, isBillingUsageNotificationsEnabled: settings.billingUsageNotificationsEnabled, }) } -export function useGeneralSettings() { +export function useGeneralSettings({ + enabled = true, + userId, +}: { + enabled?: boolean + userId: string | null +}) { const query = useQuery({ - queryKey: generalSettingsKeys.settings(), + queryKey: generalSettingsKeys.settings(userId), queryFn: fetchGeneralSettings, + enabled: enabled && Boolean(userId), staleTime: 60 * 60 * 1000, - placeholderData: keepPreviousData, }) useEffect(() => { - if (query.data) { + if (userId && query.data) { syncSettingsToZustand(query.data) } - }, [query.data]) + }, [query.data, userId]) return query } @@ -69,6 +76,9 @@ interface UpdateSettingParams { export function useUpdateGeneralSetting() { const queryClient = useQueryClient() + const { data: session } = useSession() + const userId = session?.user?.id ?? null + const settingsKey = generalSettingsKeys.settings(userId) return useMutation({ mutationFn: async ({ key, value }: UpdateSettingParams) => { @@ -85,18 +95,20 @@ export function useUpdateGeneralSetting() { return response.json() }, onMutate: async ({ key, value }) => { - await queryClient.cancelQueries({ queryKey: generalSettingsKeys.settings() }) + if (!userId) { + return { previousSettings: undefined } + } + + await queryClient.cancelQueries({ queryKey: settingsKey }) - const previousSettings = queryClient.getQueryData( - generalSettingsKeys.settings() - ) + const previousSettings = queryClient.getQueryData(settingsKey) if (previousSettings) { const newSettings = { ...previousSettings, [key]: value, } - queryClient.setQueryData(generalSettingsKeys.settings(), newSettings) + queryClient.setQueryData(settingsKey, newSettings) syncSettingsToZustand(newSettings) } @@ -104,13 +116,13 @@ export function useUpdateGeneralSetting() { }, onError: (err, _variables, context) => { if (context?.previousSettings) { - queryClient.setQueryData(generalSettingsKeys.settings(), context.previousSettings) + queryClient.setQueryData(settingsKey, context.previousSettings) syncSettingsToZustand(context.previousSettings) } logger.error('Failed to update setting:', err) }, onSuccess: () => { - queryClient.invalidateQueries({ queryKey: generalSettingsKeys.settings() }) + queryClient.invalidateQueries({ queryKey: settingsKey }) }, }) } diff --git a/apps/tradinggoose/hooks/queries/indicators.ts b/apps/tradinggoose/hooks/queries/indicators.ts index 366bddedf..6120786c1 100644 --- a/apps/tradinggoose/hooks/queries/indicators.ts +++ b/apps/tradinggoose/hooks/queries/indicators.ts @@ -142,7 +142,7 @@ export function useIndicators(workspaceId: string) { interface CreateIndicatorParams { workspaceId: string - indicator: Omit + indicator: Pick } export function useCreateIndicator() { @@ -152,19 +152,11 @@ export function useCreateIndicator() { mutationFn: async ({ workspaceId, indicator }: CreateIndicatorParams) => { logger.info(`Creating indicator: ${indicator.name} in workspace ${workspaceId}`) - const resolvedIndicator = { - ...indicator, - color: - typeof indicator.color === 'string' && indicator.color.trim().length > 0 - ? indicator.color.trim() - : undefined, - } - const response = await fetch(API_ENDPOINT, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - indicators: [resolvedIndicator], + indicators: [indicator], workspaceId, }), }) @@ -192,7 +184,7 @@ interface UpdateIndicatorParams { workspaceId: string indicatorId: string updates: Partial< - Omit + Omit > } @@ -229,7 +221,6 @@ export function useUpdateIndicator() { { id: indicatorId, name: updates.name ?? currentIndicator.name, - color: updates.color ?? currentIndicator.color, pineCode: updates.pineCode ?? currentIndicator.pineCode, inputMeta: resolvedInputMeta, }, diff --git a/apps/tradinggoose/hooks/queries/workflows.ts b/apps/tradinggoose/hooks/queries/workflows.ts index f84959e2a..3783d61c5 100644 --- a/apps/tradinggoose/hooks/queries/workflows.ts +++ b/apps/tradinggoose/hooks/queries/workflows.ts @@ -16,7 +16,6 @@ interface CreateWorkflowVariables { workspaceId: string name?: string description?: string - color?: string folderId?: string | null } @@ -25,7 +24,7 @@ export function useCreateWorkflow() { return useMutation({ mutationFn: async (variables: CreateWorkflowVariables) => { - const { workspaceId, name, description, color, folderId } = variables + const { workspaceId, name, description, folderId } = variables logger.info(`Creating new workflow in workspace: ${workspaceId}`) const requestBody: Record = { @@ -34,9 +33,6 @@ export function useCreateWorkflow() { workspaceId, folderId: folderId || null, } - if (typeof color === 'string' && color.trim().length > 0) { - requestBody.color = color.trim() - } const createResponse = await fetch('/api/workflows', { method: 'POST', diff --git a/apps/tradinggoose/hooks/queries/workspace.ts b/apps/tradinggoose/hooks/queries/workspace.ts index c292fcae4..e36a581ac 100644 --- a/apps/tradinggoose/hooks/queries/workspace.ts +++ b/apps/tradinggoose/hooks/queries/workspace.ts @@ -81,6 +81,7 @@ export interface WorkspaceSettingsResponse { } | null permissions: { users: WorkspaceSettingsUser[] + currentUserPermission: 'admin' | 'write' | 'read' } | null } @@ -176,15 +177,7 @@ async function fetchAdminWorkspaces(userId: string | undefined): Promise - user.id === userId || user.userId === userId - ) - hasAdminAccess = currentUserPermission?.permissionType === 'admin' - } + const hasAdminAccess = permissionData.currentUserPermission === 'admin' const isOwner = workspace.isOwner || workspace.ownerId === userId diff --git a/apps/tradinggoose/hooks/use-user-permissions.test.tsx b/apps/tradinggoose/hooks/use-user-permissions.test.tsx index 7b65ce389..5931d5047 100644 --- a/apps/tradinggoose/hooks/use-user-permissions.test.tsx +++ b/apps/tradinggoose/hooks/use-user-permissions.test.tsx @@ -1,10 +1,9 @@ /** @vitest-environment jsdom */ -import React, { act } from 'react' +import { act } from 'react' import { createRoot, type Root } from 'react-dom/client' -import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from 'vitest' -const mockUseSession = vi.fn() const reactActEnvironment = globalThis as typeof globalThis & { IS_REACT_ACT_ENVIRONMENT?: boolean } @@ -14,18 +13,6 @@ let container: HTMLDivElement | null = null let root: Root | null = null let latestValue: unknown = null -vi.mock('@/lib/auth-client', () => ({ - useSession: () => mockUseSession(), -})) - -vi.mock('@/lib/logs/console/logger', () => ({ - createLogger: () => ({ - error: vi.fn(), - info: vi.fn(), - warn: vi.fn(), - }), -})) - beforeAll(() => { reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true }) @@ -37,12 +24,6 @@ afterAll(() => { describe('useUserPermissions', () => { beforeEach(() => { latestValue = null - mockUseSession.mockReset() - mockUseSession.mockReturnValue({ - data: null, - isPending: true, - error: null, - }) container = document.createElement('div') document.body.appendChild(container) @@ -61,26 +42,11 @@ describe('useUserPermissions', () => { container = null }) - it('keeps permissions loading while the auth session is still pending', async () => { + it('keeps permissions loading while workspace permissions are still pending', async () => { const { useUserPermissions } = await import('@/hooks/use-user-permissions') function Harness() { - latestValue = useUserPermissions( - { - users: [ - { - userId: 'user-1', - email: 'member@example.com', - name: 'Member', - image: null, - permissionType: 'admin', - }, - ], - total: 1, - }, - false, - null - ) + latestValue = useUserPermissions(null, true, null) return null } @@ -97,18 +63,7 @@ describe('useUserPermissions', () => { }) }) - it('returns resolved permissions once the auth session is available', async () => { - mockUseSession.mockReturnValue({ - data: { - user: { - id: 'user-1', - email: 'member@example.com', - }, - }, - isPending: false, - error: null, - }) - + it('returns server-derived current user permissions', async () => { const { useUserPermissions } = await import('@/hooks/use-user-permissions') function Harness() { @@ -124,6 +79,7 @@ describe('useUserPermissions', () => { }, ], total: 1, + currentUserPermission: 'write', }, false, null diff --git a/apps/tradinggoose/hooks/use-user-permissions.ts b/apps/tradinggoose/hooks/use-user-permissions.ts index 5292b7726..13bb68d3b 100644 --- a/apps/tradinggoose/hooks/use-user-permissions.ts +++ b/apps/tradinggoose/hooks/use-user-permissions.ts @@ -1,10 +1,6 @@ import { useMemo } from 'react' -import { useSession } from '@/lib/auth-client' -import { createLogger } from '@/lib/logs/console/logger' import type { PermissionType, WorkspacePermissions } from '@/hooks/use-workspace-permissions' -const logger = createLogger('useUserPermissions') - export interface WorkspaceUserPermissions { // Core permission checks canRead: boolean @@ -31,78 +27,53 @@ export function useUserPermissions( permissionsLoading = false, permissionsError: string | null = null ): WorkspaceUserPermissions { - const { - data: session, - isPending: isSessionPending, - error: sessionError, - } = useSession() - const userPermissions = useMemo((): WorkspaceUserPermissions => { - const sessionEmail = session?.user?.email - const sessionErrorMessage = sessionError?.message ?? null - const resolvedError = permissionsError ?? sessionErrorMessage - - if (permissionsLoading || isSessionPending) { + if (permissionsLoading) { return { canRead: false, canEdit: false, canAdmin: false, userPermissions: 'read', isLoading: true, - error: resolvedError, + error: permissionsError, } } - if (!sessionEmail) { + if (permissionsError) { return { canRead: false, canEdit: false, canAdmin: false, userPermissions: 'read', isLoading: false, - error: sessionErrorMessage ?? 'Authentication required', + error: permissionsError, } } - // Find current user in workspace permissions (case-insensitive) - const currentUser = workspacePermissions?.users?.find( - (user) => user.email.toLowerCase() === sessionEmail.toLowerCase() - ) - - // If user not found in workspace, they have no permissions - if (!currentUser) { - logger.warn('User not found in workspace permissions', { - userEmail: sessionEmail, - hasPermissions: !!workspacePermissions, - userCount: workspacePermissions?.users?.length || 0, - }) - + const userPerms = workspacePermissions?.currentUserPermission + if (!userPerms) { return { canRead: false, canEdit: false, canAdmin: false, userPermissions: 'read', isLoading: false, - error: resolvedError || 'User not found in workspace', + error: 'User not found in workspace', } } - const userPerms = currentUser.permissionType || 'read' - - // Core permission checks const canAdmin = userPerms === 'admin' const canEdit = userPerms === 'write' || userPerms === 'admin' - const canRead = true // If user is found in workspace permissions, they have read access return { - canRead, + canRead: true, canEdit, canAdmin, userPermissions: userPerms, isLoading: false, - error: resolvedError, + error: null, } - }, [session, workspacePermissions, permissionsLoading, permissionsError, isSessionPending, sessionError]) + }, [workspacePermissions?.currentUserPermission, permissionsLoading, permissionsError]) return userPermissions } diff --git a/apps/tradinggoose/hooks/use-workspace-permissions.test.tsx b/apps/tradinggoose/hooks/use-workspace-permissions.test.tsx new file mode 100644 index 000000000..a76560a91 --- /dev/null +++ b/apps/tradinggoose/hooks/use-workspace-permissions.test.tsx @@ -0,0 +1,180 @@ +/** + * @vitest-environment jsdom + */ + +import { act } from 'react' +import { createRoot, type Root } from 'react-dom/client' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + resetWorkspacePermissionsStore, + useWorkspacePermissions, +} from './use-workspace-permissions' + +const mockHandleAuthError = vi.hoisted(() => vi.fn()) +let latestValue: ReturnType | null = null +let workspaceId = 'workspace-401' +let userId = 'user-1' + +vi.mock('@/lib/auth/auth-error-handler', () => ({ + handleAuthError: mockHandleAuthError, + isAuthErrorStatus: (status?: number | null) => status === 401, +})) + +vi.mock('@/i18n/navigation', () => ({ + usePathname: () => '/workspace/workspace-1/dashboard', +})) + +function WorkspacePermissionsProbe() { + latestValue = useWorkspacePermissions(workspaceId, userId) + return null +} + +describe('useWorkspacePermissions', () => { + let container: HTMLDivElement + let root: Root + const reactActEnvironment = globalThis as typeof globalThis & { + IS_REACT_ACT_ENVIRONMENT?: boolean + } + + beforeEach(() => { + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true + latestValue = null + workspaceId = 'workspace-401' + userId = 'user-1' + mockHandleAuthError.mockResolvedValue(undefined) + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue(new Response(null, { status: 401, statusText: 'Unauthorized' })) + ) + container = document.createElement('div') + document.body.appendChild(container) + root = createRoot(container) + }) + + afterEach(() => { + act(() => { + root.unmount() + }) + container.remove() + vi.unstubAllGlobals() + vi.clearAllMocks() + resetWorkspacePermissionsStore() + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = false + }) + + it('routes workspace permission 401 responses through auth recovery', async () => { + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(mockHandleAuthError).toHaveBeenCalledWith( + 'workspace-permissions', + '/workspace/workspace-1/dashboard' + ) + expect(latestValue).toMatchObject({ + loading: false, + error: 'SESSION_EXPIRED', + permissions: null, + }) + }) + + it('does not refetch a session recovery record after remounting the same user workspace key', async () => { + const fetchMock = fetch as unknown as ReturnType + + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(latestValue).toMatchObject({ + loading: false, + error: 'SESSION_EXPIRED', + permissions: null, + }) + + await act(async () => { + root.unmount() + }) + root = createRoot(container) + + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(mockHandleAuthError).toHaveBeenCalledTimes(1) + expect(latestValue).toMatchObject({ + loading: false, + error: 'SESSION_EXPIRED', + permissions: null, + }) + }) + + it('uses the server-authenticated user id instead of client session state', async () => { + const fetchMock = vi.fn().mockResolvedValue( + Response.json({ + users: [], + total: 0, + currentUserPermission: 'admin', + }) + ) + vi.stubGlobal('fetch', fetchMock) + + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(fetchMock).toHaveBeenCalledWith('/api/workspaces/workspace-401/permissions') + expect(mockHandleAuthError).not.toHaveBeenCalled() + expect(latestValue).toMatchObject({ + loading: false, + error: null, + permissions: { + currentUserPermission: 'admin', + }, + }) + }) + + it('does not reuse a cached workspace permission record after the active user changes', async () => { + workspaceId = 'workspace-1' + const fetchMock = vi + .fn() + .mockResolvedValueOnce( + Response.json({ + users: [], + total: 0, + currentUserPermission: 'admin', + }) + ) + .mockResolvedValueOnce( + Response.json({ + users: [], + total: 0, + currentUserPermission: 'read', + }) + ) + vi.stubGlobal('fetch', fetchMock) + + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(latestValue?.permissions?.currentUserPermission).toBe('admin') + + userId = 'user-2' + + await act(async () => { + root.render() + await new Promise((resolve) => setTimeout(resolve, 0)) + }) + + expect(fetchMock).toHaveBeenCalledTimes(2) + expect(latestValue?.permissions?.currentUserPermission).toBe('read') + }) +}) diff --git a/apps/tradinggoose/hooks/use-workspace-permissions.ts b/apps/tradinggoose/hooks/use-workspace-permissions.ts index 7101eea54..06a95dc8f 100644 --- a/apps/tradinggoose/hooks/use-workspace-permissions.ts +++ b/apps/tradinggoose/hooks/use-workspace-permissions.ts @@ -3,8 +3,10 @@ import { useCallback, useEffect } from 'react' import type { permissionTypeEnum } from '@tradinggoose/db/schema' import { createWithEqualityFn as create } from 'zustand/traditional' -import { handleAuthError } from '@/lib/auth/auth-error-handler' +import { handleAuthError, isAuthErrorStatus } from '@/lib/auth/auth-error-handler' +import { isSessionRecoveryAuthError } from '@/lib/auth/auth-error-copy' import { createLogger } from '@/lib/logs/console/logger' +import { usePathname } from '@/i18n/navigation' import { API_ENDPOINTS } from '@/stores/constants' const logger = createLogger('useWorkspacePermissions') @@ -22,6 +24,7 @@ export interface WorkspaceUser { export interface WorkspacePermissions { users: WorkspaceUser[] total: number + currentUserPermission: PermissionType } interface UseWorkspacePermissionsReturn { @@ -47,8 +50,12 @@ type WorkspacePermissionsRecord = { interface WorkspacePermissionsStoreState { records: Record inFlight: Partial>> - setRecord: (workspaceId: string, partial: Partial) => void - fetchPermissions: (workspaceId: string, options?: { force?: boolean }) => Promise + setRecord: (recordKey: string, partial: Partial) => void + fetchPermissions: ( + recordKey: string, + workspaceId: string, + options: { callbackPathname: string; force?: boolean } + ) => Promise } const createDefaultRecord = (): WorkspacePermissionsRecord => ({ @@ -60,29 +67,32 @@ const createDefaultRecord = (): WorkspacePermissionsRecord => ({ const useWorkspacePermissionsStore = create((set, get) => ({ records: {}, inFlight: {}, - setRecord: (workspaceId, partial) => + setRecord: (recordKey, partial) => set((state) => { - const prev = state.records[workspaceId] ?? createDefaultRecord() + const prev = state.records[recordKey] ?? createDefaultRecord() return { records: { ...state.records, - [workspaceId]: { + [recordKey]: { ...prev, ...partial, }, }, } }), - fetchPermissions: async (workspaceId, options) => { - const { force = false } = options ?? {} + fetchPermissions: async (recordKey, workspaceId, options) => { + const { callbackPathname, force = false } = options const { records, inFlight, setRecord } = get() if (!force) { - if (inFlight[workspaceId]) { - return inFlight[workspaceId] + if (inFlight[recordKey]) { + return inFlight[recordKey] } - const existing = records[workspaceId] + const existing = records[recordKey] + if (isSessionRecoveryAuthError(existing?.error)) { + return + } if (existing?.permissions && !existing?.error) { return } @@ -90,7 +100,7 @@ const useWorkspacePermissionsStore = create((set const fetchPromise = (async () => { try { - setRecord(workspaceId, { loading: true, error: null }) + setRecord(recordKey, { loading: true, error: null }) const response = await fetch(API_ENDPOINTS.WORKSPACE_PERMISSIONS(workspaceId)) @@ -98,9 +108,14 @@ const useWorkspacePermissionsStore = create((set if (response.status === 404) { throw new Error('Workspace not found or access denied') } - if (response.status === 401) { - await handleAuthError('workspace-permissions') - throw new Error('Authentication required') + if (isAuthErrorStatus(response.status)) { + await handleAuthError('workspace-permissions', callbackPathname) + setRecord(recordKey, { + permissions: null, + loading: false, + error: 'SESSION_EXPIRED', + }) + return } throw new Error(`Failed to fetch permissions: ${response.statusText}`) } @@ -113,7 +128,7 @@ const useWorkspacePermissionsStore = create((set users: data.users.map((u) => ({ email: u.email, permissions: u.permissionType })), }) - setRecord(workspaceId, { + setRecord(recordKey, { permissions: data, loading: false, error: null, @@ -124,14 +139,14 @@ const useWorkspacePermissionsStore = create((set workspaceId, error: errorMessage, }) - setRecord(workspaceId, { + setRecord(recordKey, { loading: false, error: errorMessage, }) } finally { set((state) => { const next = { ...state.inFlight } - delete next[workspaceId] + delete next[recordKey] return { inFlight: next } }) } @@ -140,7 +155,7 @@ const useWorkspacePermissionsStore = create((set set((state) => ({ inFlight: { ...state.inFlight, - [workspaceId]: fetchPromise, + [recordKey]: fetchPromise, }, })) @@ -148,44 +163,48 @@ const useWorkspacePermissionsStore = create((set }, })) -export function useWorkspacePermissions(workspaceId: string | null): UseWorkspacePermissionsReturn { - const record = useWorkspacePermissionsStore((state) => - workspaceId ? state.records[workspaceId] : undefined - ) +function getRecordKey(workspaceId: string, userId: string) { + return `${userId}:${workspaceId}` +} + +export function resetWorkspacePermissionsStore() { + useWorkspacePermissionsStore.setState({ records: {}, inFlight: {} }) +} + +export function useWorkspacePermissions( + workspaceId: string, + userId: string +): UseWorkspacePermissionsReturn { + const callbackPathname = usePathname() + const recordKey = getRecordKey(workspaceId, userId) + const record = useWorkspacePermissionsStore((state) => state.records[recordKey]) const fetchPermissions = useWorkspacePermissionsStore((state) => state.fetchPermissions) const setRecord = useWorkspacePermissionsStore((state) => state.setRecord) useEffect(() => { - if (!workspaceId) { - return () => {} - } - fetchPermissions(workspaceId).catch((error) => { + fetchPermissions(recordKey, workspaceId, { callbackPathname }).catch((error) => { logger.error('Failed to load workspace permissions', { workspaceId, error }) }) - }, [workspaceId, fetchPermissions]) + }, [workspaceId, recordKey, callbackPathname, fetchPermissions]) const refetch = useCallback(async () => { - if (!workspaceId) return - await fetchPermissions(workspaceId, { force: true }) - }, [workspaceId, fetchPermissions]) + await fetchPermissions(recordKey, workspaceId, { callbackPathname, force: true }) + }, [workspaceId, recordKey, callbackPathname, fetchPermissions]) const updatePermissions = useCallback( (newPermissions: WorkspacePermissions) => { - if (!workspaceId) return - setRecord(workspaceId, { + setRecord(recordKey, { permissions: newPermissions, loading: false, error: null, }) }, - [workspaceId, setRecord] + [recordKey, setRecord] ) - const isInitialLoad = Boolean(workspaceId) && !record - return { permissions: record?.permissions ?? null, - loading: record?.loading ?? isInitialLoad, + loading: record?.loading ?? true, error: record?.error ?? null, updatePermissions, refetch, diff --git a/apps/tradinggoose/hooks/workflow/use-accessible-reference-prefixes.ts b/apps/tradinggoose/hooks/workflow/use-accessible-reference-prefixes.ts index f98db1ad9..2aa4db375 100644 --- a/apps/tradinggoose/hooks/workflow/use-accessible-reference-prefixes.ts +++ b/apps/tradinggoose/hooks/workflow/use-accessible-reference-prefixes.ts @@ -1,13 +1,13 @@ import { useMemo } from 'react' import { BlockPathCalculator } from '@/lib/block-path-calculator' import { SYSTEM_REFERENCE_PREFIXES } from '@/lib/workflows/references' -import { normalizeBlockName } from '@/stores/workflows/utils' import { useWorkflowBlocks, useWorkflowEdges, useWorkflowLoops, useWorkflowParallels, } from '@/lib/yjs/use-workflow-doc' +import { normalizeBlockName } from '@/stores/workflows/utils' import type { Loop, Parallel } from '@/stores/workflows/workflow/types' export function useAccessibleReferencePrefixes(blockId?: string | null): Set | undefined { @@ -49,10 +49,6 @@ export function useAccessibleReferencePrefixes(blockId?: string | null): Set prefixes.add(prefix)) diff --git a/apps/tradinggoose/hooks/workflow/use-workflow-execution.test.tsx b/apps/tradinggoose/hooks/workflow/use-workflow-execution.test.tsx index 014ccd37b..e8b91d082 100644 --- a/apps/tradinggoose/hooks/workflow/use-workflow-execution.test.tsx +++ b/apps/tradinggoose/hooks/workflow/use-workflow-execution.test.tsx @@ -5,8 +5,12 @@ import { createRoot, type Root } from 'react-dom/client' import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' const mockRunQueuedWorkflowExecution = vi.hoisted(() => vi.fn()) -const mockUseCurrentWorkflow = vi.hoisted(() => vi.fn()) -const mockUseWorkflowVariables = vi.hoisted(() => vi.fn()) +const mockWorkflowDoc = vi.hoisted(() => ({})) +const mockReadWorkflowSnapshot = vi.hoisted(() => vi.fn()) +const mockUseWorkflowSession = vi.hoisted(() => vi.fn()) +const mockGetVariablesSnapshot = vi.hoisted(() => vi.fn()) + +vi.unmock('@/blocks/registry') const mockConsoleState = vi.hoisted(() => ({ cancelRunningEntries: vi.fn(), @@ -29,22 +33,12 @@ vi.mock('@/lib/workflows/queued-execution-client', () => ({ runQueuedWorkflowExecution: mockRunQueuedWorkflowExecution, })) -vi.mock('@/lib/workflows/triggers', () => ({ - TriggerUtils: { - findStartBlock: vi.fn(() => ({ blockId: 'chat-trigger', block: {} })), - getTriggerValidationMessage: vi.fn(() => 'Missing chat trigger'), - findTriggersByType: vi.fn((blocks, type) => - type === 'manual' - ? Object.values(blocks as Record).filter( - (block: any) => block.type === 'manual_trigger' - ) - : [] - ), - }, +vi.mock('@/lib/yjs/workflow-session', () => ({ + getVariablesSnapshot: mockGetVariablesSnapshot, })) -vi.mock('@/lib/yjs/use-workflow-doc', () => ({ - useWorkflowVariables: mockUseWorkflowVariables, +vi.mock('@/lib/yjs/workflow-session-host', () => ({ + useWorkflowSession: mockUseWorkflowSession, })) vi.mock('@/stores/console/store', () => { @@ -65,42 +59,45 @@ vi.mock('@/stores/execution/store', () => { } }) -vi.mock('@/stores/workflows/registry/store', () => ({ - useWorkflowRegistry: vi.fn((selector) => - selector({ - workflows: { - 'workflow-1': { - workspaceId: 'workspace-1', - }, - }, - getActiveWorkflowId: () => null, - }) - ), -})) - -vi.mock('@/stores/workflows/workflow/utils', () => ({ - generateLoopBlocks: vi.fn(() => ({})), - generateParallelBlocks: vi.fn(() => ({})), -})) - vi.mock('@/widgets/widgets/editor_workflow/context/workflow-route-context', () => ({ useWorkflowRoute: vi.fn(() => ({ workflowId: 'workflow-1', + workspaceId: 'workspace-1', channelId: 'channel-1', })), })) -vi.mock('./use-current-workflow', () => ({ - useCurrentWorkflow: mockUseCurrentWorkflow, -})) +import { useWorkflowExecution } from './use-workflow-execution' describe('useWorkflowExecution', () => { let container: HTMLDivElement | null = null let root: Root | null = null const previousActEnvironment = (globalThis as any).IS_REACT_ACT_ENVIRONMENT + const agentBlock = { + id: 'agent-1', + type: 'agent', + name: 'Agent', + enabled: true, + subBlocks: {}, + outputs: {}, + } + + function mockSingleTriggerSnapshot( + triggerId: string, + type: string, + name: string, + subBlocks: Record = {} + ) { + mockReadWorkflowSnapshot.mockReturnValue({ + blocks: { + [triggerId]: { id: triggerId, type, name, enabled: true, subBlocks, outputs: {} }, + 'agent-1': agentBlock, + }, + edges: [{ id: 'edge-1', source: triggerId, target: 'agent-1' }], + }) + } async function renderExecutionHook() { - const { useWorkflowExecution } = await import('./use-workflow-execution') const state: { execution: ReturnType | null } = { execution: null, } @@ -133,8 +130,12 @@ describe('useWorkflowExecution', () => { output: {}, logs: [], }) - mockUseWorkflowVariables.mockReturnValue([]) - mockUseCurrentWorkflow.mockReturnValue({ + mockUseWorkflowSession.mockReturnValue({ + doc: mockWorkflowDoc, + readWorkflowSnapshot: mockReadWorkflowSnapshot, + }) + mockGetVariablesSnapshot.mockReturnValue({}) + mockReadWorkflowSnapshot.mockReturnValue({ blocks: { 'chat-trigger': { id: 'chat-trigger', @@ -152,14 +153,7 @@ describe('useWorkflowExecution', () => { subBlocks: {}, outputs: {}, }, - 'agent-1': { - id: 'agent-1', - type: 'agent', - name: 'Agent', - enabled: true, - subBlocks: {}, - outputs: {}, - }, + 'agent-1': agentBlock, }, edges: [ { id: 'edge-1', source: 'chat-trigger', target: 'agent-1' }, @@ -213,7 +207,20 @@ describe('useWorkflowExecution', () => { ) }) + it('does not run chat-only workflows through editor Run', async () => { + mockSingleTriggerSnapshot('chat-trigger', 'chat_trigger', 'Chat Trigger') + + const execution = await renderExecutionHook() + + await act(async () => { + await execution.handleRunWorkflow({ triggerBlockId: 'chat-trigger' }) + }) + + expect(mockRunQueuedWorkflowExecution).not.toHaveBeenCalled() + }) + it('forwards queued execution events to the workflow caller', async () => { + mockSingleTriggerSnapshot('schedule-trigger', 'schedule', 'Schedule') const streamEvent = { type: 'stream:chunk', executionId: 'execution-1', @@ -237,7 +244,7 @@ describe('useWorkflowExecution', () => { const execution = await renderExecutionHook() await act(async () => { - await execution.handleRunWorkflow({ onEvent }) + await execution.handleRunWorkflow({ triggerBlockId: 'schedule-trigger', onEvent }) }) expect(onEvent).toHaveBeenCalledWith(streamEvent) @@ -245,7 +252,7 @@ describe('useWorkflowExecution', () => { expect(mockRunQueuedWorkflowExecution).toHaveBeenCalledWith( expect.objectContaining({ triggerType: 'manual', - startBlockId: 'manual-trigger', + triggerBlockId: 'schedule-trigger', selectedOutputs: undefined, stream: true, }), diff --git a/apps/tradinggoose/hooks/workflow/use-workflow-execution.ts b/apps/tradinggoose/hooks/workflow/use-workflow-execution.ts index c53551675..c6673ede4 100644 --- a/apps/tradinggoose/hooks/workflow/use-workflow-execution.ts +++ b/apps/tradinggoose/hooks/workflow/use-workflow-execution.ts @@ -2,16 +2,14 @@ import { useCallback, useRef, useState } from 'react' import { createLogger } from '@/lib/logs/console/logger' import type { WorkflowExecutionEvent } from '@/lib/workflows/execution-events' import { runQueuedWorkflowExecution } from '@/lib/workflows/queued-execution-client' -import { TriggerUtils } from '@/lib/workflows/triggers' -import { useWorkflowVariables } from '@/lib/yjs/use-workflow-doc' +import { resolveWorkflowRunTrigger, TriggerUtils } from '@/lib/workflows/triggers' +import { getVariablesSnapshot } from '@/lib/yjs/workflow-session' +import { useWorkflowSession } from '@/lib/yjs/workflow-session-host' import type { ExecutionResult } from '@/executor/types' -import { useLatestRef } from '@/hooks/use-latest-ref' import { useConsoleStore } from '@/stores/console/store' import { useExecutionStore } from '@/stores/execution/store' -import { useWorkflowRegistry } from '@/stores/workflows/registry/store' -import { generateLoopBlocks, generateParallelBlocks } from '@/stores/workflows/workflow/utils' +import { buildExecutableWorkflowData } from '@/stores/workflows/workflow/utils' import { useWorkflowRoute } from '@/widgets/widgets/editor_workflow/context/workflow-route-context' -import { useCurrentWorkflow } from './use-current-workflow' const logger = createLogger('useWorkflowExecution') const WORKFLOW_EXECUTION_FAILURE_MESSAGE = 'Workflow execution failed' @@ -19,6 +17,7 @@ type WorkflowExecutionTriggerType = 'chat' | 'manual' type WorkflowExecutionRequest = { input?: unknown triggerType?: WorkflowExecutionTriggerType + triggerBlockId?: string selectedOutputs?: string[] onEvent?: (event: WorkflowExecutionEvent) => void | Promise } @@ -64,35 +63,15 @@ function createExecutionId() { return globalThis.crypto.randomUUID() } -function getInputFormatTestValues(inputFormatValue: unknown): Record { - const testInput: Record = {} - if (!Array.isArray(inputFormatValue)) return testInput - - for (const field of inputFormatValue) { - if (field && typeof field === 'object' && 'name' in field && 'value' in field) { - const name = (field as { name?: unknown }).name - if (typeof name === 'string' && name.length > 0) { - testInput[name] = (field as { value?: unknown }).value - } - } - } - - return testInput -} - export function useWorkflowExecution() { - const currentWorkflow = useCurrentWorkflow() - const { workflowId: routeWorkflowId, channelId } = useWorkflowRoute() - const workflows = useWorkflowRegistry((state) => state.workflows) - const registryWorkflowId = useWorkflowRegistry((state) => state.getActiveWorkflowId(channelId)) - const activeWorkflowId = routeWorkflowId ?? registryWorkflowId + const { workflowId: activeWorkflowId, workspaceId } = useWorkflowRoute() + const { doc, error, isLoading, readWorkflowSnapshot } = useWorkflowSession() const { cancelRunningEntries } = useConsoleStore() - const yjsVariables = useWorkflowVariables() - const yjsVariablesRef = useLatestRef(yjsVariables) const abortControllerRef = useRef(null) const { isExecuting, setIsExecuting, setIsDebugging, setPendingBlocks, setActiveBlocks } = useExecutionStore() const [executionResult, setExecutionResult] = useState(null) + const isWorkflowSessionReady = Boolean(doc) && !isLoading && !error const applyExecutionEvent = useCallback( (event: WorkflowExecutionEvent) => { @@ -169,89 +148,54 @@ export function useWorkflowExecution() { ) const buildExecutionRequest = useCallback( - async (workflowInput: unknown, triggerType: WorkflowExecutionTriggerType) => { - if (!activeWorkflowId) throw new Error('Workflow target is required') - - const workspaceId = workflows[activeWorkflowId]?.workspaceId + async ( + workflowInput: unknown, + triggerType: WorkflowExecutionTriggerType, + requestedTriggerBlockId?: string + ) => { + const workflowSnapshot = readWorkflowSnapshot() + if (!workflowSnapshot || !doc) { + throw new Error('Workflow session is not ready') + } if (!workspaceId) { throw new Error('Cannot execute workflow without workspaceId') } - const validBlocks = Object.entries(currentWorkflow.blocks).reduce( - (acc, [blockId, block]) => { - if (block?.type && block.enabled !== false) { - acc[blockId] = block - } - return acc - }, - {} as typeof currentWorkflow.blocks + const workflowData = buildExecutableWorkflowData( + workflowSnapshot.blocks, + workflowSnapshot.edges ) - const isChatExecution = triggerType === 'chat' - let startBlockId: string | undefined + let triggerBlockId: string | undefined let finalWorkflowInput = workflowInput + let finalTriggerType = triggerType - if (isChatExecution) { - const startBlock = TriggerUtils.findStartBlock(validBlocks, 'chat') - if (!startBlock) { - throw new Error(TriggerUtils.getTriggerValidationMessage('chat', 'missing')) + if (triggerType === 'chat') { + const chatTrigger = TriggerUtils.findTriggerBlock(workflowData.blocks, 'chat') + if (!chatTrigger) { + throw new Error('Chat execution requires a Chat Trigger block') } - startBlockId = startBlock.blockId + triggerBlockId = chatTrigger.blockId } else { - const entries = Object.entries(validBlocks) - const apiTriggers = TriggerUtils.findTriggersByType(validBlocks, 'api') - const manualTriggers = TriggerUtils.findTriggersByType(validBlocks, 'manual') - - if (apiTriggers.length > 1) { - throw new Error('Multiple API Trigger blocks found. Keep only one.') + if (!requestedTriggerBlockId) { + throw new Error('Run requires choosing a configured trigger block') } - - let selectedTrigger: any = null - let selectedBlockId: string | null = null - - if (apiTriggers.length === 1) { - selectedTrigger = apiTriggers[0] - selectedBlockId = entries.find(([, block]) => block === selectedTrigger)?.[0] ?? null - - const testInput = getInputFormatTestValues(selectedTrigger.subBlocks?.inputFormat?.value) - if (Object.keys(testInput).length > 0) { - finalWorkflowInput = testInput - } - } else if (manualTriggers.length > 0) { - selectedTrigger = - manualTriggers.find((trigger) => trigger.type === 'manual_trigger') ?? - manualTriggers.find((trigger) => trigger.type === 'input_trigger') ?? - manualTriggers[0] - selectedBlockId = entries.find(([, block]) => block === selectedTrigger)?.[0] ?? null - - if (selectedTrigger.type === 'input_trigger') { - const testInput = getInputFormatTestValues( - selectedTrigger.subBlocks?.inputFormat?.value - ) - if (Object.keys(testInput).length > 0) { - finalWorkflowInput = testInput - } + const editorTestTrigger = resolveWorkflowRunTrigger( + workflowData.blocks, + workflowData.edges, + { + surface: 'editor', + workflowInput, + triggerBlockId: requestedTriggerBlockId, } - } else { - throw new Error('Manual run requires a Manual, Input Form, or API Trigger block') - } - - if (!selectedBlockId || !selectedTrigger) { - throw new Error('No valid trigger block found to start execution') - } - - const outgoingConnections = currentWorkflow.edges.filter( - (edge) => edge.source === selectedBlockId ) - if (outgoingConnections.length === 0) { - const triggerName = selectedTrigger.name || selectedTrigger.type - throw new Error(`${triggerName} must be connected to other blocks to execute`) - } - - startBlockId = selectedBlockId + triggerBlockId = editorTestTrigger.blockId + workflowData.blocks = editorTestTrigger.blocks + finalWorkflowInput = editorTestTrigger.input + finalTriggerType = editorTestTrigger.triggerType } - const workflowVariables = Object.values(yjsVariablesRef.current ?? {}).reduce( + const workflowVariables = Object.values(getVariablesSnapshot(doc)).reduce( (acc, variable: any) => { if (variable?.id) acc[variable.id] = variable return acc @@ -262,18 +206,13 @@ export function useWorkflowExecution() { return { workspaceId, input: finalWorkflowInput, - startBlockId, - triggerType, + triggerBlockId, + triggerType: finalTriggerType, workflowVariables, - workflowData: { - blocks: validBlocks, - edges: currentWorkflow.edges, - loops: generateLoopBlocks(validBlocks), - parallels: generateParallelBlocks(validBlocks), - }, + workflowData, } }, - [activeWorkflowId, currentWorkflow.blocks, currentWorkflow.edges, workflows] + [doc, readWorkflowSnapshot, workspaceId] ) const uploadChatFiles = useCallback( @@ -350,10 +289,14 @@ export function useWorkflowExecution() { abortControllerRef.current = abortController try { - const triggerType = request.triggerType ?? 'manual' - const executionRequest = await buildExecutionRequest(request.input, triggerType) + const requestedTriggerType = request.triggerType ?? 'manual' + const executionRequest = await buildExecutionRequest( + request.input, + requestedTriggerType, + request.triggerBlockId + ) const input = - triggerType === 'chat' + executionRequest.triggerType === 'chat' ? await uploadChatFiles( executionRequest.input, executionId, @@ -366,11 +309,11 @@ export function useWorkflowExecution() { workflowId: activeWorkflowId, executionId, input, - triggerType, + triggerType: executionRequest.triggerType, executionTarget: 'live', workflowData: executionRequest.workflowData, workflowVariables: executionRequest.workflowVariables, - startBlockId: executionRequest.startBlockId, + triggerBlockId: executionRequest.triggerBlockId, selectedOutputs: request.selectedOutputs, stream: true, signal: abortController.signal, @@ -424,6 +367,7 @@ export function useWorkflowExecution() { return { isExecuting, + isWorkflowSessionReady, executionResult, handleRunWorkflow, handleCancelExecution, diff --git a/apps/tradinggoose/i18n/messages/en.json b/apps/tradinggoose/i18n/messages/en.json index 6151514f7..bfa86d33f 100644 --- a/apps/tradinggoose/i18n/messages/en.json +++ b/apps/tradinggoose/i18n/messages/en.json @@ -43,6 +43,7 @@ "docs": "Docs", "blog": "Blog", "login": "Login", + "goToDashboard": "Go to dashboard", "menu": "Menu", "homeLabel": "Home", "languageLabel": "Language", @@ -237,7 +238,6 @@ "noAccount": "No account found with this email. Please sign up first.", "missingCredentials": "Please enter both email and password.", "emailPasswordDisabled": "Email and password login is disabled.", - "failedToCreateSession": "Failed to create session. Please try again later.", "tooManyAttempts": "Too many login attempts. Please try again later or reset your password.", "accountLocked": "Your account has been locked for security. Please reset your password.", "network": "Network error. Please check your connection and try again.", @@ -370,16 +370,8 @@ "waitlist": "Honk! Introducing TradingGoose-Studio", "open": "Honk! TradingGoose-Studio is here!" }, - "leadWords": [ - "Build", - "Test", - "Run" - ], - "highlightWords": [ - "Trading Analysis", - "Signal Detection", - "Risk Assessment" - ], + "leadWords": ["Build", "Test", "Run"], + "highlightWords": ["Trading Analysis", "Signal Detection", "Risk Assessment"], "titleConnector": "your", "suffix": "with TradingGoose", "description": "Connect your own data providers, write custom indicators to monitor market prices, and wire them into workflows that trigger trade, sell, buy, or any action you define.", @@ -713,13 +705,7 @@ "experience": { "label": "Years of Experience *", "placeholder": "Select experience level", - "options": [ - "0-1 years", - "1-3 years", - "3-5 years", - "5-10 years", - "10+ years" - ] + "options": ["0-1 years", "1-3 years", "3-5 years", "5-10 years", "10+ years"] }, "location": { "label": "Location *", @@ -3682,10 +3668,7 @@ }, "headers": { "title": "Response Headers", - "columns": [ - "Key", - "Value" - ], + "columns": ["Key", "Value"], "description": "Additional HTTP headers to include in the response" } }, @@ -17931,10 +17914,7 @@ }, "variables": { "title": "Variables", - "columns": [ - "Key", - "Value" - ] + "columns": ["Key", "Value"] }, "apiKey": { "title": "Anthropic API Key", @@ -21443,6 +21423,23 @@ "older": "Older" } }, + "mentions": { + "workflowBlocks": "Workflow Blocks", + "untitledChat": "Untitled chat", + "untitled": "Untitled", + "matches": "Matches", + "noMatches": "No matches", + "noPastChats": "No past chats", + "noWorkflows": "No workflows found", + "noSkills": "No skills found", + "noIndicators": "No indicators found", + "noCustomTools": "No custom tools found", + "noMcpServers": "No MCP servers found", + "noKnowledgeBases": "No knowledge bases found", + "noBlocksFound": "No blocks found", + "noBlocksInWorkflow": "No blocks in this workflow", + "noExecutionsFound": "No executions found" + }, "accessLevel": { "limited": { "label": "Limited", @@ -22590,4 +22587,4 @@ "sentLine": "This confirmation was sent on {dateTime}." } } -} \ No newline at end of file +} diff --git a/apps/tradinggoose/i18n/messages/es.json b/apps/tradinggoose/i18n/messages/es.json index adf6fbc61..90ca61b2e 100644 --- a/apps/tradinggoose/i18n/messages/es.json +++ b/apps/tradinggoose/i18n/messages/es.json @@ -43,6 +43,7 @@ "docs": "Documentación", "blog": "Blog", "login": "Iniciar sesión", + "goToDashboard": "Ir al panel", "menu": "Menú", "homeLabel": "Inicio", "languageLabel": "Idioma", @@ -237,7 +238,6 @@ "noAccount": "No se encontró ninguna cuenta con este correo electrónico. Regístrese primero.", "missingCredentials": "Por favor, ingrese tanto el correo electrónico como la contraseña.", "emailPasswordDisabled": "El inicio de sesión con correo electrónico y contraseña está deshabilitado.", - "failedToCreateSession": "Error al crear la sesión. Intente nuevamente más tarde.", "tooManyAttempts": "Demasiados intentos de inicio de sesión. Intente nuevamente más tarde o restablezca su contraseña.", "accountLocked": "Su cuenta ha sido bloqueada por seguridad. Restablezca su contraseña.", "network": "Error de red. Verifique su conexión e intente nuevamente.", @@ -20873,7 +20873,7 @@ "noSkillsAvailableYet": "Aún no hay skills disponibles.", "noSkillsFound": "No se encontraron skills.", "retry": "Reintentar", - "untitledSkill": "Skill sin título" + "untitledSkill": "Habilidad sin título" }, "customToolDropdown": { "placeholder": "Seleccionar herramientas personalizadas", @@ -21423,6 +21423,23 @@ "older": "Anteriores" } }, + "mentions": { + "workflowBlocks": "Bloques del flujo de trabajo", + "untitledChat": "Chat sin título", + "untitled": "Sin título", + "matches": "Coincidencias", + "noMatches": "No hay coincidencias", + "noPastChats": "No hay chats anteriores", + "noWorkflows": "No se encontraron flujos de trabajo", + "noSkills": "No se encontraron habilidades", + "noIndicators": "No se encontraron indicadores", + "noCustomTools": "No se encontraron herramientas personalizadas", + "noMcpServers": "No se encontraron servidores MCP", + "noKnowledgeBases": "No se encontraron bases de conocimiento", + "noBlocksFound": "No se encontraron bloques", + "noBlocksInWorkflow": "No hay bloques en este flujo de trabajo", + "noExecutionsFound": "No se encontraron ejecuciones" + }, "accessLevel": { "limited": { "label": "Limitado", diff --git a/apps/tradinggoose/i18n/messages/zh.json b/apps/tradinggoose/i18n/messages/zh.json index a1d25d2c7..98b8a11e5 100644 --- a/apps/tradinggoose/i18n/messages/zh.json +++ b/apps/tradinggoose/i18n/messages/zh.json @@ -43,6 +43,7 @@ "docs": "文档", "blog": "博客", "login": "登录", + "goToDashboard": "前往仪表盘", "menu": "菜单", "homeLabel": "首页", "languageLabel": "语言", @@ -237,7 +238,6 @@ "noAccount": "此邮箱未找到账户,请先注册。", "missingCredentials": "请输入邮箱和密码。", "emailPasswordDisabled": "邮箱和密码登录已禁用。", - "failedToCreateSession": "创建会话失败,请稍后重试。", "tooManyAttempts": "登录尝试次数过多,请稍后重试或重置密码。", "accountLocked": "出于安全原因,您的账户已被锁定,请重置密码。", "network": "网络错误,请检查连接后重试。", @@ -21410,6 +21410,23 @@ "older": "更早" } }, + "mentions": { + "workflowBlocks": "工作流模块", + "untitledChat": "未命名聊天", + "untitled": "未命名", + "matches": "匹配项", + "noMatches": "未找到匹配项", + "noPastChats": "没有历史聊天", + "noWorkflows": "未找到工作流", + "noSkills": "未找到技能", + "noIndicators": "未找到指标", + "noCustomTools": "未找到自定义工具", + "noMcpServers": "未找到 MCP 服务器", + "noKnowledgeBases": "未找到知识库", + "noBlocksFound": "未找到模块", + "noBlocksInWorkflow": "此工作流中没有模块", + "noExecutionsFound": "未找到执行记录" + }, "accessLevel": { "limited": { "label": "受限", diff --git a/apps/tradinggoose/i18n/navigation.ts b/apps/tradinggoose/i18n/navigation.ts index e0f945dd1..ce48943c7 100644 --- a/apps/tradinggoose/i18n/navigation.ts +++ b/apps/tradinggoose/i18n/navigation.ts @@ -1,5 +1,5 @@ import { createNavigation } from 'next-intl/navigation' -import { localizeUrl } from './utils' +import { LOCALE_COOKIE, LOCALE_COOKIE_MAX_AGE, localizeUrl } from './utils' import { routing } from './routing' // These navigation helpers localize canonical internal paths like `/verify`. @@ -13,6 +13,6 @@ export function replaceLocaleDocument(locale: typeof routing.locales[number], pa return } - document.cookie = `NEXT_LOCALE=${encodeURIComponent(locale)}; path=/; max-age=31536000; samesite=lax` + document.cookie = `${LOCALE_COOKIE}=${encodeURIComponent(locale)}; path=/; max-age=${LOCALE_COOKIE_MAX_AGE}; samesite=lax` window.location.replace(localizeUrl(window.location.origin, locale, pathname)) } diff --git a/apps/tradinggoose/i18n/public-copy.ts b/apps/tradinggoose/i18n/public-copy.ts index c9957af29..583286c14 100644 --- a/apps/tradinggoose/i18n/public-copy.ts +++ b/apps/tradinggoose/i18n/public-copy.ts @@ -25,5 +25,12 @@ export function getClientMessages( if (scope === 'workspace') return { nav: messages.nav, workspace } if (scope === 'admin') return { admin, nav: messages.nav, registration: messages.registration, workspace } - return messages + return { + ...messages, + workspace: { + nav: workspace.nav, + userMenu: workspace.userMenu, + settingsModal: workspace.settingsModal, + }, + } } diff --git a/apps/tradinggoose/i18n/routing.ts b/apps/tradinggoose/i18n/routing.ts index dba6a3c0d..828f6ec4f 100644 --- a/apps/tradinggoose/i18n/routing.ts +++ b/apps/tradinggoose/i18n/routing.ts @@ -7,8 +7,10 @@ export const defaultLocale = 'en' satisfies (typeof locales)[number] export const routing = defineRouting({ locales, defaultLocale, - localePrefix: 'as-needed', + localePrefix: 'always', + // Proxy owns locale negotiation so every renderable page URL stays explicitly prefixed. localeDetection: false, + alternateLinks: false, }) export type AppLocale = (typeof locales)[number] diff --git a/apps/tradinggoose/i18n/utils.test.ts b/apps/tradinggoose/i18n/utils.test.ts index 33179c60b..6bd8ea40c 100644 --- a/apps/tradinggoose/i18n/utils.test.ts +++ b/apps/tradinggoose/i18n/utils.test.ts @@ -3,6 +3,7 @@ import { buildLocalizedAlternates, getLocaleDisplayName, getOpenGraphLocale, + localizeDocsUrl, localizeSiteUrl, localizeUrl, normalizeCallbackUrl, @@ -44,6 +45,17 @@ describe('i18n utils', () => { expect(normalizeCallbackUrl('workspace/ws-1/dashboard')).toBeNull() }) + it('rejects localized callback URLs', () => { + expect(normalizeCallbackUrl('/en?source=nav')).toBeNull() + expect(normalizeCallbackUrl('/es/workspace/ws-1/dashboard?layoutId=layout-1')).toBeNull() + expect( + normalizeCallbackUrl( + 'https://tradinggoose.ai/zh/invite/invitation-1?token=abc', + 'https://tradinggoose.ai' + ) + ).toBeNull() + }) + it('builds localized site URLs and alternate hreflang mappings', () => { const previousAppUrl = process.env.NEXT_PUBLIC_APP_URL process.env.NEXT_PUBLIC_APP_URL = 'https://preview.example.com' @@ -54,10 +66,10 @@ describe('i18n utils', () => { expect(buildLocalizedAlternates('es', '/blog')).toEqual({ canonical: 'https://preview.example.com/es/blog', languages: { - en: 'https://preview.example.com/blog', + en: 'https://preview.example.com/en/blog', es: 'https://preview.example.com/es/blog', zh: 'https://preview.example.com/zh/blog', - 'x-default': 'https://preview.example.com/blog', + 'x-default': 'https://preview.example.com/en/blog', }, }) } finally { @@ -74,13 +86,18 @@ describe('i18n utils', () => { 'https://tradinggoose.ai/es/reset-password?token=abc' ) expect(localizeUrl('https://tradinggoose.ai', 'en', '/workspace')).toBe( - 'https://tradinggoose.ai/workspace' + 'https://tradinggoose.ai/en/workspace' ) expect(localizeUrl('https://tradinggoose.ai', 'invalid', '/login')).toBe( - 'https://tradinggoose.ai/login' + 'https://tradinggoose.ai/en/login' ) }) + it('keeps docs URLs aligned with the docs app locale contract', () => { + expect(localizeDocsUrl('en')).toBe('https://docs.tradinggoose.ai/') + expect(localizeDocsUrl('zh', '/widgets')).toBe('https://docs.tradinggoose.ai/zh/widgets') + }) + it('rejects non-canonical app URL inputs', () => { expect(() => localizeUrl('https://tradinggoose.ai', 'es', '/zh/login')).toThrow( 'Expected an unlocalized internal pathname' diff --git a/apps/tradinggoose/i18n/utils.ts b/apps/tradinggoose/i18n/utils.ts index 872f24a01..8c9e6f72f 100644 --- a/apps/tradinggoose/i18n/utils.ts +++ b/apps/tradinggoose/i18n/utils.ts @@ -7,8 +7,9 @@ export type LocaleInput = LocaleCode | string | null | undefined export { defaultLocale, isLocaleCode, locales } -export const SITE_BASE_URL = getBaseUrl() export const CANONICAL_CALLBACK_PATH_HEADER = 'x-tradinggoose-callback-path' +export const LOCALE_COOKIE = 'NEXT_LOCALE' +export const LOCALE_COOKIE_MAX_AGE = 60 * 60 * 24 * 365 const LOCALE_DISPLAY_NAMES: Record = { en: 'English', es: 'Español', @@ -26,6 +27,15 @@ export function normalizeLocaleCode(locale: LocaleInput): LocaleCode { return locale && isLocaleCode(locale) ? locale : defaultLocale } +export function requireCanonicalCallbackPath(headers: Headers, routeName: string) { + const callbackUrl = headers.get(CANONICAL_CALLBACK_PATH_HEADER) + if (!callbackUrl) { + throw new Error(`Missing canonical callback path for ${routeName} reauth redirect`) + } + + return callbackUrl +} + export function getLocaleDisplayName(locale: LocaleCode) { return LOCALE_DISPLAY_NAMES[locale] } @@ -51,23 +61,27 @@ export function stripLocaleFromPathname(pathname: string): { } } -function prefixLocalePathname(locale: LocaleCode, pathname: string) { +function prefixLocalePathname(locale: LocaleCode, pathname: string, includeDefaultLocale = true) { const normalized = pathname === '/' ? '/' : pathname.replace(/\/+$/, '') - if (locale === defaultLocale) { + if (!includeDefaultLocale && locale === defaultLocale) { return normalized } return normalized === '/' ? `/${locale}` : `/${locale}${normalized}` } +function isLocalizedInternalPathname(pathname: string) { + const firstSegment = pathname.split(/[?#]/, 1)[0].split('/').filter(Boolean)[0] + return Boolean(firstSegment && isLocaleCode(firstSegment)) +} + function assertCanonicalInternalPathname(pathname: string) { if (!pathname.startsWith('/') || pathname.startsWith('//')) { throw new Error(`Expected a canonical internal pathname, received "${pathname}"`) } - const firstSegment = pathname.split(/[?#]/, 1)[0].split('/').filter(Boolean)[0] - if (firstSegment && isLocaleCode(firstSegment)) { + if (isLocalizedInternalPathname(pathname)) { throw new Error(`Expected an unlocalized internal pathname, received "${pathname}"`) } } @@ -88,6 +102,9 @@ export function normalizeCallbackUrl( if (trimmedHref.startsWith('/')) { const parsedUrl = new URL(trimmedHref, 'http://tradinggoose.local') + if (isLocalizedInternalPathname(parsedUrl.pathname)) { + return null + } return `${parsedUrl.pathname}${parsedUrl.search}${parsedUrl.hash}` } @@ -102,6 +119,9 @@ export function normalizeCallbackUrl( return null } + if (isLocalizedInternalPathname(parsedUrl.pathname)) { + return null + } return `${parsedUrl.pathname}${parsedUrl.search}${parsedUrl.hash}` } catch { return null @@ -118,7 +138,8 @@ export function localizeSiteUrl(locale: LocaleCode, pathname: string) { } export function localizeDocsUrl(locale: LocaleCode, pathname = '/') { - return localizeUrl(DOCS_BASE_URL, locale, pathname) + assertCanonicalInternalPathname(pathname) + return `${DOCS_BASE_URL}${prefixLocalePathname(locale, pathname, false)}` } export function getOpenGraphLocale(locale: LocaleCode) { diff --git a/apps/tradinggoose/instrumentation-edge.ts b/apps/tradinggoose/instrumentation-edge.ts deleted file mode 100644 index 103cceabf..000000000 --- a/apps/tradinggoose/instrumentation-edge.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * TradingGoose Telemetry - Edge Runtime Instrumentation - * - * This file contains Edge Runtime-compatible instrumentation logic. - * No Node.js APIs (like process.on, crypto, fs, etc.) are allowed here. - */ - -import { createLogger } from './lib/logs/console/logger' - -const logger = createLogger('EdgeInstrumentation') - -export async function register() { - try { - logger.info('Edge Runtime instrumentation initialized') - } catch (error) { - logger.error('Failed to initialize Edge Runtime instrumentation', error) - } -} diff --git a/apps/tradinggoose/instrumentation-node.ts b/apps/tradinggoose/instrumentation-node.ts deleted file mode 100644 index a41cd8ea3..000000000 --- a/apps/tradinggoose/instrumentation-node.ts +++ /dev/null @@ -1,117 +0,0 @@ -/** - * TradingGoose OpenTelemetry - Server-side Instrumentation - */ - -import { DiagConsoleLogger, DiagLogLevel, diag } from '@opentelemetry/api' -import { env } from './lib/env' -import { createLogger } from './lib/logs/console/logger' - -diag.setLogger(new DiagConsoleLogger(), DiagLogLevel.ERROR) - -const logger = createLogger('OTelInstrumentation') - -const DEFAULT_TELEMETRY_CONFIG = { - endpoint: env.TELEMETRY_ENDPOINT || 'https://telemetry.tradinggoose.ai/v1/traces', - serviceName: 'tradinggoose-studio', - serviceVersion: '0.1.0', - serverSide: { enabled: true }, - batchSettings: { - maxQueueSize: 2048, - maxExportBatchSize: 512, - scheduledDelayMillis: 5000, - exportTimeoutMillis: 30000, - }, -} - -/** - * Initialize OpenTelemetry SDK with proper configuration - */ -async function initializeOpenTelemetry() { - try { - if (env.NEXT_TELEMETRY_DISABLED === '1') { - logger.info('OpenTelemetry disabled via NEXT_TELEMETRY_DISABLED=1') - return - } - - let telemetryConfig - try { - telemetryConfig = (await import('./telemetry.config')).default - } catch { - telemetryConfig = DEFAULT_TELEMETRY_CONFIG - } - - if (telemetryConfig.serverSide?.enabled === false) { - logger.info('Server-side OpenTelemetry disabled in config') - return - } - - const { NodeSDK } = await import('@opentelemetry/sdk-node') - const { defaultResource, resourceFromAttributes } = await import('@opentelemetry/resources') - const { ATTR_SERVICE_NAME, ATTR_SERVICE_VERSION, ATTR_DEPLOYMENT_ENVIRONMENT } = await import( - '@opentelemetry/semantic-conventions/incubating' - ) - const { OTLPTraceExporter } = await import('@opentelemetry/exporter-trace-otlp-http') - const { BatchSpanProcessor } = await import('@opentelemetry/sdk-trace-node') - const { ParentBasedSampler, TraceIdRatioBasedSampler } = await import( - '@opentelemetry/sdk-trace-base' - ) - - const exporter = new OTLPTraceExporter({ - url: telemetryConfig.endpoint, - headers: {}, - timeoutMillis: telemetryConfig.batchSettings.exportTimeoutMillis, - }) as any - - const spanProcessor = new BatchSpanProcessor(exporter, { - maxQueueSize: telemetryConfig.batchSettings.maxQueueSize, - maxExportBatchSize: telemetryConfig.batchSettings.maxExportBatchSize, - scheduledDelayMillis: telemetryConfig.batchSettings.scheduledDelayMillis, - exportTimeoutMillis: telemetryConfig.batchSettings.exportTimeoutMillis, - }) as any - - const resource = defaultResource().merge( - resourceFromAttributes({ - [ATTR_SERVICE_NAME]: telemetryConfig.serviceName, - [ATTR_SERVICE_VERSION]: telemetryConfig.serviceVersion, - [ATTR_DEPLOYMENT_ENVIRONMENT]: env.NODE_ENV || 'development', - 'service.namespace': 'tradinggoose-ai-platform', - 'telemetry.sdk.name': 'opentelemetry', - 'telemetry.sdk.language': 'nodejs', - 'telemetry.sdk.version': '1.0.0', - }) - ) - - const sampler = new ParentBasedSampler({ - root: new TraceIdRatioBasedSampler(0.1), // 10% sampling for root spans - }) - - const sdk = new NodeSDK({ - resource, - spanProcessor, - sampler, - traceExporter: exporter, - }) - - sdk.start() - - const shutdownHandler = async () => { - try { - await sdk.shutdown() - logger.info('OpenTelemetry SDK shut down successfully') - } catch (err) { - logger.error('Error shutting down OpenTelemetry SDK', err) - } - } - - process.on('SIGTERM', shutdownHandler) - process.on('SIGINT', shutdownHandler) - - logger.info('OpenTelemetry instrumentation initialized') - } catch (error) { - logger.error('Failed to initialize OpenTelemetry instrumentation', error) - } -} - -export async function register() { - await initializeOpenTelemetry() -} diff --git a/apps/tradinggoose/instrumentation.node.ts b/apps/tradinggoose/instrumentation.node.ts new file mode 100644 index 000000000..e35de857e --- /dev/null +++ b/apps/tradinggoose/instrumentation.node.ts @@ -0,0 +1,101 @@ +import { DiagConsoleLogger, DiagLogLevel, diag } from '@opentelemetry/api' +import { env } from './lib/env' +import { createLogger } from './lib/logs/console/logger' + +const logger = createLogger('OTelInstrumentation') + +const TELEMETRY_ENDPOINT = 'https://telemetry.tradinggoose.ai/v1/traces' +const SERVICE_NAME = 'tradinggoose-studio' +const SERVICE_VERSION = '0.1.0' +const SAMPLE_RATE = 0.1 + +const batchSettings = { + maxQueueSize: 2048, + maxExportBatchSize: 512, + scheduledDelayMillis: 5000, + exportTimeoutMillis: 30000, +} + +const telemetryState = globalThis as typeof globalThis & { + __TRADINGGOOSE_OTEL__?: { + initialized: boolean + shutdownRegistered: boolean + shutdown?: () => Promise + } +} + +export async function register() { + if (env.NEXT_TELEMETRY_DISABLED === '1') { + logger.info('OpenTelemetry disabled via NEXT_TELEMETRY_DISABLED=1') + return + } + + const state = (telemetryState.__TRADINGGOOSE_OTEL__ ??= { + initialized: false, + shutdownRegistered: false, + }) + + if (state.initialized) { + return + } + + try { + diag.setLogger(new DiagConsoleLogger(), DiagLogLevel.ERROR) + + const [ + { OTLPTraceExporter }, + { resourceFromAttributes }, + { BatchSpanProcessor, ParentBasedSampler, TraceIdRatioBasedSampler }, + { NodeTracerProvider }, + ] = await Promise.all([ + import('@opentelemetry/exporter-trace-otlp-http'), + import('@opentelemetry/resources'), + import('@opentelemetry/sdk-trace-base'), + import('@opentelemetry/sdk-trace-node'), + ]) + + const exporter = new OTLPTraceExporter({ + url: env.TELEMETRY_ENDPOINT || TELEMETRY_ENDPOINT, + timeoutMillis: batchSettings.exportTimeoutMillis, + }) + + const provider = new NodeTracerProvider({ + resource: resourceFromAttributes({ + 'service.name': SERVICE_NAME, + 'service.version': SERVICE_VERSION, + 'service.namespace': 'tradinggoose-ai-platform', + 'deployment.environment': env.NODE_ENV || 'development', + 'telemetry.sdk.name': 'opentelemetry', + 'telemetry.sdk.language': 'nodejs', + }), + sampler: new ParentBasedSampler({ + root: new TraceIdRatioBasedSampler(SAMPLE_RATE), + }), + spanProcessors: [new BatchSpanProcessor(exporter, batchSettings)], + }) + + provider.register() + + state.initialized = true + state.shutdown = async () => { + try { + await provider.shutdown() + logger.info('OpenTelemetry SDK shut down successfully') + } catch (error) { + logger.error('Error shutting down OpenTelemetry SDK', error) + } finally { + state.initialized = false + } + } + + if (!state.shutdownRegistered) { + process.once('SIGTERM', () => void state.shutdown?.()) + process.once('SIGINT', () => void state.shutdown?.()) + state.shutdownRegistered = true + } + + logger.info('OpenTelemetry instrumentation initialized') + } catch (error) { + logger.error('Failed to initialize OpenTelemetry instrumentation', error) + } +} diff --git a/apps/tradinggoose/instrumentation.ts b/apps/tradinggoose/instrumentation.ts index a561b258f..2e14bb255 100644 --- a/apps/tradinggoose/instrumentation.ts +++ b/apps/tradinggoose/instrumentation.ts @@ -1,28 +1,6 @@ -/** - * OpenTelemetry Instrumentation Entry Point - * - * This is the main entry point for OpenTelemetry instrumentation. - * It delegates to runtime-specific instrumentation modules. - */ export async function register() { - // Load Node.js-specific instrumentation if (process.env.NEXT_RUNTIME === 'nodejs') { - const nodeInstrumentation = await import('./instrumentation-node') - if (nodeInstrumentation.register) { - await nodeInstrumentation.register() - } - } - - // Load Edge Runtime-specific instrumentation - if (process.env.NEXT_RUNTIME === 'edge') { - const edgeInstrumentation = await import('./instrumentation-edge') - if (edgeInstrumentation.register) { - await edgeInstrumentation.register() - } - } - - // Load client instrumentation if we're on the client - if (typeof window !== 'undefined') { - await import('./instrumentation-client') + const { register } = await import('./instrumentation.node') + await register() } } diff --git a/apps/tradinggoose/lib/admin/access.ts b/apps/tradinggoose/lib/admin/access.ts index 83f3aac2c..cdce464c1 100644 --- a/apps/tradinggoose/lib/admin/access.ts +++ b/apps/tradinggoose/lib/admin/access.ts @@ -38,8 +38,8 @@ export async function claimFirstSystemAdmin(userId: string) { }) } -export async function getSystemAdminAccess() { - const session = await getSession() +export async function getSystemAdminAccess(headersOverride?: Headers) { + const session = await getSession(headersOverride) const user = session?.user ?? null const userId = user?.id ?? null diff --git a/apps/tradinggoose/lib/auth-client.test.ts b/apps/tradinggoose/lib/auth-client.test.ts new file mode 100644 index 000000000..c1df48b23 --- /dev/null +++ b/apps/tradinggoose/lib/auth-client.test.ts @@ -0,0 +1,53 @@ +/** @vitest-environment jsdom */ + +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const mocks = vi.hoisted(() => ({ + createAuthClient: vi.fn((config: unknown) => ({ config })), +})) + +vi.mock('better-auth/react', () => ({ + createAuthClient: mocks.createAuthClient, +})) + +vi.mock('@better-auth/sso/client', () => ({ + ssoClient: vi.fn(() => 'sso-client'), +})) + +vi.mock('@better-auth/stripe/client', () => ({ + stripeClient: vi.fn(() => 'stripe-client'), +})) + +vi.mock('better-auth/client/plugins', () => ({ + customSessionClient: vi.fn(() => 'custom-session-client'), + emailOTPClient: vi.fn(() => 'email-otp-client'), + genericOAuthClient: vi.fn(() => 'generic-oauth-client'), + organizationClient: vi.fn(() => 'organization-client'), +})) + +vi.mock('@/lib/env', () => ({ + env: { + NEXT_PUBLIC_SSO_ENABLED: false, + }, +})) + +vi.mock('@/lib/session/session-context', () => ({ + SessionContext: null, +})) + +describe('auth client', () => { + beforeEach(() => { + vi.resetModules() + mocks.createAuthClient.mockClear() + }) + + it('uses the browser origin for same-origin auth requests', async () => { + await import('./auth-client') + + expect(mocks.createAuthClient).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: window.location.origin, + }) + ) + }) +}) diff --git a/apps/tradinggoose/lib/auth-client.ts b/apps/tradinggoose/lib/auth-client.ts index 615c3b86c..e1dc6a8fd 100644 --- a/apps/tradinggoose/lib/auth-client.ts +++ b/apps/tradinggoose/lib/auth-client.ts @@ -11,10 +11,9 @@ import { createAuthClient } from 'better-auth/react' import type { auth } from '@/lib/auth' import { env } from '@/lib/env' import { SessionContext, type SessionHookResult } from '@/lib/session/session-context' -import { getBaseUrl } from '@/lib/urls/utils' export const client = createAuthClient({ - baseURL: getBaseUrl(), + ...(typeof window === 'undefined' ? {} : { baseURL: window.location.origin }), plugins: [ emailOTPClient(), genericOAuthClient(), diff --git a/apps/tradinggoose/lib/auth.ts b/apps/tradinggoose/lib/auth.ts index e1f6e5c72..d3ebc5404 100644 --- a/apps/tradinggoose/lib/auth.ts +++ b/apps/tradinggoose/lib/auth.ts @@ -28,16 +28,12 @@ import { renderPasswordResetEmail, } from '@/components/emails/render-email' import { sendBillingTierWelcomeEmail } from '@/lib/billing' -import { localizeUrl } from '@/i18n/utils' import { authorizeSubscriptionReference } from '@/lib/billing/authorization' import { ensureDefaultUserSubscription, getEffectiveSubscription, } from '@/lib/billing/core/subscription' -import { - handleNewUser, - resetUserDefaultUsageToOnboardingAllowanceBalance, -} from '@/lib/billing/core/usage' +import { handleNewUser } from '@/lib/billing/core/usage' import { ensureOrganizationForOrganizationSubscription, syncSubscriptionUsageLimits, @@ -55,11 +51,11 @@ import { handleInvoicePaymentSucceeded, } from '@/lib/billing/webhooks/invoices' import { + handleStripeSubscriptionDeleted, handleSubscriptionCreated, - handleSubscriptionDeleted, } from '@/lib/billing/webhooks/subscription' -import { addVerifiedUserEmailToAudience, sendEmail } from '@/lib/email/mailer' import { resolveEmailLocale } from '@/lib/email/locale' +import { addVerifiedUserEmailToAudience, sendEmail } from '@/lib/email/mailer' import { quickValidateEmail } from '@/lib/email/validation' import { env, getEnv } from '@/lib/env' import { isEmailVerificationEnabled } from '@/lib/environment' @@ -86,6 +82,7 @@ import { } from '@/lib/system-services/stripe-runtime' import { getResolvedSystemSettings } from '@/lib/system-settings/service' import { getBaseUrl } from '@/lib/urls/utils' +import { localizeUrl } from '@/i18n/utils' import { resolveAlpacaTradingBaseUrl } from '@/providers/trading/alpaca/config' import { resolveTradierBaseUrl } from '@/providers/trading/tradier/client' import { SSO_TRUSTED_PROVIDERS } from './sso/consts' @@ -428,15 +425,14 @@ export const auth = betterAuth({ getBaseUrl(), ...(env.NEXT_PUBLIC_SOCKET_URL ? [env.NEXT_PUBLIC_SOCKET_URL] : []), ], + advanced: { + crossSubDomainCookies: { enabled: false }, + }, database: drizzleAdapter(db, { provider: 'pg', schema, }), session: { - cookieCache: { - enabled: true, - maxAge: 24 * 60 * 60, // 24 hours in seconds - }, expiresIn: 30 * 24 * 60 * 60, // 30 days (how long a session can last overall) updateAge: 24 * 60 * 60, // 24 hours (how often to refresh the expiry) freshAge: 60 * 60, // 1 hour (or set to 0 to disable completely) @@ -646,7 +642,6 @@ export const auth = betterAuth({ }), }, plugins: [ - nextCookies(), oneTimeToken({ expiresIn: 24 * 60 * 60, // 24 hours - Socket.IO handles connection persistence with heartbeats }), @@ -689,19 +684,10 @@ export const auth = betterAuth({ to: data.email, subject: getEmailSubject(data.type, locale), html, + text: `${getEmailSubject(data.type, locale)}\n\n${data.otp}`, emailType: 'transactional', }) - if (!result.success && result.message.includes('no email service configured')) { - logger.info('🔑 VERIFICATION CODE FOR LOGIN/SIGNUP', { - email: data.email, - otp: data.otp, - type: data.type, - validation: validation.checks, - }) - return - } - if (!result.success) { throw new Error(`Failed to send verification code: ${result.message}`) } @@ -1688,65 +1674,6 @@ export const auth = betterAuth({ }) } }, - onSubscriptionDeleted: async ({ - event, - stripeSubscription, - subscription, - }: { - event: Stripe.Event - stripeSubscription: Stripe.Subscription - subscription: any - }) => { - logger.info('[onSubscriptionDeleted] Subscription deleted', { - subscriptionId: subscription.id, - referenceType: subscription.referenceType, - referenceId: subscription.referenceId, - }) - - try { - await syncSubscriptionBillingTierFromStripeSubscription( - subscription.id, - stripeSubscription || (event.data.object as Stripe.Subscription | undefined) - ) - - const hydratedSubscription = await getHydratedSubscriptionById(subscription.id) - const subscriptionRecord = hydratedSubscription ?? { ...subscription, tier: null } - - await handleSubscriptionDeleted(subscriptionRecord) - - const { billingEnabled } = await getBillingGateState() - const nextSubscriptionRecord = - billingEnabled && subscriptionRecord.referenceType === 'user' - ? await ensureDefaultUserSubscription(subscriptionRecord.referenceId) - : subscriptionRecord - - if ( - nextSubscriptionRecord.referenceType === 'user' && - nextSubscriptionRecord.tier?.isDefault && - !nextSubscriptionRecord.stripeSubscriptionId - ) { - await resetUserDefaultUsageToOnboardingAllowanceBalance( - nextSubscriptionRecord.referenceId - ) - } - - await syncSubscriptionUsageLimits(nextSubscriptionRecord) - - logger.info('[onSubscriptionDeleted] Reconciled subscription usage limits', { - subscriptionId: subscription.id, - referenceType: subscription.referenceType, - referenceId: subscription.referenceId, - }) - } catch (error) { - logger.error('[onSubscriptionDeleted] Failed to handle subscription deletion', { - subscriptionId: subscription.id, - referenceType: subscription.referenceType, - referenceId: subscription.referenceId, - error, - }) - throw error - } - }, }, onEvent: async (event: Stripe.Event) => { logger.info('[onEvent] Received Stripe webhook', { @@ -1772,7 +1699,10 @@ export const auth = betterAuth({ await handleManualEnterpriseSubscription(event) break } - // Note: customer.subscription.deleted is handled by better-auth's onSubscriptionDeleted callback above + case 'customer.subscription.deleted': { + await handleStripeSubscriptionDeleted(event) + break + } default: logger.info('[onEvent] Ignoring unsupported webhook event', { eventId: event.id, @@ -1871,6 +1801,7 @@ export const auth = betterAuth({ }, }, }), + nextCookies(), ], onAPIError: { errorURL: '/error', @@ -1883,15 +1814,11 @@ export const auth = betterAuth({ }, }) -export async function getSession( - headersOverride?: Headers, - options?: { disableCookieCache?: boolean } -) { +export async function getSession(headersOverride?: Headers) { const hdrs = headersOverride ?? (await headers()) try { return await auth.api.getSession({ headers: hdrs, - ...(options ? { query: options } : {}), }) } catch (error) { logger.warn('Failed to fetch session', { error }) diff --git a/apps/tradinggoose/lib/auth/auth-error-copy.test.ts b/apps/tradinggoose/lib/auth/auth-error-copy.test.ts index 0237ba3b7..e777f7d0c 100644 --- a/apps/tradinggoose/lib/auth/auth-error-copy.test.ts +++ b/apps/tradinggoose/lib/auth/auth-error-copy.test.ts @@ -1,5 +1,12 @@ import { describe, expect, it } from 'vitest' -import { getAuthErrorContent, normalizeAuthErrorCode } from '@/lib/auth/auth-error-copy' +import { + getAuthErrorCallbackPath, + getAuthErrorContent, + isSessionRecoveryAuthError, + normalizeAuthErrorCallbackSegments, + normalizeAuthErrorCode, + resolveSsoAuthErrorMessage, +} from '@/lib/auth/auth-error-copy' import { REGISTRATION_DISABLED_REASON, REGISTRATION_WAITLIST_REASON, @@ -25,7 +32,7 @@ describe('getAuthErrorContent', () => { expect(code).toBe('UNABLE_TO_CREATE_USER') expect(content.title).toBe("We couldn't create your account") expect(content.primaryAction.href).toBe('/signup') - expect(content.secondaryAction.href).toBe('/login?reauth=1') + expect(content.secondaryAction.href).toBe('/login') }) it('falls back to the default auth error copy for unknown codes', () => { @@ -33,7 +40,59 @@ describe('getAuthErrorContent', () => { expect(code).toBe('TOTALLY_UNKNOWN_ERROR') expect(content.title).toBe(copy.auth.error.default.title) + expect(content.primaryAction.href).toBe('/login') + }) + + it.each([ + 'UNABLE_TO_CREATE_SESSION', + 'FAILED_TO_CREATE_SESSION', + 'FAILED_TO_GET_SESSION', + 'SESSION_EXPIRED', + ])('routes %s through reauth cleanup', (errorCode) => { + const { code, content } = getAuthErrorContent(copy, errorCode) + + expect(code).toBe(errorCode) expect(content.primaryAction.href).toBe('/login?reauth=1') + expect(content.secondaryAction.href).toBe('/') + }) + + it.each([ + 'UNABLE_TO_CREATE_SESSION', + 'FAILED_TO_CREATE_SESSION', + 'FAILED_TO_GET_SESSION', + 'SESSION_EXPIRED', + ])('classifies %s as a session recovery auth error', (errorCode) => { + expect(isSessionRecoveryAuthError(errorCode)).toBe(true) + }) + + it.each([ + ['ACCOUNT_NOT_FOUND', 'accountNotFound'], + ['SSO_FAILED', 'ssoFailed'], + ['INVALID_PROVIDER', 'providerNotConfigured'], + ['NO_PROVIDER_FOUND', 'providerNotConfigured'], + ] as const)('uses SSO-specific guidance for %s callback failures', (errorCode, copyKey) => { + const { content } = getAuthErrorContent(copy, errorCode) + + expect(resolveSsoAuthErrorMessage(copy, errorCode)).toBe(copy.auth.sso.errors[copyKey]) + expect(content.description).toBe(copy.auth.sso.errors[copyKey]) + expect(content.primaryAction.href).toBe('/login') + }) + + it('keeps the stored canonical destination on session recovery actions', () => { + const callbackPath = getAuthErrorCallbackPath('/invite/invitation-1?token=workspace-token') + const callback = normalizeAuthErrorCallbackSegments( + callbackPath?.split('/').filter(Boolean).slice(1) + ) + const { content } = getAuthErrorContent(copy, 'UNABLE_TO_CREATE_SESSION', null, callback) + + expect(callbackPath).toMatch(/^\/error\/callback\//) + expect(callback).toBe('/invite/invitation-1?token=workspace-token') + expect(getAuthErrorCallbackPath('/en/workspace')).toBeNull() + expect(getAuthErrorCallbackPath('https://evil.example/workspace')).toBeNull() + expect(normalizeAuthErrorCallbackSegments(['callback', 'not-valid-base64*'])).toBeNull() + expect(content.primaryAction.href).toBe( + '/login?reauth=1&callbackUrl=%2Finvite%2Finvitation-1%3Ftoken%3Dworkspace-token' + ) }) it('maps the waitlist registration reason to waitlist recovery copy', () => { @@ -51,6 +110,6 @@ describe('getAuthErrorContent', () => { expect(code).toBe('REGISTRATION_DISABLED') expect(content.title).toBe(copy.auth.error.groups.registrationDisabled.title) expect(content.description).toBe(copy.auth.error.groups.registrationDisabled.description) - expect(content.primaryAction.href).toBe('/login?reauth=1') + expect(content.primaryAction.href).toBe('/login') }) }) diff --git a/apps/tradinggoose/lib/auth/auth-error-copy.ts b/apps/tradinggoose/lib/auth/auth-error-copy.ts index 52f5736da..085134656 100644 --- a/apps/tradinggoose/lib/auth/auth-error-copy.ts +++ b/apps/tradinggoose/lib/auth/auth-error-copy.ts @@ -3,6 +3,7 @@ import { REGISTRATION_WAITLIST_REASON, } from '@/lib/registration/shared' import type { PublicCopy } from '@/i18n/public-copy' +import { normalizeCallbackUrl } from '@/i18n/utils' export interface AuthErrorAction { href: string @@ -18,11 +19,14 @@ export interface AuthErrorContent { type AuthErrorGroupKey = keyof PublicCopy['auth']['error']['groups'] -const LOGIN_HREF = '/login?reauth=1' +const LOGIN_HREF = '/login' +const REAUTH_LOGIN_HREF = '/login?reauth=1' const SIGNUP_HREF = '/signup' const HOME_HREF = '/' const VERIFY_HREF = '/verify' const WAITLIST_HREF = '/waitlist' +const AUTH_ERROR_CALLBACK_SEGMENT = 'callback' +type SsoErrorCopyKey = keyof PublicCopy['auth']['sso']['errors'] const AUTH_ERROR_GROUP_BY_CODE: Partial> = { UNABLE_TO_CREATE_USER: 'accountCreation', @@ -37,11 +41,16 @@ const AUTH_ERROR_GROUP_BY_CODE: Partial> = { CALLBACK_URL_REQUIRED: 'invalidCallback', INVALID_TOKEN: 'invalidToken', TOKEN_EXPIRED: 'expiredToken', + UNABLE_TO_CREATE_SESSION: 'sessionCreation', FAILED_TO_CREATE_SESSION: 'sessionCreation', FAILED_TO_GET_SESSION: 'sessionRestore', SESSION_EXPIRED: 'sessionExpired', + ACCOUNT_NOT_FOUND: 'userInfo', + SSO_FAILED: 'userInfo', FAILED_TO_GET_USER_INFO: 'userInfo', USER_EMAIL_NOT_FOUND: 'userInfo', + INVALID_PROVIDER: 'providerUnavailable', + NO_PROVIDER_FOUND: 'providerUnavailable', PROVIDER_NOT_FOUND: 'providerUnavailable', SOCIAL_ACCOUNT_ALREADY_LINKED: 'linkedAccount', LINKED_ACCOUNT_ALREADY_EXISTS: 'linkedAccount', @@ -52,6 +61,18 @@ const AUTH_ERROR_GROUP_BY_CODE: Partial> = { EMAIL_PASSWORD_DISABLED: 'providerUnavailable', } +const SSO_ERROR_COPY_BY_CODE: Partial> = { + ACCOUNT_NOT_FOUND: 'accountNotFound', + SSO_FAILED: 'ssoFailed', + INVALID_PROVIDER: 'providerNotConfigured', + NO_PROVIDER_FOUND: 'providerNotConfigured', + INVALID_EMAIL_DOMAIN: 'invalidEmailDomain', + NETWORK_ERROR: 'network', + RATE_LIMIT: 'rateLimit', + TOO_MANY_REQUESTS: 'rateLimit', + SSO_DISABLED: 'ssoDisabled', +} + export function normalizeAuthErrorCode(error: string | null | undefined) { if (!error) { return null @@ -66,14 +87,55 @@ export function normalizeAuthErrorCode(error: string | null | undefined) { return normalized || null } -function getAuthErrorActionCopy(localeCopy: PublicCopy) { +export function getAuthErrorCallbackPath(value: string | null | undefined) { + const callback = normalizeCallbackUrl(value) + if (!callback) { + return null + } + + const encoded = btoa(encodeURIComponent(callback)) + .replaceAll('+', '-') + .replaceAll('/', '_') + .replace(/=+$/g, '') + return `/error/${AUTH_ERROR_CALLBACK_SEGMENT}/${encoded}` +} + +export function normalizeAuthErrorCallbackSegments(value: string[] | undefined) { + if (!value || value.length !== 2 || value[0] !== AUTH_ERROR_CALLBACK_SEGMENT) { + return null + } + + try { + const base64 = value[1].replaceAll('-', '+').replaceAll('_', '/') + const paddedBase64 = base64.padEnd(Math.ceil(base64.length / 4) * 4, '=') + + return normalizeCallbackUrl(decodeURIComponent(atob(paddedBase64))) + } catch { + return null + } +} + +function appendCallbackUrl(href: string, callbackUrl: string | null | undefined) { + if (!callbackUrl) { + return href + } + + const separator = href.includes('?') ? '&' : '?' + return `${href}${separator}callbackUrl=${encodeURIComponent(callbackUrl)}` +} + +function getAuthErrorActionCopy(localeCopy: PublicCopy, callbackUrl?: string | null) { return { login: { - href: LOGIN_HREF, + href: appendCallbackUrl(LOGIN_HREF, callbackUrl), + label: localeCopy.auth.common.backToLogin, + }, + reauthLogin: { + href: appendCallbackUrl(REAUTH_LOGIN_HREF, callbackUrl), label: localeCopy.auth.common.backToLogin, }, signup: { - href: SIGNUP_HREF, + href: appendCallbackUrl(SIGNUP_HREF, callbackUrl), label: localeCopy.auth.common.backToSignup, }, home: { @@ -99,16 +161,40 @@ function resolveAuthErrorGroupKey(errorCode: string | null): AuthErrorGroupKey | return AUTH_ERROR_GROUP_BY_CODE[errorCode] ?? null } +function isSessionRecoveryGroup(groupKey: AuthErrorGroupKey) { + return ( + groupKey === 'sessionCreation' || groupKey === 'sessionRestore' || groupKey === 'sessionExpired' + ) +} + +export function isSessionRecoveryAuthError( + error: string | null | undefined, + errorDescription?: string | null +) { + const groupKey = resolveAuthErrorGroup(error, errorDescription) + return Boolean(groupKey && isSessionRecoveryGroup(groupKey)) +} + +export function resolveSsoAuthErrorMessage(copy: PublicCopy, error: string | null | undefined) { + const code = normalizeAuthErrorCode(error) + const copyKey = code ? SSO_ERROR_COPY_BY_CODE[code] : null + + return copyKey ? copy.auth.sso.errors[copyKey] : null +} + export function getAuthErrorContent( copy: PublicCopy, error: string | null | undefined, - errorDescription?: string | null + errorDescription?: string | null, + callbackUrl?: string | null ) { const code = normalizeAuthErrorCode(error) const descriptionCode = normalizeAuthErrorCode(errorDescription) const groupKey = resolveAuthErrorGroupKey(code) ?? resolveAuthErrorGroupKey(descriptionCode) - const actionCopy = getAuthErrorActionCopy(copy) + const actionCopy = getAuthErrorActionCopy(copy, callbackUrl) const normalizedDescription = errorDescription?.trim() || null + const ssoErrorDescription = + resolveSsoAuthErrorMessage(copy, code) ?? resolveSsoAuthErrorMessage(copy, descriptionCode) if (groupKey) { const group = copy.auth.error.groups[groupKey] @@ -119,7 +205,9 @@ export function getAuthErrorContent( ? actionCopy.verify : groupKey === 'waitlistLimited' ? actionCopy.waitlist - : actionCopy.login + : isSessionRecoveryGroup(groupKey) + ? actionCopy.reauthLogin + : actionCopy.login const secondaryAction = groupKey === 'accountCreation' || groupKey === 'emailVerification' || @@ -131,9 +219,10 @@ export function getAuthErrorContent( const content: AuthErrorContent = { title: group.title, description: - normalizedDescription && descriptionCode && !resolveAuthErrorGroupKey(descriptionCode) + ssoErrorDescription ?? + (normalizedDescription && descriptionCode && !resolveAuthErrorGroupKey(descriptionCode) ? normalizedDescription - : group.description, + : group.description), primaryAction, secondaryAction, } @@ -148,7 +237,8 @@ export function getAuthErrorContent( code, content: { title: copy.auth.error.default.title, - description: normalizedDescription ?? copy.auth.error.default.description, + description: + ssoErrorDescription ?? normalizedDescription ?? copy.auth.error.default.description, primaryAction: actionCopy.login, secondaryAction: actionCopy.home, }, diff --git a/apps/tradinggoose/lib/auth/auth-error-handler.test.ts b/apps/tradinggoose/lib/auth/auth-error-handler.test.ts new file mode 100644 index 000000000..0c96bcbbb --- /dev/null +++ b/apps/tradinggoose/lib/auth/auth-error-handler.test.ts @@ -0,0 +1,39 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' + +const replaceMock = vi.fn() + +function stubWindow(url: string) { + const parsedUrl = new URL(url) + + vi.stubGlobal('window', { + location: { + href: parsedUrl.href, + pathname: parsedUrl.pathname, + search: parsedUrl.search, + hash: parsedUrl.hash, + replace: replaceMock, + }, + sessionStorage: { + getItem: vi.fn(() => '0'), + setItem: vi.fn(), + }, + }) +} + +describe('handleAuthError', () => { + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('routes login-page auth recovery through login reauth cleanup', async () => { + vi.resetModules() + vi.clearAllMocks() + stubWindow('https://app.tradinggoose.ai/login?callbackUrl=%2Fworkspace#credentials') + + const { handleAuthError } = await import('./auth-error-handler') + + await handleAuthError('workspace-permissions', '/login') + + expect(replaceMock).toHaveBeenCalledWith('/login?callbackUrl=%2Fworkspace&reauth=1#credentials') + }) +}) diff --git a/apps/tradinggoose/lib/auth/auth-error-handler.ts b/apps/tradinggoose/lib/auth/auth-error-handler.ts index 11279fffd..041f4ee98 100644 --- a/apps/tradinggoose/lib/auth/auth-error-handler.ts +++ b/apps/tradinggoose/lib/auth/auth-error-handler.ts @@ -1,40 +1,17 @@ 'use client' import { createLogger } from '@/lib/logs/console/logger' -import { localizeUrl, stripLocaleFromPathname } from '@/i18n/utils' +import { isLocaleCode, normalizeCallbackUrl } from '@/i18n/utils' const logger = createLogger('AuthErrorHandler') let isHandlingAuthError = false const LAST_RECOVERY_KEY = 'tradinggoose-auth-recovery-ts' -const AUTH_COOKIE_NAMES = [ - 'better-auth.session_token', - 'better-auth.session_data', - 'better-auth.dont_remember', - '__Secure-better-auth.session_token', - '__Secure-better-auth.session_data', - '__Secure-better-auth.dont_remember', -] - -function deleteBrowserAuthCookies() { - if (typeof document === 'undefined' || typeof window === 'undefined') return - - const baseDomain = window.location.hostname - const domains = [undefined, baseDomain, `.${baseDomain}`].filter(Boolean) - - AUTH_COOKIE_NAMES.forEach((name) => { - domains.forEach((domain) => { - document.cookie = `${name}=; Max-Age=0; Path=/; SameSite=Lax${ - domain ? `; Domain=${domain}` : '' - }` - }) - }) -} function shouldRateLimitRecovery(reason?: string) { if (typeof window === 'undefined') return false // Avoid infinite reload loops on the login page by rate limiting recovery attempts - const isOnLoginPage = stripLocaleFromPathname(window.location.pathname).pathname === '/login' + const isOnLoginPage = isLoginPathname(window.location.pathname) if (!isOnLoginPage) return false const now = Date.now() @@ -48,48 +25,39 @@ function shouldRateLimitRecovery(reason?: string) { return false } -async function safeServerSignOut() { - try { - await fetch('/api/auth/sign-out', { - method: 'POST', - credentials: 'include', - headers: { 'cache-control': 'no-store' }, - }) - } catch (error) { - logger.warn('Fallback sign-out failed', { error }) - } +function isLoginPathname(pathname: string) { + const segments = pathname.split('/').filter(Boolean) + return segments[0] === 'login' || (segments[1] === 'login' && isLocaleCode(segments[0])) } /** - * Clears the current auth session when we detect an unauthorized response. - * This removes any stale tokens/cookies and forces a navigation to login so - * the user can authenticate again. + * Routes stale auth state to the login reauth flow. */ -export async function handleAuthError(reason?: string) { +export async function handleAuthError(reason: string, callbackPathname: string) { if (typeof window === 'undefined') return if (isHandlingAuthError) return if (shouldRateLimitRecovery(reason)) return + const canonicalCallbackPathname = normalizeCallbackUrl(callbackPathname) + if (!canonicalCallbackPathname) { + throw new Error('Expected a canonical auth recovery callback pathname') + } + isHandlingAuthError = true - deleteBrowserAuthCookies() - await safeServerSignOut() - if (stripLocaleFromPathname(window.location.pathname).pathname === '/login') { - logger.warn('Cleared stale auth state on login page', { reason }) - isHandlingAuthError = false + if (isLoginPathname(window.location.pathname)) { + const loginUrl = new URL(window.location.href) + loginUrl.searchParams.set('reauth', '1') + logger.warn('Routing login page through reauth cleanup', { reason }) + window.location.replace(`${loginUrl.pathname}${loginUrl.search}${loginUrl.hash}`) return } - const { locale, pathname } = stripLocaleFromPathname(window.location.pathname) - const callbackUrl = `${pathname}${window.location.search}` + const callbackUrl = `${canonicalCallbackPathname}${window.location.search}${window.location.hash}` + const loginPath = `/login?reauth=1&callbackUrl=${encodeURIComponent(callbackUrl)}` + logger.warn('Handling authentication error', { reason, callbackUrl }) - window.location.replace( - localizeUrl( - window.location.origin, - locale, - `/login?reauth=1&callbackUrl=${encodeURIComponent(callbackUrl)}` - ) - ) + window.location.replace(loginPath) } export function isAuthErrorStatus(status?: number | null): boolean { diff --git a/apps/tradinggoose/lib/auth/redirect-urls.ts b/apps/tradinggoose/lib/auth/redirect-urls.ts index af273814c..d41106969 100644 --- a/apps/tradinggoose/lib/auth/redirect-urls.ts +++ b/apps/tradinggoose/lib/auth/redirect-urls.ts @@ -1,6 +1,7 @@ 'use client' import { useLocale } from 'next-intl' +import { getAuthErrorCallbackPath } from '@/lib/auth/auth-error-copy' import { getBaseUrl } from '@/lib/urls/utils' import { localizeUrl, normalizeCallbackUrl } from '@/i18n/utils' @@ -12,8 +13,8 @@ export function useAuthRedirectUrls() { const canonicalFallbackPath = normalizeCallbackUrl(fallbackPath) ?? '/workspace' return normalizeCallbackUrl(callbackPath) ?? canonicalFallbackPath }, - providerErrorPath(path: string) { - return normalizeCallbackUrl(path) ?? '/sso' + providerErrorPath(callbackPath: string | null | undefined) { + return getAuthErrorCallbackPath(callbackPath) ?? '/error' }, passwordResetUrl() { return localizeUrl(getBaseUrl(), locale, '/reset-password') diff --git a/apps/tradinggoose/lib/billing/core/billing.ts b/apps/tradinggoose/lib/billing/core/billing.ts index de0e631d3..1fd2dd64c 100644 --- a/apps/tradinggoose/lib/billing/core/billing.ts +++ b/apps/tradinggoose/lib/billing/core/billing.ts @@ -205,13 +205,21 @@ export async function calculateSubscriptionOverage(sub: { }) } } else if (sub.tier) { - const usage = await getUserUsageData(sub.referenceId) + const [stats] = await db + .select({ + currentPeriodCost: userStats.currentPeriodCost, + totalCost: userStats.totalCost, + }) + .from(userStats) + .where(eq(userStats.userId, sub.referenceId)) + .limit(1) + const currentUsage = parseDecimal(stats?.currentPeriodCost ?? stats?.totalCost) const { usageAllowance } = getBillingTierPricing(sub.tier) - totalOverage = Math.max(0, usage.currentUsage - usageAllowance) + totalOverage = Math.max(0, currentUsage - usageAllowance) logger.info('Calculated individual-tier overage', { subscriptionId: sub.id, - totalIndividualUsage: usage.currentUsage, + totalIndividualUsage: currentUsage, usageAllowance, totalOverage, }) diff --git a/apps/tradinggoose/lib/billing/core/subscription.test.ts b/apps/tradinggoose/lib/billing/core/subscription.test.ts index c158517d9..e185a6d26 100644 --- a/apps/tradinggoose/lib/billing/core/subscription.test.ts +++ b/apps/tradinggoose/lib/billing/core/subscription.test.ts @@ -47,6 +47,7 @@ vi.mock('@tradinggoose/db/schema', () => ({ id: 'subscription.id', referenceType: 'subscription.referenceType', referenceId: 'subscription.referenceId', + stripeSubscriptionId: 'subscription.stripeSubscriptionId', status: 'subscription.status', }, user: { @@ -201,6 +202,30 @@ describe('subscription billing helpers', () => { ).rejects.toBeInstanceOf(MissingBillingSubscriptionError) }) + it('returns the exact local subscription for a Stripe subscription id', async () => { + const row = { + id: 'sub_123', + stripeSubscriptionId: 'stripe_sub_123', + tier: { id: 'tier_default' }, + } + mockDb.select.mockImplementationOnce(() => createSelectQueryMock([row], 'limit')) + + const { getSubscriptionByStripeSubscriptionId } = await import('./subscription') + + await expect(getSubscriptionByStripeSubscriptionId('stripe_sub_123')).resolves.toBe(row) + expect(mockEq).toHaveBeenCalledWith('subscription.stripeSubscriptionId', 'stripe_sub_123') + expect(mockHydrateSubscriptionsWithTiers).toHaveBeenCalledWith([row]) + }) + + it('returns null for an untracked Stripe subscription id', async () => { + mockDb.select.mockImplementationOnce(() => createSelectQueryMock([], 'limit')) + + const { getSubscriptionByStripeSubscriptionId } = await import('./subscription') + + await expect(getSubscriptionByStripeSubscriptionId('stripe_sub_missing')).resolves.toBe(null) + expect(mockHydrateSubscriptionsWithTiers).toHaveBeenCalledWith([]) + }) + it('seeds onboarding allowance into user stats on billing-enable backfill', async () => { const insertCalls: Array<{ values: Record diff --git a/apps/tradinggoose/lib/billing/core/subscription.ts b/apps/tradinggoose/lib/billing/core/subscription.ts index 1eef95fec..888ba0ed0 100644 --- a/apps/tradinggoose/lib/billing/core/subscription.ts +++ b/apps/tradinggoose/lib/billing/core/subscription.ts @@ -119,6 +119,19 @@ export async function requireActiveSubscriptionForReference( return activeSubscription } +export async function getSubscriptionByStripeSubscriptionId( + stripeSubscriptionId: string +): Promise { + const rows = await db + .select() + .from(subscription) + .where(eq(subscription.stripeSubscriptionId, stripeSubscriptionId)) + .limit(1) + + const hydratedSubscriptions = await hydrateSubscriptionsWithTiers(rows) + return hydratedSubscriptions[0] ?? null +} + export async function getEffectiveSubscription( userId: string ): Promise { diff --git a/apps/tradinggoose/lib/billing/core/usage.ts b/apps/tradinggoose/lib/billing/core/usage.ts index dd797da05..9ab7f47d9 100644 --- a/apps/tradinggoose/lib/billing/core/usage.ts +++ b/apps/tradinggoose/lib/billing/core/usage.ts @@ -20,8 +20,8 @@ import { toBillingTierSummary, } from '@/lib/billing/tiers' import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types' -import { sendEmail } from '@/lib/email/mailer' import { resolveEmailLocale } from '@/lib/email/locale' +import { sendEmail } from '@/lib/email/mailer' import { getEmailPreferences } from '@/lib/email/unsubscribe' import { createLogger } from '@/lib/logs/console/logger' import { getBaseUrl } from '@/lib/urls/utils' @@ -121,11 +121,12 @@ export async function decrementGrantedOnboardingAllowanceByCurrentPeriodUsage( } export async function resetUserDefaultUsageToOnboardingAllowanceBalance( - userId: string + userId: string, + dbClient: Pick = db ): Promise { const [{ onboardingAllowanceUsd }, statsRecords] = await Promise.all([ getResolvedBillingSettings(), - db + dbClient .select({ grantedOnboardingAllowanceUsd: userStats.grantedOnboardingAllowanceUsd }) .from(userStats) .where(eq(userStats.userId, userId)) @@ -141,7 +142,7 @@ export async function resetUserDefaultUsageToOnboardingAllowanceBalance( parseNonNegativeBillingAmount(statsRecords[0].grantedOnboardingAllowanceUsd) ?? 0 const usedOnboardingAllowance = Math.max(onboardingAllowance - grantedAllowance, 0) - await db + await dbClient .update(userStats) .set({ customUsageLimit: onboardingAllowance.toString(), diff --git a/apps/tradinggoose/lib/billing/webhooks/enterprise.ts b/apps/tradinggoose/lib/billing/webhooks/enterprise.ts index 4957f9958..8a1581f44 100644 --- a/apps/tradinggoose/lib/billing/webhooks/enterprise.ts +++ b/apps/tradinggoose/lib/billing/webhooks/enterprise.ts @@ -12,8 +12,8 @@ import { requireBillingTierById, } from '@/lib/billing/tiers' import { resolveBillingTierForPersistence } from '@/lib/billing/tiers/persistence' -import { sendEmail } from '@/lib/email/mailer' import { resolveEmailLocale } from '@/lib/email/locale' +import { sendEmail } from '@/lib/email/mailer' import { createLogger } from '@/lib/logs/console/logger' import type { EnterpriseSubscriptionMetadata } from '../types' @@ -137,16 +137,12 @@ export async function handleManualEnterpriseSubscription(event: Stripe.Event) { metadata: metadataJson, } - const existing = await db - .select({ id: subscription.id }) - .from(subscription) - .where(eq(subscription.stripeSubscriptionId, stripeSubscription.id)) - .limit(1) - - if (existing.length > 0) { - await db - .update(subscription) - .set({ + const [persistedSubscription] = await db + .insert(subscription) + .values(subscriptionRow) + .onConflictDoUpdate({ + target: subscription.stripeSubscriptionId, + set: { plan: subscriptionRow.plan, billingTierId: subscriptionRow.billingTierId, referenceType: subscriptionRow.referenceType, @@ -160,11 +156,10 @@ export async function handleManualEnterpriseSubscription(event: Stripe.Event) { trialStart: subscriptionRow.trialStart, trialEnd: subscriptionRow.trialEnd, metadata: subscriptionRow.metadata, - }) - .where(eq(subscription.stripeSubscriptionId, stripeSubscription.id)) - } else { - await db.insert(subscription).values(subscriptionRow) - } + }, + }) + .returning({ id: subscription.id }) + const subscriptionId = persistedSubscription?.id || subscriptionRow.id if (billingTierRecord.usageScope === 'pooled') { const organizationUsageLimit = getTierIncludedUsageLimit(billingTierRecord) || monthlyPrice @@ -193,7 +188,7 @@ export async function handleManualEnterpriseSubscription(event: Stripe.Event) { } logger.info('[subscription.created] Upserted enterprise subscription', { - subscriptionId: existing[0]?.id || subscriptionRow.id, + subscriptionId, referenceType: subscriptionRow.referenceType, referenceId: subscriptionRow.referenceId, subscriptionKey: subscriptionRow.plan, @@ -230,7 +225,11 @@ export async function handleManualEnterpriseSubscription(event: Stripe.Event) { const org = orgDetails[0] const locale = await resolveEmailLocale({ userId: user.id, email: user.email }) - const html = await renderEnterpriseSubscriptionEmail(user.name || user.email, user.email, locale) + const html = await renderEnterpriseSubscriptionEmail( + user.name || user.email, + user.email, + locale + ) const emailResult = await sendEmail({ to: user.email, diff --git a/apps/tradinggoose/lib/billing/webhooks/invoices.ts b/apps/tradinggoose/lib/billing/webhooks/invoices.ts index 4635c6fca..7251c1fa1 100644 --- a/apps/tradinggoose/lib/billing/webhooks/invoices.ts +++ b/apps/tradinggoose/lib/billing/webhooks/invoices.ts @@ -3,7 +3,6 @@ import { member, organizationBillingLedger, organizationMemberBillingLedger, - subscription as subscriptionTable, user, userStats, } from '@tradinggoose/db/schema' @@ -12,15 +11,15 @@ import type Stripe from 'stripe' import { getEmailSubject, renderPaymentFailedEmail } from '@/components/emails/render-email' import { calculateSubscriptionOverage } from '@/lib/billing/core/billing' import { getOrganizationBillingLedger } from '@/lib/billing/core/organization' +import { getSubscriptionByStripeSubscriptionId } from '@/lib/billing/core/subscription' import { requireStripeClient } from '@/lib/billing/stripe-client' import { type BillingTierRecord, - hydrateSubscriptionsWithTiers, isOrganizationSubscription, usesIndividualBillingLedger, } from '@/lib/billing/tiers' -import { sendEmail } from '@/lib/email/mailer' import { resolveEmailLocale } from '@/lib/email/locale' +import { sendEmail } from '@/lib/email/mailer' import { quickValidateEmail } from '@/lib/email/validation' import { createLogger } from '@/lib/logs/console/logger' import { getBaseUrl } from '@/lib/urls/utils' @@ -44,17 +43,6 @@ type SubscriptionUsageScope = { tier?: BillingTierRecord | null } -async function getHydratedSubscriptionByStripeSubscriptionId(stripeSubscriptionId: string) { - const records = await db - .select() - .from(subscriptionTable) - .where(eq(subscriptionTable.stripeSubscriptionId, stripeSubscriptionId)) - .limit(1) - - const hydratedSubscriptions = await hydrateSubscriptionsWithTiers(records) - return hydratedSubscriptions[0] ?? null -} - /** * Create a billing portal URL for a Stripe customer */ @@ -354,7 +342,7 @@ export async function handleInvoicePaymentSucceeded(event: Stripe.Event) { }) return } - const sub = await getHydratedSubscriptionByStripeSubscriptionId(stripeSubscriptionId) + const sub = await getSubscriptionByStripeSubscriptionId(stripeSubscriptionId) if (!sub) return // Only reset usage here if the tenant was previously blocked; otherwise invoice.created already reset it @@ -457,7 +445,7 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) { stripeSubscriptionId, }) - const sub = await getHydratedSubscriptionByStripeSubscriptionId(stripeSubscriptionId) + const sub = await getSubscriptionByStripeSubscriptionId(stripeSubscriptionId) if (sub) { if (isOrganizationSubscription(sub)) { @@ -530,7 +518,7 @@ export async function handleInvoiceFinalized(event: Stripe.Event) { } if (invoice.billing_reason && invoice.billing_reason !== 'subscription_cycle') return - const sub = await getHydratedSubscriptionByStripeSubscriptionId(stripeSubscriptionId) + const sub = await getSubscriptionByStripeSubscriptionId(stripeSubscriptionId) if (!sub) return const stripe = requireStripeClient() diff --git a/apps/tradinggoose/lib/billing/webhooks/subscription.test.ts b/apps/tradinggoose/lib/billing/webhooks/subscription.test.ts index d5cf60551..79a7e7827 100644 --- a/apps/tradinggoose/lib/billing/webhooks/subscription.test.ts +++ b/apps/tradinggoose/lib/billing/webhooks/subscription.test.ts @@ -6,25 +6,46 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' const { mockAnd, + mockCalculateSubscriptionOverage, mockDb, mockDecrementGrantedOnboardingAllowanceByCurrentPeriodUsage, + mockEnsureDefaultUserSubscription, mockEq, + mockGetBilledOverageForSubscription, + mockGetResolvedBillingSettings, + mockGetSubscriptionByStripeSubscriptionId, mockIsPaidBillingTier, mockNe, + mockRequireStripeClient, mockResetUsageForSubscription, + mockResetUserDefaultUsageToOnboardingAllowanceBalance, + mockSyncSubscriptionBillingTierFromStripeSubscription, + mockSyncSubscriptionUsageLimits, } = vi.hoisted(() => ({ mockAnd: vi.fn(), + mockCalculateSubscriptionOverage: vi.fn(), mockDb: { select: vi.fn(), + transaction: vi.fn(), + update: vi.fn(), }, mockDecrementGrantedOnboardingAllowanceByCurrentPeriodUsage: vi.fn(), + mockEnsureDefaultUserSubscription: vi.fn(), mockEq: vi.fn((field: unknown, value: unknown) => ({ field, value })), + mockGetBilledOverageForSubscription: vi.fn(), + mockGetResolvedBillingSettings: vi.fn(), + mockGetSubscriptionByStripeSubscriptionId: vi.fn(), mockIsPaidBillingTier: vi.fn(), mockNe: vi.fn((field: unknown, value: unknown) => ({ field, value })), + mockRequireStripeClient: vi.fn(), mockResetUsageForSubscription: vi.fn(), + mockResetUserDefaultUsageToOnboardingAllowanceBalance: vi.fn(), + mockSyncSubscriptionBillingTierFromStripeSubscription: vi.fn(), + mockSyncSubscriptionUsageLimits: vi.fn(), })) let otherActiveSubscriptions: Array> = [] +let updateCalls: Array> = [] vi.mock('@tradinggoose/db', () => ({ db: mockDb, @@ -34,6 +55,7 @@ vi.mock('@tradinggoose/db/schema', () => ({ subscription: { referenceType: 'subscription.referenceType', referenceId: 'subscription.referenceId', + stripeSubscriptionId: 'subscription.stripeSubscriptionId', status: 'subscription.status', id: 'subscription.id', }, @@ -48,23 +70,43 @@ vi.mock('drizzle-orm', () => ({ vi.mock('@/lib/billing/core/usage', () => ({ decrementGrantedOnboardingAllowanceByCurrentPeriodUsage: mockDecrementGrantedOnboardingAllowanceByCurrentPeriodUsage, + resetUserDefaultUsageToOnboardingAllowanceBalance: + mockResetUserDefaultUsageToOnboardingAllowanceBalance, +})) + +vi.mock('@/lib/billing/core/subscription', () => ({ + ensureDefaultUserSubscription: mockEnsureDefaultUserSubscription, + getSubscriptionByStripeSubscriptionId: mockGetSubscriptionByStripeSubscriptionId, +})) + +vi.mock('@/lib/billing/settings', () => ({ + getResolvedBillingSettings: mockGetResolvedBillingSettings, })) vi.mock('@/lib/billing/tiers', () => ({ isPaidBillingTier: mockIsPaidBillingTier, })) +vi.mock('@/lib/billing/tiers/persistence', () => ({ + syncSubscriptionBillingTierFromStripeSubscription: + mockSyncSubscriptionBillingTierFromStripeSubscription, +})) + +vi.mock('@/lib/billing/organization', () => ({ + syncSubscriptionUsageLimits: mockSyncSubscriptionUsageLimits, +})) + vi.mock('@/lib/billing/webhooks/invoices', () => ({ - getBilledOverageForSubscription: vi.fn(), + getBilledOverageForSubscription: mockGetBilledOverageForSubscription, resetUsageForSubscription: mockResetUsageForSubscription, })) vi.mock('@/lib/billing/core/billing', () => ({ - calculateSubscriptionOverage: vi.fn(), + calculateSubscriptionOverage: mockCalculateSubscriptionOverage, })) vi.mock('@/lib/billing/stripe-client', () => ({ - requireStripeClient: vi.fn(), + requireStripeClient: mockRequireStripeClient, })) vi.mock('@/lib/logs/console/logger', () => ({ @@ -75,22 +117,102 @@ vi.mock('@/lib/logs/console/logger', () => ({ }), })) -function createSelectQueryMock(result: unknown) { +function createSelectQueryMock(result: unknown, terminal: 'where' | 'limit' = 'where') { const query = { from: vi.fn(() => query), - where: vi.fn(() => Promise.resolve(result)), + where: vi.fn(() => (terminal === 'where' ? Promise.resolve(result) : query)), + limit: vi.fn(() => Promise.resolve(result)), + } + + return query +} + +function createUpdateQueryMock() { + const query = { + set: vi.fn((values: Record) => { + updateCalls.push(values) + return query + }), + where: vi.fn(() => Promise.resolve()), } return query } +function createDeletedStripeSubscription( + overrides: Partial<{ + id: string + metadata: Record + }> = {} +) { + return { + id: overrides.id ?? 'sub_stripe_123', + cancel_at_period_end: true, + metadata: overrides.metadata ?? { + referenceId: 'user-1', + subscriptionId: 'metadata_is_not_identity', + userId: 'user-1', + }, + items: { + data: [ + { + current_period_start: 1778910924, + current_period_end: 1781589324, + }, + ], + }, + } +} + +function createDeletedSubscriptionEvent(stripeSubscription = createDeletedStripeSubscription()) { + return { + id: 'evt_deleted', + data: { + object: stripeSubscription, + }, + } +} + +function createDefaultSubscription( + overrides: Partial<{ + id: string + metadata: Record + referenceId: string + referenceType: 'user' | 'organization' + status: string + stripeSubscriptionId: string | null + tier: Record + }> = {} +) { + return { + id: overrides.id ?? 'sub_default_user-1', + referenceType: overrides.referenceType ?? 'user', + referenceId: overrides.referenceId ?? 'user-1', + status: overrides.status ?? 'active', + stripeSubscriptionId: overrides.stripeSubscriptionId ?? null, + metadata: overrides.metadata ?? { source: 'default-tier' }, + tier: overrides.tier ?? { + id: 'tier_default', + isDefault: true, + displayName: 'Pay As You Go', + }, + } +} + describe('handleSubscriptionCreated', () => { beforeEach(() => { vi.resetModules() vi.clearAllMocks() + updateCalls = [] otherActiveSubscriptions = [] mockDb.select.mockImplementation(() => createSelectQueryMock(otherActiveSubscriptions)) + mockDb.transaction.mockImplementation(async (callback) => callback(mockDb)) + mockDb.update.mockImplementation(() => createUpdateQueryMock()) + mockCalculateSubscriptionOverage.mockResolvedValue(0) + mockGetBilledOverageForSubscription.mockResolvedValue(0) + mockGetResolvedBillingSettings.mockResolvedValue({ billingEnabled: true }) + mockRequireStripeClient.mockReturnValue({}) mockIsPaidBillingTier.mockReturnValue(false) }) @@ -155,3 +277,189 @@ describe('handleSubscriptionCreated', () => { expect(mockResetUsageForSubscription).not.toHaveBeenCalled() }) }) + +describe('handleStripeSubscriptionDeleted', () => { + beforeEach(() => { + vi.resetModules() + vi.clearAllMocks() + + updateCalls = [] + mockDb.select.mockImplementation(() => createSelectQueryMock([], 'limit')) + mockDb.transaction.mockImplementation(async (callback) => callback(mockDb)) + mockDb.update.mockImplementation(() => createUpdateQueryMock()) + mockCalculateSubscriptionOverage.mockResolvedValue(0) + mockGetBilledOverageForSubscription.mockResolvedValue(0) + mockGetResolvedBillingSettings.mockResolvedValue({ billingEnabled: true }) + mockGetSubscriptionByStripeSubscriptionId.mockResolvedValue(null) + mockRequireStripeClient.mockReturnValue({}) + mockSyncSubscriptionBillingTierFromStripeSubscription.mockResolvedValue(undefined) + mockSyncSubscriptionUsageLimits.mockResolvedValue(undefined) + mockResetUserDefaultUsageToOnboardingAllowanceBalance.mockResolvedValue(undefined) + mockResetUsageForSubscription.mockResolvedValue(undefined) + }) + + it('settles a deleted Stripe PAYG subscription by Stripe subscription id before restoring default PAYG', async () => { + const stripeBackedSubscription = createDefaultSubscription({ + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + const defaultSubscription = createDefaultSubscription() + mockGetSubscriptionByStripeSubscriptionId + .mockResolvedValueOnce(stripeBackedSubscription) + .mockResolvedValueOnce(stripeBackedSubscription) + mockEnsureDefaultUserSubscription.mockResolvedValue(defaultSubscription) + + const { handleStripeSubscriptionDeleted } = await import('./subscription') + await handleStripeSubscriptionDeleted(createDeletedSubscriptionEvent() as any) + + expect(mockGetSubscriptionByStripeSubscriptionId).toHaveBeenCalledWith('sub_stripe_123') + expect(mockEq).not.toHaveBeenCalledWith('subscription.id', 'metadata_is_not_identity') + expect(mockSyncSubscriptionBillingTierFromStripeSubscription).toHaveBeenCalledWith( + 'sub_default_user-1', + expect.objectContaining({ id: 'sub_stripe_123' }) + ) + expect(mockCalculateSubscriptionOverage).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'sub_default_user-1', + referenceId: 'user-1', + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + ) + expect(mockEnsureDefaultUserSubscription).toHaveBeenCalledWith('user-1', mockDb) + expect(mockResetUserDefaultUsageToOnboardingAllowanceBalance).toHaveBeenCalledWith( + 'user-1', + mockDb + ) + expect(mockSyncSubscriptionUsageLimits).toHaveBeenCalledWith(defaultSubscription) + expect(mockDb.update.mock.invocationCallOrder[0]).toBeLessThan( + mockCalculateSubscriptionOverage.mock.invocationCallOrder[0] + ) + expect(mockCalculateSubscriptionOverage.mock.invocationCallOrder[0]).toBeLessThan( + mockEnsureDefaultUserSubscription.mock.invocationCallOrder[0] + ) + expect(mockEnsureDefaultUserSubscription.mock.invocationCallOrder[0]).toBeLessThan( + mockSyncSubscriptionUsageLimits.mock.invocationCallOrder[0] + ) + expect(updateCalls).toContainEqual( + expect.objectContaining({ + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + ) + }) + + it('skips deleted Stripe subscription events without a local subscription row', async () => { + const { handleStripeSubscriptionDeleted } = await import('./subscription') + await handleStripeSubscriptionDeleted(createDeletedSubscriptionEvent() as any) + + expect(mockSyncSubscriptionBillingTierFromStripeSubscription).not.toHaveBeenCalled() + expect(mockCalculateSubscriptionOverage).not.toHaveBeenCalled() + expect(mockResetUsageForSubscription).not.toHaveBeenCalled() + expect(mockEnsureDefaultUserSubscription).not.toHaveBeenCalled() + expect(mockResetUserDefaultUsageToOnboardingAllowanceBalance).not.toHaveBeenCalled() + expect(mockSyncSubscriptionUsageLimits).not.toHaveBeenCalled() + expect(updateCalls).toEqual([]) + }) + + it('rejects when a matched subscription disappears during deletion settlement', async () => { + const stripeBackedSubscription = createDefaultSubscription({ + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + mockGetSubscriptionByStripeSubscriptionId + .mockResolvedValueOnce(stripeBackedSubscription) + .mockResolvedValueOnce(null) + + const { handleStripeSubscriptionDeleted } = await import('./subscription') + await expect( + handleStripeSubscriptionDeleted(createDeletedSubscriptionEvent() as any) + ).rejects.toThrow( + 'Local subscription disappeared while settling deleted Stripe subscription sub_stripe_123' + ) + + expect(mockSyncSubscriptionBillingTierFromStripeSubscription).toHaveBeenCalled() + expect(updateCalls).toContainEqual( + expect.objectContaining({ + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + ) + expect(mockCalculateSubscriptionOverage).not.toHaveBeenCalled() + expect(mockEnsureDefaultUserSubscription).not.toHaveBeenCalled() + expect(mockSyncSubscriptionUsageLimits).not.toHaveBeenCalled() + }) + + it('does not reset onboarding usage when another personal subscription remains entitled', async () => { + const canceledSubscription = createDefaultSubscription({ + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + }) + const replacementSubscription = createDefaultSubscription({ + id: 'sub_replacement', + stripeSubscriptionId: 'sub_stripe_replacement', + tier: { + id: 'tier_pro', + isDefault: false, + displayName: 'Pro', + }, + }) + mockGetSubscriptionByStripeSubscriptionId + .mockResolvedValueOnce(canceledSubscription) + .mockResolvedValueOnce(canceledSubscription) + mockEnsureDefaultUserSubscription.mockResolvedValue(replacementSubscription) + + const { handleStripeSubscriptionDeleted } = await import('./subscription') + await handleStripeSubscriptionDeleted(createDeletedSubscriptionEvent() as any) + + expect(mockEnsureDefaultUserSubscription).toHaveBeenCalledWith('user-1', mockDb) + expect(mockResetUserDefaultUsageToOnboardingAllowanceBalance).not.toHaveBeenCalled() + expect(mockSyncSubscriptionUsageLimits).toHaveBeenCalledWith(replacementSubscription) + expect(mockCalculateSubscriptionOverage.mock.invocationCallOrder[0]).toBeLessThan( + mockEnsureDefaultUserSubscription.mock.invocationCallOrder[0] + ) + }) + + it('syncs usage limits for organization members after deleting an organization subscription', async () => { + const organizationSubscription = createDefaultSubscription({ + id: 'sub_org', + referenceType: 'organization', + referenceId: 'org-1', + status: 'canceled', + stripeSubscriptionId: 'sub_stripe_123', + tier: { + id: 'tier_team', + isDefault: false, + ownerType: 'organization', + displayName: 'Team', + }, + }) + mockGetSubscriptionByStripeSubscriptionId + .mockResolvedValueOnce(organizationSubscription) + .mockResolvedValueOnce(organizationSubscription) + + const { handleStripeSubscriptionDeleted } = await import('./subscription') + await handleStripeSubscriptionDeleted( + createDeletedSubscriptionEvent( + createDeletedStripeSubscription({ + metadata: { + referenceType: 'organization', + referenceId: 'org-1', + }, + }) + ) as any + ) + + expect(mockEnsureDefaultUserSubscription).not.toHaveBeenCalled() + expect(mockGetSubscriptionByStripeSubscriptionId).toHaveBeenCalledWith('sub_stripe_123') + expect(mockResetUserDefaultUsageToOnboardingAllowanceBalance).not.toHaveBeenCalled() + expect(mockSyncSubscriptionUsageLimits).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'sub_org', + referenceType: 'organization', + referenceId: 'org-1', + status: 'canceled', + }) + ) + }) +}) diff --git a/apps/tradinggoose/lib/billing/webhooks/subscription.ts b/apps/tradinggoose/lib/billing/webhooks/subscription.ts index 4919d4f4a..db9bdd26f 100644 --- a/apps/tradinggoose/lib/billing/webhooks/subscription.ts +++ b/apps/tradinggoose/lib/billing/webhooks/subscription.ts @@ -1,10 +1,21 @@ import { db } from '@tradinggoose/db' import { subscription } from '@tradinggoose/db/schema' import { and, eq, ne } from 'drizzle-orm' +import type Stripe from 'stripe' import { calculateSubscriptionOverage } from '@/lib/billing/core/billing' -import { decrementGrantedOnboardingAllowanceByCurrentPeriodUsage } from '@/lib/billing/core/usage' +import { + ensureDefaultUserSubscription, + getSubscriptionByStripeSubscriptionId, +} from '@/lib/billing/core/subscription' +import { + decrementGrantedOnboardingAllowanceByCurrentPeriodUsage, + resetUserDefaultUsageToOnboardingAllowanceBalance, +} from '@/lib/billing/core/usage' +import { syncSubscriptionUsageLimits } from '@/lib/billing/organization' +import { getResolvedBillingSettings } from '@/lib/billing/settings' import { requireStripeClient } from '@/lib/billing/stripe-client' import { type BillingTierRecord, isPaidBillingTier } from '@/lib/billing/tiers' +import { syncSubscriptionBillingTierFromStripeSubscription } from '@/lib/billing/tiers/persistence' import { getBilledOverageForSubscription, resetUsageForSubscription, @@ -23,6 +34,18 @@ type TieredSubscriptionLifecycleRecord = { tier?: BillingTierRecord | null } +function getStripeSubscriptionPeriod(stripeSubscription: Stripe.Subscription) { + const item = stripeSubscription.items.data[0] + if (!item) { + return {} + } + + return { + periodStart: new Date(item.current_period_start * 1000), + periodEnd: new Date(item.current_period_end * 1000), + } +} + /** * Handle new subscription creation - reset usage if transitioning from free/default to subscribed */ @@ -182,8 +205,8 @@ export async function handleSubscriptionDeleted(subscription: TieredSubscription { idempotencyKey: itemIdemKey } ) - // Finalize the invoice (this will trigger payment collection) - if (overageInvoice.id) { + // Finalize only draft invoices; duplicate webhook deliveries can return the prior invoice. + if (overageInvoice.id && overageInvoice.status === 'draft') { await stripe.invoices.finalizeInvoice(overageInvoice.id) } @@ -215,12 +238,8 @@ export async function handleSubscriptionDeleted(subscription: TieredSubscription }) } - // Reset usage after billing await resetUsageForSubscription(subscription) - // Note: better-auth's Stripe plugin already updates status to 'canceled' before calling this handler - // We only need to handle overage billing and usage reset - logger.info('Successfully processed subscription cancellation', { subscriptionId: subscription.id, stripeSubscriptionId, @@ -235,3 +254,81 @@ export async function handleSubscriptionDeleted(subscription: TieredSubscription throw error } } + +export async function handleStripeSubscriptionDeleted(event: Stripe.Event) { + const stripeSubscription = event.data.object as Stripe.Subscription + const stripeSubscriptionId = stripeSubscription.id + + const resolvedSubscription = await getSubscriptionByStripeSubscriptionId(stripeSubscriptionId) + + if (!resolvedSubscription) { + logger.info('Deleted Stripe subscription has no local subscription row; skipping settlement', { + eventId: event.id, + stripeSubscriptionId, + }) + return + } + + await db + .update(subscription) + .set({ + ...getStripeSubscriptionPeriod(stripeSubscription), + stripeSubscriptionId, + status: 'canceled', + cancelAtPeriodEnd: stripeSubscription.cancel_at_period_end, + }) + .where(eq(subscription.stripeSubscriptionId, stripeSubscriptionId)) + + await syncSubscriptionBillingTierFromStripeSubscription( + resolvedSubscription.id, + stripeSubscription + ) + + const hydratedSubscription = await getSubscriptionByStripeSubscriptionId(stripeSubscriptionId) + if (!hydratedSubscription) { + throw new Error( + `Local subscription disappeared while settling deleted Stripe subscription ${stripeSubscriptionId}` + ) + } + + const subscriptionToSettle = { + ...hydratedSubscription, + stripeSubscriptionId, + status: 'canceled', + } + let subscriptionForUsageLimits: TieredSubscriptionLifecycleRecord = subscriptionToSettle + + await handleSubscriptionDeleted(subscriptionToSettle) + + if (subscriptionToSettle.referenceType === 'user') { + const { billingEnabled } = await getResolvedBillingSettings() + + if (billingEnabled) { + subscriptionForUsageLimits = await db.transaction(async (tx) => { + const nextSubscription = await ensureDefaultUserSubscription( + subscriptionToSettle.referenceId, + tx + ) + + if (nextSubscription.tier?.isDefault && !nextSubscription.stripeSubscriptionId) { + await resetUserDefaultUsageToOnboardingAllowanceBalance( + subscriptionToSettle.referenceId, + tx + ) + } + + return nextSubscription + }) + } + } + + await syncSubscriptionUsageLimits(subscriptionForUsageLimits) + + logger.info('Settled deleted Stripe subscription', { + eventId: event.id, + subscriptionId: subscriptionToSettle.id, + referenceType: subscriptionToSettle.referenceType, + referenceId: subscriptionToSettle.referenceId, + stripeSubscriptionId, + }) +} diff --git a/apps/tradinggoose/lib/block-path-calculator.ts b/apps/tradinggoose/lib/block-path-calculator.ts index 8355fa6c3..a7fa5f37e 100644 --- a/apps/tradinggoose/lib/block-path-calculator.ts +++ b/apps/tradinggoose/lib/block-path-calculator.ts @@ -120,7 +120,7 @@ export class BlockPathCalculator { } names.push(accessibleBlockId) - if (block.metadata?.id === 'input_trigger') { + if (block.metadata?.category === 'triggers' || block.config.params.triggerMode === true) { names.push('start') } } diff --git a/apps/tradinggoose/lib/branding/metadata.ts b/apps/tradinggoose/lib/branding/metadata.ts index 94bfb8b20..620adade9 100644 --- a/apps/tradinggoose/lib/branding/metadata.ts +++ b/apps/tradinggoose/lib/branding/metadata.ts @@ -1,7 +1,8 @@ import type { Metadata } from 'next' import { getBrandConfig } from '@/lib/branding/branding' import { getPublicCopy } from '@/i18n/public-copy' -import { defaultLocale, getOpenGraphLocale, type LocaleCode, SITE_BASE_URL } from '@/i18n/utils' +import { defaultLocale, getOpenGraphLocale, type LocaleCode } from '@/i18n/utils' +import { getBaseUrl } from '@/lib/urls/utils' export const DEFAULT_META_DESCRIPTION = 'Open-source LLM trading platform. Connect data providers, write custom indicators in PineTS, and trigger AI agent workflows on live signals.' @@ -16,6 +17,7 @@ export function generateBrandedMetadata( const brand = getBrandConfig() const copy = getPublicCopy(locale) const landingMeta = copy.meta.landing + const siteBaseUrl = getBaseUrl() const defaultTitle = brand.name @@ -32,7 +34,7 @@ export function generateBrandedMetadata( referrer: 'origin-when-cross-origin', creator: brand.name, publisher: brand.name, - metadataBase: new URL(SITE_BASE_URL), + metadataBase: new URL(siteBaseUrl), robots: { index: true, follow: true, @@ -110,6 +112,8 @@ export function generateBrandedMetadata( * Generate static structured data for SEO */ export function generateStructuredData() { + const siteBaseUrl = getBaseUrl() + return { '@context': 'https://schema.org', '@type': 'SoftwareApplication', @@ -117,7 +121,7 @@ export function generateStructuredData() { alternateName: ['TradingGoose Studio', 'TradingGoose.ai'], description: 'TradingGoose (also known as TradingGoose Studio) is an open-source visual workflow platform for technical LLM-driven trading, maintained at github.com/TradingGoose/TradingGoose-Studio. Connect your own market data providers, write custom indicators in PineTS, monitor live prices, and route signals into AI agent workflows that trigger trades, alerts, portfolio rebalancing, or any action you define. Not affiliated with the older TradingGoose multi-agent LLM research framework.', - url: SITE_BASE_URL, + url: siteBaseUrl, sameAs: [ 'https://github.com/TradingGoose/TradingGoose-Studio', 'https://docs.tradinggoose.ai', @@ -136,7 +140,7 @@ export function generateStructuredData() { '@type': 'Organization', name: 'TradingGoose Studio', alternateName: 'TradingGoose', - url: SITE_BASE_URL, + url: siteBaseUrl, sameAs: [ 'https://github.com/TradingGoose/TradingGoose-Studio', 'https://discord.gg/wavf5JWhuT', diff --git a/apps/tradinggoose/lib/copilot/access-policy.ts b/apps/tradinggoose/lib/copilot/access-policy.ts index 6eb379a84..441c60fd2 100644 --- a/apps/tradinggoose/lib/copilot/access-policy.ts +++ b/apps/tradinggoose/lib/copilot/access-policy.ts @@ -1,6 +1,6 @@ export type CopilotAccessLevel = 'limited' | 'full' -export function shouldAutoExecuteTool(accessLevel: CopilotAccessLevel): boolean { +function shouldAutoExecuteTool(accessLevel: CopilotAccessLevel): boolean { return accessLevel === 'full' } diff --git a/apps/tradinggoose/lib/copilot/agent/utils.test.ts b/apps/tradinggoose/lib/copilot/agent/utils.test.ts index 3e883dcfa..d0f7d9078 100644 --- a/apps/tradinggoose/lib/copilot/agent/utils.test.ts +++ b/apps/tradinggoose/lib/copilot/agent/utils.test.ts @@ -48,6 +48,7 @@ describe('requestCopilotTitle', () => { const title = await requestCopilotTitle({ message: 'Build a momentum screener with RSI filters', + userId: 'user-1', model: 'gpt-5.4', provider: 'openai', }) @@ -60,6 +61,7 @@ describe('requestCopilotTitle', () => { expect(init.headers).toEqual({ 'Content-Type': 'application/json', 'x-api-key': 'test-copilot-key', + 'x-copilot-user-id': 'user-1', }) const payload = JSON.parse(init.body) @@ -98,6 +100,7 @@ describe('requestCopilotTitle', () => { const title = await requestCopilotTitle({ message: 'Review the current skill implementation', + userId: 'user-1', model: 'claude-opus-4.6', }) diff --git a/apps/tradinggoose/lib/copilot/agent/utils.ts b/apps/tradinggoose/lib/copilot/agent/utils.ts index 2d674e9c5..a6db1bd48 100644 --- a/apps/tradinggoose/lib/copilot/agent/utils.ts +++ b/apps/tradinggoose/lib/copilot/agent/utils.ts @@ -14,10 +14,12 @@ const logger = createLogger('CopilotTitle') */ export async function requestCopilotTitle({ message, + userId, model, provider, }: { message: string + userId: string model?: string provider?: ProviderId }): Promise { @@ -43,6 +45,9 @@ export async function requestCopilotTitle({ }, ], }, + headers: { + 'x-copilot-user-id': userId, + }, }) if (!response.ok) { const errorText = await response.text().catch(() => '') diff --git a/apps/tradinggoose/lib/copilot/completion-usage-billing.ts b/apps/tradinggoose/lib/copilot/completion-usage-billing.ts new file mode 100644 index 000000000..9dee047e5 --- /dev/null +++ b/apps/tradinggoose/lib/copilot/completion-usage-billing.ts @@ -0,0 +1,363 @@ +import { sql } from 'drizzle-orm' +import { z } from 'zod' +import { getPersonalEffectiveSubscription } from '@/lib/billing/core/subscription' +import { isBillingEnabledForRuntime } from '@/lib/billing/settings' +import { getTierCopilotCostMultiplier } from '@/lib/billing/tiers' +import { accrueUserUsageCost } from '@/lib/billing/usage-accrual' +import { resolveWorkflowBillingContext } from '@/lib/billing/workspace-billing' +import { commitCopilotUsageReservation } from '@/lib/copilot/usage-reservations' +import { isHosted } from '@/lib/environment' +import { createLogger } from '@/lib/logs/console/logger' +import { hasProcessedMessage, markMessageAsProcessed } from '@/lib/redis' +import { calculateCost } from '@/providers/ai/utils' + +const BILLING_EVENT_TTL_SECONDS = 60 * 60 * 24 * 30 +const DEFAULT_ESTIMATED_RESERVATION_USD = 1 +const logger = createLogger('CopilotUsageAPI') + +const CompletionUsageReportSchema = z.object({ + kind: z.literal('completion'), + model: z.string().min(1, 'model is required'), + usage: z.unknown(), + remoteModel: z.string().nullable().optional(), + completionId: z.string().min(1, 'completionId is required'), + workflowId: z.string().nullable().optional(), +}) + +interface TokenMetrics { + promptTokens: number + completionTokens: number + totalTokens: number +} + +export type UsageBillingResult = + | { + billed: true + duplicate: false + cost: number + tokens: number + model: string + } + | { + billed: false + duplicate: true + } + | { + billed: false + duplicate?: false + reason: 'billing_disabled' | 'no_token_metrics' | 'zero_cost' | 'ledger_not_found' + } + +function readNumber(value: unknown): number | undefined { + if (typeof value === 'number' && Number.isFinite(value)) { + return value + } + if (typeof value === 'string') { + const parsed = Number.parseFloat(value) + return Number.isFinite(parsed) ? parsed : undefined + } + return undefined +} + +function pickNumber(source: any, keys: string[]): number | undefined { + if (!source || typeof source !== 'object') return undefined + for (const key of keys) { + const candidate = readNumber(source[key]) + if (candidate !== undefined) { + return candidate + } + } + return undefined +} + +function extractTokenMetrics(usage: any): TokenMetrics | null { + const sources = [usage, usage?.tokenUsage, usage?.tokens, usage?.usageDetails] + + let promptTokens: number | undefined + let completionTokens: number | undefined + let totalTokens: number | undefined + + for (const src of sources) { + if (promptTokens === undefined) { + promptTokens = pickNumber(src, [ + 'prompt_tokens', + 'promptTokens', + 'input_tokens', + 'inputTokens', + 'prompt', + ]) + } + if (completionTokens === undefined) { + completionTokens = pickNumber(src, [ + 'completion_tokens', + 'completionTokens', + 'output_tokens', + 'outputTokens', + 'completion', + ]) + } + if (totalTokens === undefined) { + totalTokens = pickNumber(src, [ + 'total_tokens', + 'totalTokens', + 'tokens', + 'token_count', + 'total', + ]) + } + } + + if (totalTokens === undefined) { + totalTokens = readNumber(usage?.tokensUsed) ?? readNumber(usage?.usage) + } + + if (completionTokens === undefined) { + completionTokens = 0 + } + + if (totalTokens !== undefined && promptTokens === undefined) { + promptTokens = totalTokens - completionTokens + } + + if (promptTokens === undefined || totalTokens === undefined) { + return null + } + + const normalizedPrompt = Math.max(0, Math.round(promptTokens)) + const normalizedCompletion = Math.max(0, Math.round(completionTokens ?? 0)) + const normalizedTotal = Math.max( + 0, + Math.round(totalTokens ?? normalizedPrompt + normalizedCompletion) + ) + + if (normalizedTotal <= 0 || (normalizedPrompt === 0 && normalizedCompletion === 0)) { + return null + } + + return { + promptTokens: normalizedPrompt, + completionTokens: normalizedCompletion, + totalTokens: normalizedTotal, + } +} + +async function resolveEffectiveCopilotTier(params: { + userId: string + workflowId?: string +}): Promise<{ + effectiveTier: any + billingContext: Awaited> | null +}> { + const billingContext = params.workflowId + ? await resolveWorkflowBillingContext({ + workflowId: params.workflowId, + actorUserId: params.userId, + }) + : null + const effectiveTier = params.workflowId + ? (billingContext?.subscription?.tier ?? null) + : ((await getPersonalEffectiveSubscription(params.userId))?.tier ?? null) + + if (!effectiveTier) { + throw new Error( + params.workflowId + ? `No active workflow subscription tier found for billed copilot usage on workflow ${params.workflowId}` + : `No active personal subscription tier found for billed copilot usage for user ${params.userId}` + ) + } + + return { + effectiveTier, + billingContext, + } +} + +async function calculateCopilotCostUsd(params: { + userId: string + workflowId?: string + billingModel: string + promptTokens: number + completionTokens: number + fallbackUsd?: number +}): Promise<{ + costUsd: number + normalizedModel: string + billingContext: Awaited> | null +}> { + const normalizedModel = params.billingModel.trim().toLowerCase() + const costResult = calculateCost( + normalizedModel, + params.promptTokens, + params.completionTokens, + false + ) + const { effectiveTier, billingContext } = await resolveEffectiveCopilotTier({ + userId: params.userId, + workflowId: params.workflowId, + }) + const rawCostUsd = Number(costResult.total || 0) * getTierCopilotCostMultiplier(effectiveTier) + + return { + costUsd: rawCostUsd > 0 ? rawCostUsd : (params.fallbackUsd ?? 0), + normalizedModel, + billingContext, + } +} + +export async function calculateCopilotReservationUsdFromEstimate(params: { + userId: string + workflowId?: string + model: string + estimatedPromptTokens: number + reservedCompletionTokens: number +}): Promise { + const { costUsd } = await calculateCopilotCostUsd({ + userId: params.userId, + workflowId: params.workflowId, + billingModel: params.model, + promptTokens: params.estimatedPromptTokens, + completionTokens: params.reservedCompletionTokens, + fallbackUsd: DEFAULT_ESTIMATED_RESERVATION_USD, + }) + + return costUsd +} + +export async function recordCopilotCompletionUsage(params: { + userId: string + workflowId?: string + usage: any + billingModel: string + billingKeyId?: string | null +}): Promise { + const metrics = extractTokenMetrics(params.usage) + if (!metrics) { + logger.info('Skipping copilot billing - no token metrics available', { + billingKeyPrefix: 'copilot-completion-billing', + billingKeyId: params.billingKeyId, + reason: 'copilot_completion_usage', + }) + return { billed: false, reason: 'no_token_metrics' } + } + + const billingKey = params.billingKeyId + ? `copilot-completion-billing:${params.billingKeyId}` + : null + if (billingKey && (await hasProcessedMessage(billingKey))) { + logger.info('Copilot billing already processed', { + billingKey, + reason: 'copilot_completion_usage', + }) + return { billed: false, duplicate: true } + } + + const { + costUsd: costToAdd, + normalizedModel, + billingContext, + } = await calculateCopilotCostUsd({ + userId: params.userId, + workflowId: params.workflowId, + billingModel: params.billingModel, + promptTokens: metrics.promptTokens, + completionTokens: metrics.completionTokens, + }) + if (costToAdd <= 0) { + logger.info('Skipping copilot billing - calculated cost is zero', { + userId: params.userId, + workflowId: params.workflowId, + billingKeyId: params.billingKeyId, + model: normalizedModel, + reason: 'copilot_completion_usage', + }) + return { billed: false, reason: 'zero_cost' } + } + + const extraUpdates: Record = { + totalCopilotCost: sql`total_copilot_cost + ${costToAdd}`, + currentPeriodCopilotCost: sql`current_period_copilot_cost + ${costToAdd}`, + totalCopilotCalls: sql`total_copilot_calls + 1`, + } + + if (metrics.totalTokens > 0) { + extraUpdates.totalCopilotTokens = sql`total_copilot_tokens + ${metrics.totalTokens}` + } + + const didAccrue = await accrueUserUsageCost({ + userId: params.userId, + workflowId: params.workflowId, + cost: costToAdd, + extraUpdates, + reason: 'copilot_completion_usage', + }) + + if (!didAccrue) { + logger.warn('Copilot billing skipped - ledger record not found', { + userId: params.userId, + workflowId: params.workflowId, + billingKeyId: params.billingKeyId, + reason: 'copilot_completion_usage', + }) + return { billed: false, reason: 'ledger_not_found' } + } + + if (billingKey) { + await markMessageAsProcessed(billingKey, BILLING_EVENT_TTL_SECONDS) + } + + logger.info('Copilot billing recorded', { + userId: params.userId, + billingUserId: billingContext?.billingUserId ?? params.userId, + workflowId: params.workflowId, + billingKeyId: params.billingKeyId, + cost: costToAdd, + tokens: metrics.totalTokens, + model: normalizedModel, + reason: 'copilot_completion_usage', + }) + + return { + billed: true, + duplicate: false, + cost: costToAdd, + tokens: metrics.totalTokens, + model: normalizedModel, + } +} + +export async function mirrorLocalCopilotCompletionUsageReports(params: { + userId: string + reports: unknown[] +}): Promise { + if (isHosted || params.reports.length === 0) { + return + } + + if (!(await isBillingEnabledForRuntime())) { + return + } + + for (const report of params.reports) { + try { + const payload = CompletionUsageReportSchema.parse(report) + const billing = await commitCopilotUsageReservation({ + userId: params.userId, + workflowId: payload.workflowId ?? undefined, + operation: () => + recordCopilotCompletionUsage({ + userId: params.userId, + workflowId: payload.workflowId ?? undefined, + usage: payload.usage, + billingModel: payload.model, + billingKeyId: payload.completionId, + }), + }) + + if (!billing.billed && !billing.duplicate && billing.reason !== 'zero_cost') { + logger.warn('Local Copilot completion usage mirror skipped', { reason: billing.reason }) + } + } catch (error) { + logger.warn('Failed to mirror local Copilot completion usage report', { error }) + } + } +} diff --git a/apps/tradinggoose/lib/copilot/entity-documents.ts b/apps/tradinggoose/lib/copilot/entity-documents.ts index 055e2f920..a4042fcc3 100644 --- a/apps/tradinggoose/lib/copilot/entity-documents.ts +++ b/apps/tradinggoose/lib/copilot/entity-documents.ts @@ -35,7 +35,6 @@ const CustomToolDocumentSchema = z.object({ const IndicatorDocumentSchema = z.object({ name: z.string(), - color: z.string(), pineCode: z.string(), inputMeta: z.record(z.unknown()).nullable(), }) @@ -87,7 +86,6 @@ function normalizeEntityFields( case 'indicator': return { name: typeof source.name === 'string' ? source.name : '', - color: typeof source.color === 'string' ? source.color : '', pineCode: typeof source.pineCode === 'string' ? source.pineCode : '', inputMeta: source.inputMeta && diff --git a/apps/tradinggoose/lib/copilot/inline-tool-call.test.tsx b/apps/tradinggoose/lib/copilot/inline-tool-call.test.tsx index 725ac0aab..a6ccee09c 100644 --- a/apps/tradinggoose/lib/copilot/inline-tool-call.test.tsx +++ b/apps/tradinggoose/lib/copilot/inline-tool-call.test.tsx @@ -253,7 +253,7 @@ describe('InlineToolCall', () => { ) }) - it('shows review controls for staged workflow edits in full access without generic Allow', async () => { + it('shows review controls for already-staged workflow edits in full access', async () => { const toolCallId = 'tool-workflow-review' mockUseCopilotStoreState.accessLevel = 'full' mockGetToolInterruptDisplays.mockReturnValue({ @@ -299,7 +299,7 @@ describe('InlineToolCall', () => { expect(container.textContent).not.toContain('Allow') }) - it('renders entity review diffs and controls from staged tool results in full access', async () => { + it('renders entity review diffs with controls for already-staged reviews in full access', async () => { mockUseCopilotStoreState.accessLevel = 'full' mockGetToolInterruptDisplays.mockReturnValue({ accept: { text: 'Accept changes' }, diff --git a/apps/tradinggoose/lib/copilot/registry.ts b/apps/tradinggoose/lib/copilot/registry.ts index a6e61df58..4fde5f7d3 100644 --- a/apps/tradinggoose/lib/copilot/registry.ts +++ b/apps/tradinggoose/lib/copilot/registry.ts @@ -6,7 +6,10 @@ import { SKILL_DOCUMENT_FORMAT, } from '@/lib/copilot/entity-documents' import { MONITOR_DOCUMENT_FORMAT } from '@/lib/copilot/monitor/monitor-documents' -import { TG_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/document-format' +import { + TG_MERMAID_DOCUMENT_FORMAT, + WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT, +} from '@/lib/workflows/document-format' import { WORKFLOW_VARIABLE_TYPES, type WorkflowVariableType } from '@/lib/workflows/value-types' import { GetAgentAccessoryCatalogInput, @@ -148,7 +151,6 @@ const CreateWorkflowArgs = z .object({ name: z.string().trim().min(1).optional(), description: z.string().optional(), - color: z.string().optional(), folderId: z.string().nullable().optional(), workspaceId: RequiredId.optional(), }) @@ -167,14 +169,19 @@ const EditWorkflowArgs = z .string() .min(1) .describe( - 'Complete raw `tg-mermaid-v1` Mermaid document for the entire workflow, not a partial patch. Preserve the canonical `%% TG_WORKFLOW`, `%% TG_BLOCK`, and `%% TG_EDGE` metadata returned by `read_workflow`; Studio validates that structure. Use this only for graph or topology changes such as adding, removing, reconnecting, or replacing blocks, loops, parallels, or condition branches.' + 'Minimal Mermaid flowchart for the entire workflow graph, not a partial patch. Include flowchart direction, existing block ids as node/subgraph ids, new block `id:` and `type:` labels, subgraph nesting, and edge arrows. Do not include `%% TG_*` metadata, subBlocks, outputs, enabled, positions, or full block metadata. Existing block ids are stable identities: their type and details are preserved by id, and supplied labels must match current block names. This tool cannot replace an existing block or change its type; new ids create new blocks with generated positions. Use edit_workflow_block for block internals.' + ), + removedBlockIds: z + .array(z.string().trim().min(1)) + .optional() + .describe( + 'Existing block root ids intentionally removed from the workflow graph. Removing a loop or parallel root removes its descendants.' ), - documentFormat: z.literal(TG_MERMAID_DOCUMENT_FORMAT).optional(), entityId: RequiredId, }) .strict() .describe( - "Full workflow document replacement tool. Do not use this to rename one existing block or patch one block's `enabled` or `subBlocks`; use `edit_workflow_block` instead." + "Full workflow topology rewrite tool using minimal Mermaid. Do not use this to replace an existing block, rename one existing block, or patch one block's `enabled` or `subBlocks`; use `edit_workflow_block` instead." ) const EditWorkflowBlockArgs = z @@ -305,6 +312,11 @@ export const ToolArgSchemas = { run_workflow: z.object({ entityId: RequiredId, + triggerBlockId: z + .string() + .trim() + .min(1) + .describe('Exact trigger block id from `read_workflow.workflowSummary.blocks`.'), workflow_input: z.union([z.string(), z.record(z.any())]).optional(), }), @@ -591,6 +603,11 @@ const WorkflowDocumentEnvelope = WorkflowTargetEnvelope.extend({ entityDocument: z.string(), }) +const WorkflowGraphDocumentEnvelope = WorkflowTargetEnvelope.extend({ + documentFormat: z.literal(WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT), + entityDocument: z.string(), +}) + const WorkflowSummaryResult = z.object({ blocks: z.array( z.object({ @@ -640,7 +657,6 @@ const GenericEntityListEntry = z.object({ entityDescription: z.string().optional(), entityTitle: z.string().optional(), entityFunctionName: z.string().optional(), - entityColor: z.string().optional(), entityTransport: z.string().optional(), entityUrl: z.string().optional(), entityEnabled: z.boolean().optional(), @@ -656,7 +672,6 @@ const GenericEntityListResult = z.object({ const IndicatorListEntry = z.object({ name: z.string(), source: z.enum(['default', 'custom']), - color: z.string().optional(), editable: z.boolean(), callableInFunctionBlock: z.boolean(), inputTitles: z.array(z.string()).optional(), @@ -768,7 +783,7 @@ const WorkflowPreviewEdge = z.object({ targetHandle: z.string().optional(), }) -const BuildOrEditWorkflowResult = WorkflowDocumentEnvelope.extend({ +const WorkflowMutationResultShape = { workflowState: z.unknown().optional(), preview: z .object({ @@ -790,7 +805,10 @@ const BuildOrEditWorkflowResult = WorkflowDocumentEnvelope.extend({ edgesCount: z.number(), }) .optional(), -}) +} + +const EditWorkflowResult = WorkflowGraphDocumentEnvelope.extend(WorkflowMutationResultShape) +const EditWorkflowBlockResult = WorkflowDocumentEnvelope.extend(WorkflowMutationResultShape) const ExecutionEntry = z.object({ id: z.string(), @@ -841,8 +859,8 @@ export const ToolResultSchemas = { message: z.string().optional(), }), - edit_workflow: BuildOrEditWorkflowResult, - edit_workflow_block: BuildOrEditWorkflowResult, + edit_workflow: EditWorkflowResult, + edit_workflow_block: EditWorkflowBlockResult, rename_workflow: WorkflowMutationResult, run_workflow: z.object({ executionId: z.string().optional(), diff --git a/apps/tradinggoose/lib/copilot/review-sessions/permissions.test.ts b/apps/tradinggoose/lib/copilot/review-sessions/permissions.test.ts index 423eccfb5..c7f7c36dd 100644 --- a/apps/tradinggoose/lib/copilot/review-sessions/permissions.test.ts +++ b/apps/tradinggoose/lib/copilot/review-sessions/permissions.test.ts @@ -63,7 +63,9 @@ import { loadReviewSessionForUser, loadReviewSessionForUserByConversationId, verifyReviewTargetAccess, + verifyWorkflowAccess, } from '@/lib/copilot/review-sessions/permissions' +import { readWorkflowAccessContext } from '@/lib/workflows/utils' import { resolveEntityWorkspaceId } from '@/lib/yjs/server/entity-loaders' type MockChain = { @@ -77,6 +79,7 @@ type MockChain = { const mockDb = db as unknown as { select: ReturnType } const mockResolveEntityWorkspaceId = vi.mocked(resolveEntityWorkspaceId) +const mockReadWorkflowAccessContext = vi.mocked(readWorkflowAccessContext) function createMockChain(finalResult: any): MockChain { const chain: any = {} @@ -272,4 +275,27 @@ describe('review session permissions', () => { expect(result).toBeNull() }) + + it('treats canonical workspace owners as workflow admins without permission rows', async () => { + mockReadWorkflowAccessContext.mockResolvedValueOnce({ + workflow: { + id: 'workflow-1', + userId: 'member-1', + workspaceId: 'workspace-1', + } as NonNullable>>['workflow'], + workspaceOwnerId: 'owner-1', + workspacePermission: null, + isOwner: false, + isWorkspaceOwner: true, + }) + + const result = await verifyWorkflowAccess('owner-1', 'workflow-1', 'write') + + expect(result).toEqual({ + hasAccess: true, + userPermission: 'admin', + workspaceId: 'workspace-1', + isOwner: false, + }) + }) }) diff --git a/apps/tradinggoose/lib/copilot/review-sessions/permissions.ts b/apps/tradinggoose/lib/copilot/review-sessions/permissions.ts index a9d6351c0..5c47c2070 100644 --- a/apps/tradinggoose/lib/copilot/review-sessions/permissions.ts +++ b/apps/tradinggoose/lib/copilot/review-sessions/permissions.ts @@ -247,7 +247,9 @@ export async function verifyWorkflowAccess( return buildAccessResult({ isOwner: accessContext.isOwner, - userPermission: accessContext.workspacePermission ?? null, + userPermission: accessContext.isWorkspaceOwner + ? 'admin' + : (accessContext.workspacePermission ?? null), workspaceId: accessContext.workflow.workspaceId ?? null, accessMode, }) diff --git a/apps/tradinggoose/lib/copilot/runtime-tool-manifest-enrichment.ts b/apps/tradinggoose/lib/copilot/runtime-tool-manifest-enrichment.ts index 47b0522c6..dfe993cf2 100644 --- a/apps/tradinggoose/lib/copilot/runtime-tool-manifest-enrichment.ts +++ b/apps/tradinggoose/lib/copilot/runtime-tool-manifest-enrichment.ts @@ -12,7 +12,6 @@ import { MonitorDocumentSchema, } from '@/lib/copilot/monitor/monitor-documents' import type { RuntimeToolManifestSemanticValidator } from '@/lib/copilot/workflow-subblock-semantic-contracts' -import { TG_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/document-format' export type { RuntimeToolManifestSemanticValidator } from '@/lib/copilot/workflow-subblock-semantic-contracts' @@ -74,58 +73,6 @@ const JSON_DOCUMENT_SPECS: JsonDocumentSemanticSpec[] = [ }, ] -const TG_WORKFLOW_LINE_PREFIX = '%% TG_WORKFLOW ' -const TG_BLOCK_LINE_PREFIX = '%% TG_BLOCK ' -const TG_EDGE_LINE_PREFIX = '%% TG_EDGE ' - -const TG_WORKFLOW_METADATA_SCHEMA: Record = { - type: 'object', - required: ['version', 'direction'], - additionalProperties: true, - properties: { - version: { const: TG_MERMAID_DOCUMENT_FORMAT }, - direction: { enum: ['TD', 'LR'] }, - }, -} - -const TG_POSITION_SCHEMA: Record = { - type: 'object', - required: ['x', 'y'], - additionalProperties: true, - properties: { - x: { type: 'number' }, - y: { type: 'number' }, - }, -} - -const TG_BLOCK_SCHEMA: Record = { - type: 'object', - required: ['id', 'type', 'name', 'position', 'subBlocks', 'outputs', 'enabled'], - additionalProperties: true, - properties: { - id: { type: 'string' }, - type: { type: 'string' }, - name: { type: 'string' }, - position: TG_POSITION_SCHEMA, - subBlocks: { type: 'object' }, - outputs: { type: 'object' }, - enabled: { type: 'boolean' }, - }, -} - -const TG_EDGE_SCHEMA: Record = { - type: 'object', - required: ['source', 'target'], - additionalProperties: true, - properties: { - id: { type: 'string' }, - source: { type: 'string' }, - target: { type: 'string' }, - sourceHandle: { type: 'string' }, - targetHandle: { type: 'string' }, - }, -} - function getObjectPropertySchema( parameters: Record, propertyName: string @@ -151,87 +98,6 @@ function getConstStringValue(propertySchema: Record | null): st return null } -function buildWorkflowDocumentSemanticValidators( - documentField: string -): RuntimeToolManifestSemanticValidator[] { - return [ - { - path: documentField, - kind: 'string_requires_real_newlines', - description: - 'Use raw Mermaid text with real newlines; Studio validates workflow graph semantics.', - message: - 'Expected raw Mermaid text with real newline characters, not JSON-escaped `\\n` sequences.', - }, - { - path: documentField, - kind: 'string_starts_with', - args: { prefix: 'flowchart ' }, - description: - 'Start with a Mermaid `flowchart` declaration; Studio validates canonical workflow structure.', - message: 'Expected raw Mermaid text that starts with a `flowchart` declaration.', - }, - { - path: documentField, - kind: 'string_requires_line_prefix', - args: { prefix: TG_WORKFLOW_LINE_PREFIX, minMatches: 1 }, - description: 'Include a standalone canonical `%% TG_WORKFLOW {...}` metadata line.', - message: 'Workflow documents must include a standalone `%% TG_WORKFLOW {...}` metadata line.', - }, - { - path: documentField, - kind: 'string_requires_line_prefix', - args: { prefix: TG_BLOCK_LINE_PREFIX, minMatches: 1 }, - description: 'Include standalone canonical `%% TG_BLOCK {...}` metadata lines.', - message: 'Workflow documents must include standalone `%% TG_BLOCK {...}` metadata lines.', - }, - { - path: documentField, - kind: 'string_line_prefix_json_schema', - args: { prefix: TG_WORKFLOW_LINE_PREFIX, schema: TG_WORKFLOW_METADATA_SCHEMA }, - description: 'Validate each `TG_WORKFLOW` metadata JSON payload.', - message: - '`TG_WORKFLOW` metadata must be canonical JSON with `version: "tg-mermaid-v1"` and `direction` of `TD` or `LR`.', - }, - { - path: documentField, - kind: 'string_line_prefix_json_schema', - args: { prefix: TG_BLOCK_LINE_PREFIX, schema: TG_BLOCK_SCHEMA }, - description: 'Validate each `TG_BLOCK` metadata JSON payload.', - message: - '`TG_BLOCK` metadata must be canonical block state with `id`, `type`, `name`, `position`, `subBlocks`, `outputs`, and `enabled`.', - }, - { - path: documentField, - kind: 'string_line_prefix_json_schema', - args: { prefix: TG_EDGE_LINE_PREFIX, schema: TG_EDGE_SCHEMA }, - description: 'Validate each `TG_EDGE` metadata JSON payload when edge lines are present.', - message: '`TG_EDGE` metadata must be canonical edge state with string `source` and `target`.', - }, - { - path: documentField, - kind: 'string_forbids_substring', - args: { substring: '"blockType"' }, - description: 'Use canonical `TG_BLOCK.type`, not simplified block metadata aliases.', - message: 'Use `type` in `TG_BLOCK` metadata, not `blockType`.', - }, - { - path: documentField, - kind: 'string_forbids_substring', - args: { substring: '"blockName"' }, - description: 'Use canonical `TG_BLOCK.name`, not simplified block metadata aliases.', - message: 'Use `name` in `TG_BLOCK` metadata, not `blockName`.', - }, - { - path: documentField, - kind: 'string_forbids_substring', - args: { substring: '"blockDescription"' }, - description: 'Use canonical `TG_BLOCK` state, not simplified block metadata aliases.', - message: '`TG_BLOCK` metadata must not include `blockDescription`.', - }, - ] -} - function buildJsonDocumentSemanticValidators( documentField: string, spec: JsonDocumentSemanticSpec @@ -255,11 +121,6 @@ function buildJsonDocumentSemanticValidators( } const DOCUMENT_SEMANTIC_SPECS = [ - { - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, - preferredDocumentField: 'entityDocument', - buildSemanticValidators: buildWorkflowDocumentSemanticValidators, - }, ...JSON_DOCUMENT_SPECS.map((spec) => ({ documentFormat: spec.documentFormat, preferredDocumentField: 'entityDocument', diff --git a/apps/tradinggoose/lib/copilot/runtime-tool-manifest.test.ts b/apps/tradinggoose/lib/copilot/runtime-tool-manifest.test.ts index 62bfed4d9..2cb512612 100644 --- a/apps/tradinggoose/lib/copilot/runtime-tool-manifest.test.ts +++ b/apps/tradinggoose/lib/copilot/runtime-tool-manifest.test.ts @@ -142,30 +142,19 @@ describe('copilot runtime tool manifest', () => { }), expect.objectContaining({ name: 'edit_workflow', - description: expect.stringContaining( - 'Do not use this for a single existing block `name`, `enabled`, or `subBlocks` change' - ), + description: expect.stringContaining('minimal Mermaid `entityDocument`'), kind: 'edit', entityKind: 'workflow', - semanticValidators: expect.arrayContaining([ - expect.objectContaining({ - path: 'entityDocument', - kind: 'string_requires_real_newlines', - description: expect.stringContaining('Studio validates workflow graph semantics'), - }), - expect.objectContaining({ - path: 'entityDocument', - kind: 'string_starts_with', - args: { prefix: 'flowchart ' }, - }), - ]), parameters: expect.objectContaining({ type: 'object', required: expect.arrayContaining(['entityId', 'entityDocument']), properties: expect.objectContaining({ entityId: expect.any(Object), entityDocument: expect.objectContaining({ - description: expect.stringContaining('%% TG_WORKFLOW'), + description: expect.stringContaining('Minimal Mermaid flowchart'), + }), + removedBlockIds: expect.objectContaining({ + description: expect.stringContaining('intentionally removed'), }), }), }), @@ -283,39 +272,30 @@ describe('copilot runtime tool manifest', () => { ) const editWorkflowValidators = manifest.tools.find((tool) => tool.name === 'edit_workflow')?.semanticValidators ?? [] - const workflowValidatorKinds = editWorkflowValidators.map((validator) => validator.kind) - expect(workflowValidatorKinds).toEqual( - expect.arrayContaining([ - 'string_requires_real_newlines', - 'string_starts_with', - 'string_requires_line_prefix', - 'string_line_prefix_json_schema', - 'string_forbids_substring', - ]) - ) - expect(workflowValidatorKinds).not.toContain('string_document_contract') - expect(editWorkflowValidators).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - kind: 'string_requires_line_prefix', - args: { prefix: '%% TG_WORKFLOW ', minMatches: 1 }, - }), - expect.objectContaining({ - kind: 'string_requires_line_prefix', - args: { prefix: '%% TG_BLOCK ', minMatches: 1 }, - }), - expect.objectContaining({ - kind: 'string_line_prefix_json_schema', - args: expect.objectContaining({ prefix: '%% TG_EDGE ', schema: expect.any(Object) }), - }), - ]) - ) + expect(editWorkflowValidators.map((validator) => validator.kind)).toEqual([ + 'string_requires_real_newlines', + 'string_starts_with', + 'string_forbids_substring', + ]) const editWorkflowProperties = (manifest.tools.find((tool) => tool.name === 'edit_workflow')?.parameters?.properties as | Record | undefined) ?? {} + const createWorkflowProperties = + (manifest.tools.find((tool) => tool.name === 'create_workflow')?.parameters?.properties as + | Record + | undefined) ?? {} + const createIndicatorSchema = manifest.tools + .find((tool) => tool.name === 'create_indicator') + ?.semanticValidators?.find((validator) => validator.kind === 'string_json_schema')?.args + ?.schema as { properties?: Record; required?: string[] } | undefined + expect(createWorkflowProperties).not.toHaveProperty('color') + expect(createIndicatorSchema?.properties ?? {}).not.toHaveProperty('color') + expect(createIndicatorSchema?.required ?? []).not.toContain('color') expect(editWorkflowProperties).toHaveProperty('entityId') expect(editWorkflowProperties).toHaveProperty('entityDocument') + expect(editWorkflowProperties).toHaveProperty('removedBlockIds') + expect(editWorkflowProperties).not.toHaveProperty('documentFormat') expect(editWorkflowProperties).not.toHaveProperty('workflowId') expect(editWorkflowProperties).not.toHaveProperty('workflowDocument') expect( diff --git a/apps/tradinggoose/lib/copilot/runtime-tool-manifest.ts b/apps/tradinggoose/lib/copilot/runtime-tool-manifest.ts index 7aecaa1b2..e9d2428ec 100644 --- a/apps/tradinggoose/lib/copilot/runtime-tool-manifest.ts +++ b/apps/tradinggoose/lib/copilot/runtime-tool-manifest.ts @@ -42,11 +42,34 @@ const buildToolParameterSchema = (toolId: ToolId): Record => { } const TOOL_NAMES = ToolIds.options +const WORKFLOW_GRAPH_VALIDATORS: RuntimeToolManifestSemanticValidator[] = [ + { + path: 'entityDocument', + kind: 'string_requires_real_newlines', + message: 'Workflow graph Mermaid must be raw multi-line Mermaid text with real newlines.', + }, + { + path: 'entityDocument', + kind: 'string_starts_with', + args: { prefix: 'flowchart ' }, + message: 'Workflow graph Mermaid must start with `flowchart TD` or `flowchart LR`.', + }, + { + path: 'entityDocument', + kind: 'string_forbids_substring', + args: { substring: '%% TG_' }, + message: 'Workflow graph Mermaid must not include TG_* metadata comments.', + }, +] function getSemanticValidators( + toolName: ToolId, parameters: Record ): RuntimeToolManifestSemanticValidator[] | undefined { - const semanticValidators = buildAutomaticSemanticValidators(parameters) + const semanticValidators = + toolName === 'edit_workflow' + ? WORKFLOW_GRAPH_VALIDATORS + : buildAutomaticSemanticValidators(parameters) if (semanticValidators.length === 0) { return undefined @@ -60,7 +83,7 @@ export async function getCopilotRuntimeToolManifest(): Promise { const parameters = buildToolParameterSchema(toolName) - const semanticValidators = getSemanticValidators(parameters) + const semanticValidators = getSemanticValidators(toolName, parameters) return { name: toolName, diff --git a/apps/tradinggoose/lib/copilot/server-tool-errors.test.ts b/apps/tradinggoose/lib/copilot/server-tool-errors.test.ts index 7ef23c7ae..f0928a5b6 100644 --- a/apps/tradinggoose/lib/copilot/server-tool-errors.test.ts +++ b/apps/tradinggoose/lib/copilot/server-tool-errors.test.ts @@ -6,88 +6,92 @@ import { } from '@/lib/copilot/server-tool-errors' describe('copilot server tool errors', () => { - it('maps malformed workflow document errors to repairable 422 responses', () => { + it('returns container repair guidance for invalid canonical container edge handles', () => { const response = buildCopilotServerToolErrorResponse( 'edit_workflow', - new Error('Workflow document did not contain any TG_BLOCK entries') + new Error( + 'Invalid container edge: parallel1 container input requires targetHandle "target" for incoming outer edges.' + ) ) expect(response).toEqual({ status: 422, body: expect.objectContaining({ - code: 'invalid_workflow_document_missing_blocks', + code: 'invalid_workflow_document_container_edge', retryable: true, + issues: [ + { + path: 'entityDocument.edges', + message: + 'Invalid container edge: parallel1 container input requires targetHandle "target" for incoming outer edges.', + }, + ], }), }) - expect(response.body.error).toContain('standalone `%% TG_BLOCK') - expect(response.body.hint).toContain('Do not embed `TG_BLOCK` JSON inside node labels') + expect(response.body.hint).toContain('connect outer edges') }) - it('returns container and condition repair guidance for workflow edge mismatches', () => { + it('preserves embedded workflow sub-block paths in structured edit errors', () => { const response = buildCopilotServerToolErrorResponse( 'edit_workflow', new Error( - 'Workflow document edge metadata is inconsistent. Visible Mermaid connections and TG_EDGE payloads must resolve to the same logical workflow edges.' + 'Invalid edited workflow: Document contract is inconsistent: invalid block sub-block values for functionBlock.subBlocks.code.value (Expected valid raw TypeScript function-body code.).' ) ) expect(response).toEqual({ status: 422, body: expect.objectContaining({ - code: 'invalid_workflow_document_edge_mismatch', + code: 'invalid_workflow_state', retryable: true, + issues: [ + { + path: 'entityDocument.functionBlock.subBlocks.code.value', + message: 'Expected valid raw TypeScript function-body code.', + }, + ], }), }) - expect(response.body.hint).toContain('container subgraphs') - expect(response.body.hint).toContain('condition blocks') }) - it('returns container repair guidance for invalid canonical container edge handles', () => { + it('returns explicit removal guidance for omitted workflow blocks', () => { const response = buildCopilotServerToolErrorResponse( 'edit_workflow', new Error( - 'Invalid container edge: parallel1 container input requires targetHandle "target" for incoming outer edges.' + 'Invalid edited workflow: Existing block ids omitted from edit_workflow entityDocument without removedBlockIds: fn1.' ) ) expect(response).toEqual({ status: 422, body: expect.objectContaining({ - code: 'invalid_workflow_document_container_edge', + code: 'invalid_workflow_state', retryable: true, - issues: [ - { - path: 'entityDocument.edges', - message: - 'Invalid container edge: parallel1 container input requires targetHandle "target" for incoming outer edges.', - }, - ], }), }) - expect(response.body.hint).toContain('targetHandle "target"') + expect(response.body.hint).toContain('removedBlockIds') }) - it('preserves embedded workflow sub-block paths in structured edit errors', () => { + it('returns retryable graph-document guidance for malformed edit workflow Mermaid', () => { const response = buildCopilotServerToolErrorResponse( 'edit_workflow', - new Error( - 'Invalid edited workflow: Document contract is inconsistent: invalid block sub-block values for functionBlock.subBlocks.code.value (Expected valid raw TypeScript function-body code.).' - ) + new Error('Workflow graph Mermaid must start with `flowchart TD` or `flowchart LR`.') ) expect(response).toEqual({ status: 422, body: expect.objectContaining({ - code: 'invalid_workflow_state', + code: 'invalid_workflow_graph_document', retryable: true, issues: [ { - path: 'entityDocument.functionBlock.subBlocks.code.value', - message: 'Expected valid raw TypeScript function-body code.', + path: 'entityDocument', + message: 'Workflow graph Mermaid must start with `flowchart TD` or `flowchart LR`.', }, ], }), }) + expect(response.body.hint).toContain('minimal Mermaid graph') }) it('falls back to a generic 500 payload for unknown tool failures', () => { diff --git a/apps/tradinggoose/lib/copilot/server-tool-errors.ts b/apps/tradinggoose/lib/copilot/server-tool-errors.ts index 155a8d7c8..2fdf1ee90 100644 --- a/apps/tradinggoose/lib/copilot/server-tool-errors.ts +++ b/apps/tradinggoose/lib/copilot/server-tool-errors.ts @@ -76,78 +76,21 @@ function buildInvalidToolPayloadError( } function buildEditWorkflowError(message: string): CopilotServerToolErrorResponse | null { - if (message === 'Missing TG_WORKFLOW metadata') { - return { - status: 422, - body: { - code: 'invalid_workflow_document_missing_metadata', - error: 'Workflow document is missing a standalone `%% TG_WORKFLOW {...}` metadata line.', - hint: 'Send raw `tg-mermaid-v1` Mermaid text with real newlines, and keep `%% TG_WORKFLOW {...}` on its own line near the top of the document.', - retryable: true, - }, - } - } - - if (message === 'Workflow document did not contain any TG_BLOCK entries') { - return { - status: 422, - body: { - code: 'invalid_workflow_document_missing_blocks', - error: - 'Workflow document did not contain any standalone `%% TG_BLOCK {...}` block entries.', - hint: 'Emit canonical `%% TG_BLOCK {...}` comment lines for each block. Do not embed `TG_BLOCK` JSON inside node labels or send simplified block metadata.', - retryable: true, - }, - } - } - - if (message.startsWith('Invalid TG_BLOCK payload:')) { - return { - status: 422, - body: { - code: 'invalid_workflow_document_block_payload', - error: message, - hint: 'Each `TG_BLOCK` payload must be canonical workflow state with `id`, `type`, `name`, `position`, `subBlocks`, `outputs`, and `enabled`.', - retryable: true, - }, - } - } - - if (message.startsWith('Invalid TG_EDGE payload')) { - return { - status: 422, - body: { - code: 'invalid_workflow_document_edge_payload', - error: message, - hint: 'Each `TG_EDGE` payload must be a standalone JSON object with string `source` and `target` fields that matches the visible Mermaid connection.', - retryable: true, - }, - } - } - - if ( - message === - 'Workflow document contains Mermaid connection lines but no TG_EDGE entries. Every visible workflow connection must have a matching TG_EDGE payload.' - ) { - return { - status: 422, - body: { - code: 'invalid_workflow_document_missing_edge_metadata', - error: message, - hint: 'When the diagram shows visible Mermaid connections, include matching standalone `%% TG_EDGE {...}` lines for each connection.', - retryable: true, - }, - } - } + const isGraphDocumentError = + message.startsWith('Workflow graph Mermaid ') || + /^New workflow block ".+" is missing a type label\.$/.test(message) || + /^Unknown workflow block type ".+" for new block ".+"\.$/.test(message) || + message === 'entityDocument is required' - if (message.startsWith('Workflow document edge metadata is inconsistent.')) { + if (isGraphDocumentError) { return { status: 422, body: { - code: 'invalid_workflow_document_edge_mismatch', + code: 'invalid_workflow_graph_document', error: message, - hint: 'Keep the visible Mermaid connection lines and the canonical `%% TG_EDGE {...}` payloads in logical sync. Loop and parallel child blocks must stay inside their container subgraphs and cross container boundaries through the container handles, while condition blocks keep their diamond-and-branch structure.', + hint: 'Send a complete minimal Mermaid graph starting with `flowchart TD` or `flowchart LR`. Do not include TG_* metadata or block internals. Every new block needs `id:` and canonical `type:` labels from `get_available_blocks` or `get_blocks_metadata`.', retryable: true, + issues: [{ path: 'entityDocument', message }], }, } } @@ -158,7 +101,7 @@ function buildEditWorkflowError(message: string): CopilotServerToolErrorResponse body: { code: 'invalid_workflow_document_container_edge', error: message, - hint: 'For loop and parallel containers, incoming outer workflow edges must target the container block alias itself with targetHandle "target". Use Start nodes only as sources to child blocks, and End nodes only for child-to-container completion before leaving the container.', + hint: 'For loop and parallel containers, connect outer edges to the container node and internal edges to the generated start/end nodes.', retryable: true, issues: [{ path: 'entityDocument.edges', message }], }, @@ -184,11 +127,15 @@ function buildEditWorkflowError(message: string): CopilotServerToolErrorResponse const hint = details.includes('non-canonical sub-block') ? 'Use only the canonical sub-block ids from `get_blocks_metadata` for that block type. Keep the existing canonical ids and remove invented keys.' + : details.includes('removedBlockIds') + ? 'Keep every existing block id in the Mermaid graph unless the user explicitly asked to remove it; list intentional removals in `removedBlockIds`.' + : details.includes('immutable identities') + ? 'Keep the existing block id/type pair unchanged. `edit_workflow` rewrites topology only; it cannot replace an existing block or change its type.' : details.includes('unknown block type') - ? 'Use block types exactly as returned by `get_available_blocks` or `get_blocks_metadata`. Keep `TG_BLOCK.type` unchanged unless you are intentionally replacing the block with another valid type.' + ? 'Use block types exactly as returned by `get_available_blocks` or `get_blocks_metadata`.' : details.includes('Edge references non-existent') - ? 'Every `TG_EDGE` source and target must match an existing `TG_BLOCK`, `TG_LOOP`, or `TG_PARALLEL` id in the same document.' - : 'Return a complete canonical workflow document that validates as workflow state. Preserve required block fields, canonical ids, and valid edge references.' + ? 'Every edge source and target must match a block id in the same document.' + : 'Return a complete workflow graph that validates as workflow state. Preserve block ids and valid edge references.' return { status: 422, diff --git a/apps/tradinggoose/lib/copilot/tool-prompt-metadata.ts b/apps/tradinggoose/lib/copilot/tool-prompt-metadata.ts index 9435e9f68..c121c6f64 100644 --- a/apps/tradinggoose/lib/copilot/tool-prompt-metadata.ts +++ b/apps/tradinggoose/lib/copilot/tool-prompt-metadata.ts @@ -28,7 +28,7 @@ export const TOOL_PROMPT_METADATA: Record = { }, [CopilotTool.read_workflow]: { description: - 'Read a workflow by exact `entityId` and return Mermaid in `entityDocument`, plus `workflowSummary.blocks[].connections` counts and exact raw `workflowSummary.edges` with external/internal scope. For topology, use only these edges/counts; do not infer graph connections from subBlock text references like `<...>`. `connectionIssues` only reports malformed existing edges.', + 'Read a workflow by exact `entityId` and return full `tg-mermaid-v1` inspection Mermaid in `entityDocument`, plus `workflowSummary.blocks[].connections` counts and exact raw `workflowSummary.edges` with external/internal scope. Do not submit this full document to `edit_workflow`; that tool accepts minimal graph-only Mermaid. For topology, use only these edges/counts; do not infer graph connections from subBlock text references like `<...>`. `connectionIssues` only reports malformed existing edges.', kind: 'read', entityKind: 'workflow', }, @@ -40,7 +40,7 @@ export const TOOL_PROMPT_METADATA: Record = { }, edit_workflow: { description: - 'Replace the full workflow document using exact argument keys `entityId`, full `entityDocument`, and `documentFormat: tg-mermaid-v1`, then return the resulting workflow state. Use this only for graph or topology edits such as adding, removing, reconnecting, or replacing blocks, or changing loop, parallel, or condition structure. Do not use this for a single existing block `name`, `enabled`, or `subBlocks` change; use `edit_workflow_block` instead. If a full-document edit fails and the request only changes one existing block config, stop retrying `edit_workflow` and switch tools.', + 'Rewrite the full workflow graph topology using exact argument keys `entityId` and minimal Mermaid `entityDocument`, then return the resulting workflow state and graph-only Mermaid document. Use this only for graph or topology edits such as adding, deleting, reconnecting blocks, or changing loop/parallel nesting. Do not send `documentFormat`, `TG_BLOCK`, `TG_EDGE`, `subBlocks`, condition branch labels, `outputs`, `enabled`, positions, or full block metadata. Existing block ids are stable identities used directly as node/subgraph ids: their type and details are preserved by exact id, and supplied labels must match current block names. This tool cannot replace an existing block or change its type; new ids create new blocks with generated positions. New blocks need `id:` and canonical `type:` labels. Existing condition edges must use exact `condition--` source handles; use `edit_workflow_block` to define branches. If an existing block subtree is intentionally deleted, include the removed root id in `removedBlockIds`; otherwise every existing block id must remain in the Mermaid graph. Use `edit_workflow_block` for one existing block `name`, `enabled`, `subBlocks`, or condition branch definition change.', kind: 'edit', entityKind: 'workflow', }, @@ -57,7 +57,8 @@ export const TOOL_PROMPT_METADATA: Record = { entityKind: 'workflow', }, run_workflow: { - description: 'Run the target workflow with optional input.', + description: + 'Run the target workflow with optional input and an exact `triggerBlockId` from `read_workflow.workflowSummary.blocks`.', kind: 'run', entityKind: 'workflow', }, diff --git a/apps/tradinggoose/lib/copilot/tools/client/entities/entity-document-tool-utils.ts b/apps/tradinggoose/lib/copilot/tools/client/entities/entity-document-tool-utils.ts index 44126a3da..e3f20b34c 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/entities/entity-document-tool-utils.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/entities/entity-document-tool-utils.ts @@ -18,7 +18,6 @@ type EntityListEntry = { entityDescription?: string entityTitle?: string entityFunctionName?: string - entityColor?: string entityTransport?: string entityUrl?: string entityEnabled?: boolean @@ -28,7 +27,6 @@ type EntityListEntry = { export type CopilotIndicatorListEntry = { name: string source: 'default' | 'custom' - color?: string editable: boolean callableInFunctionBlock: boolean inputTitles?: string[] @@ -110,7 +108,6 @@ const ENTITY_API_CONFIG: Record = { extractList: (data) => (Array.isArray(data?.data) ? data.data : []), toFields: (item) => ({ name: item?.name ?? '', - color: item?.color ?? '', pineCode: item?.pineCode ?? '', inputMeta: item?.inputMeta && typeof item.inputMeta === 'object' && !Array.isArray(item.inputMeta) @@ -120,7 +117,6 @@ const ENTITY_API_CONFIG: Record = { toListEntry: (item) => ({ entityId: String(item?.id ?? ''), entityName: String(item?.name ?? ''), - entityColor: typeof item?.color === 'string' ? item.color : '', }), }, mcp_server: { @@ -196,9 +192,6 @@ function buildEntityCreateRequest( indicators: [ { name: fields.name, - ...(typeof fields.color === 'string' && fields.color.trim() - ? { color: fields.color.trim() } - : {}), pineCode: fields.pineCode, inputMeta: fields.inputMeta ?? undefined, }, @@ -385,7 +378,6 @@ export async function listCopilotIndicators( source, editable: item?.editable === true, callableInFunctionBlock: item?.callableInFunctionBlock === true, - ...(typeof item?.color === 'string' && item.color ? { color: item.color } : {}), ...(Array.isArray(item?.inputTitles) ? { inputTitles: item.inputTitles.filter( @@ -430,7 +422,6 @@ export async function readEntityFieldsFromContext( entityName: indicator.name, fields: { name: indicator.name, - color: '#3972F6', pineCode: indicator.pineCode, inputMeta: indicator.inputMeta ?? null, }, @@ -473,7 +464,6 @@ export function applyEntityFieldsToSession( break case 'indicator': setEntityField(session.doc, 'name', fields.name ?? '') - setEntityField(session.doc, 'color', fields.color ?? '') replaceEntityTextField(session.doc, 'pineCode', String(fields.pineCode ?? '')) setEntityField(session.doc, 'inputMeta', fields.inputMeta ?? null) break diff --git a/apps/tradinggoose/lib/copilot/tools/client/entities/entity-tools.test.ts b/apps/tradinggoose/lib/copilot/tools/client/entities/entity-tools.test.ts index 9044dcdd7..a17bf5217 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/entities/entity-tools.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/entities/entity-tools.test.ts @@ -334,7 +334,6 @@ describe('entity document tools', () => { id: 'RSI', name: 'Relative Strength Index', source: 'default', - color: '#3972F6', editable: false, callableInFunctionBlock: true, inputTitles: ['Length'], @@ -344,7 +343,6 @@ describe('entity document tools', () => { id: 'indicator-1', name: 'My Custom Indicator', source: 'custom', - color: '#ff0000', editable: true, callableInFunctionBlock: false, inputTitles: ['Fast Length'], @@ -394,7 +392,6 @@ describe('entity document tools', () => { { name: 'Relative Strength Index', source: 'default', - color: '#3972F6', editable: false, callableInFunctionBlock: true, inputTitles: ['Length'], @@ -403,7 +400,6 @@ describe('entity document tools', () => { { name: 'My Custom Indicator', source: 'custom', - color: '#ff0000', editable: true, callableInFunctionBlock: false, inputTitles: ['Fast Length'], diff --git a/apps/tradinggoose/lib/copilot/tools/client/other/oauth-request-access.ts b/apps/tradinggoose/lib/copilot/tools/client/other/oauth-request-access.ts index bbd352e5c..1686c8510 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/other/oauth-request-access.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/other/oauth-request-access.ts @@ -126,16 +126,17 @@ export class OAuthRequestAccessClientTool extends BaseClientTool { if (typeof window !== 'undefined') { const pathMatch = window.location.pathname.match(/\/workspace\/([^/]+)/) const workspaceId = pathMatch?.[1] - const callbackURL = workspaceId - ? `${window.location.origin}/workspace/${workspaceId}/integrations` - : window.location.href + if (!workspaceId) { + throw new Error('Missing workspace context for OAuth callback') + } + const callbackURL = `/workspace/${workspaceId}/integrations` try { localStorage.setItem( 'pending_oauth_state', JSON.stringify({ serviceId, scopes: service.scopes }) ) - } catch { } + } catch {} this.setState(ClientToolCallState.success) await this.markToolComplete(200, `Opened ${this.providerName} connection dialog`) diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/create-workflow.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/create-workflow.ts index b876fa2eb..c6a4e1eab 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/create-workflow.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/create-workflow.ts @@ -5,13 +5,12 @@ import { ClientToolCallState, } from '@/lib/copilot/tools/client/base-tool' import { createLogger } from '@/lib/logs/console/logger' -import { useWorkflowRegistry } from '@/stores/workflows/registry/store' import { getCopilotStoreForToolCall } from '@/stores/copilot/store-access' +import { useWorkflowRegistry } from '@/stores/workflows/registry/store' type CreateWorkflowArgs = { name?: string description?: string - color?: string folderId?: string | null workspaceId?: string } @@ -75,7 +74,6 @@ export class CreateWorkflowClientTool extends BaseClientTool { ...(typeof resolvedArgs?.description === 'string' ? { description: resolvedArgs.description } : {}), - ...(typeof resolvedArgs?.color === 'string' ? { color: resolvedArgs.color } : {}), ...(resolvedArgs?.folderId !== undefined ? { folderId: resolvedArgs.folderId } : {}), }) diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.test.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.test.ts index 545f65333..246555447 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.test.ts @@ -16,6 +16,11 @@ const workflowDocument = [ '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', '%% TG_BLOCK {"id":"block-1","type":"trigger","name":"Trigger","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', ].join('\n') +const editWorkflowDocument = [ + 'flowchart TD', + ' n1["Trigger
id: block-1
type: trigger"]', +].join('\n') +const workflowGraphDocumentFormat = 'tg-workflow-graph-mermaid-v1' let persistedToolCalls: Record = {} @@ -26,16 +31,18 @@ vi.mock('@/lib/copilot/tools/client/workflow/workflow-review-tool-utils', () => workflowId, entityName, entityDocument, + documentFormat, }: { workflowId: string entityName?: string entityDocument: string + documentFormat?: string }) => ({ entityKind: 'workflow', entityId: workflowId, ...(entityName ? { entityName } : {}), entityDocument, - documentFormat: 'tg-mermaid-v1', + documentFormat: documentFormat ?? 'tg-mermaid-v1', }), })) @@ -104,10 +111,17 @@ describe('EditWorkflowClientTool approval gating', () => { }) it('stages workflow edits for review through the unified user-action handler', async () => { - const fetchMock = vi.fn(async (input: RequestInfo | URL, _init?: RequestInit) => { + const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString() if (url === '/api/copilot/execute-copilot-server-tool') { + const body = JSON.parse(String(init?.body)) + expect(body.payload).toMatchObject({ + entityId: 'wf-1', + entityDocument: editWorkflowDocument, + removedBlockIds: ['removed-1'], + }) + expect(body.payload).not.toHaveProperty('documentFormat') return { ok: true, status: 200, @@ -130,7 +144,8 @@ describe('EditWorkflowClientTool approval gating', () => { loops: {}, parallels: {}, }, - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, + documentFormat: workflowGraphDocumentFormat, }, }), } @@ -160,7 +175,8 @@ describe('EditWorkflowClientTool approval gating', () => { await tool.handleUserAction({ entityId: 'wf-1', - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, + removedBlockIds: ['removed-1'], }) expect(tool.getState()).toBe(ClientToolCallState.review) @@ -230,7 +246,8 @@ describe('EditWorkflowClientTool approval gating', () => { loops: {}, parallels: {}, }, - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, + documentFormat: workflowGraphDocumentFormat, }, }), } @@ -252,7 +269,7 @@ describe('EditWorkflowClientTool approval gating', () => { await tool.handleUserAction({ entityId: 'wf-1', - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, }) expect(tool.getState()).toBe(ClientToolCallState.review) @@ -289,7 +306,8 @@ describe('EditWorkflowClientTool approval gating', () => { success: true, result: { workflowState: nextWorkflowState, - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, + documentFormat: workflowGraphDocumentFormat, }, }), } @@ -319,7 +337,7 @@ describe('EditWorkflowClientTool approval gating', () => { await tool.execute({ entityId: 'wf-1', - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, }) await tool.handleAccept() @@ -342,8 +360,8 @@ describe('EditWorkflowClientTool approval gating', () => { expect(markCompleteBody.data).toMatchObject({ entityKind: 'workflow', entityId: 'wf-1', - entityDocument: workflowDocument, - documentFormat: 'tg-mermaid-v1', + entityDocument: editWorkflowDocument, + documentFormat: workflowGraphDocumentFormat, }) }) @@ -374,7 +392,7 @@ describe('EditWorkflowClientTool approval gating', () => { }) await tool.execute({ - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, }) expect(tool.getState()).toBe(ClientToolCallState.error) @@ -415,7 +433,7 @@ describe('EditWorkflowClientTool approval gating', () => { state: ClientToolCallState.review, params: { entityId: 'wf-target', - entityDocument: workflowDocument, + entityDocument: editWorkflowDocument, }, result: { entityId: 'wf-target', diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.ts index 10b12ea2e..de6b9c0b3 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/edit-workflow.ts @@ -22,7 +22,7 @@ import { getCopilotStoreForToolCall } from '@/stores/copilot/store-access' interface EditWorkflowArgs { entityDocument: string - documentFormat?: string + removedBlockIds?: string[] entityId?: string } @@ -147,7 +147,7 @@ export class EditWorkflowClientTool extends StagedReviewClientTool { await tool.handleAccept({ entityId: 'wf-explicit-target', + triggerBlockId: 'schedule-trigger', workflow_input: { symbol: 'AAPL' }, }) @@ -117,6 +118,7 @@ describe('RunWorkflowClientTool channel-safe workflow scoping', () => { workflowInput: { symbol: 'AAPL' }, executionId: toolCallId, workflowId: 'wf-explicit-target', + triggerBlockId: 'schedule-trigger', }) expect(tool.getState()).toBe(ClientToolCallState.success) }) diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/run-workflow.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/run-workflow.ts index 81666e9cc..6df23f2ed 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/run-workflow.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/run-workflow.ts @@ -13,6 +13,7 @@ import { useExecutionStore } from '@/stores/execution/store' interface RunWorkflowArgs { entityId: string description?: string + triggerBlockId: string workflow_input?: Record | string } @@ -79,6 +80,13 @@ export class RunWorkflowClientTool extends BaseClientTool { } logger.debug('Using target workflow', { workflowId: activeWorkflowId }) + if (typeof params.triggerBlockId !== 'string' || params.triggerBlockId.length === 0) { + logger.debug('Execution prevented: no trigger block selected') + this.setState(ClientToolCallState.error) + await this.markToolComplete(400, 'triggerBlockId is required') + return + } + let workflowInput: Record | undefined if (params.workflow_input !== undefined) { if (typeof params.workflow_input === 'string') { @@ -116,6 +124,7 @@ export class RunWorkflowClientTool extends BaseClientTool { workflowInput, executionId: this.toolCallId, workflowId: activeWorkflowId, + triggerBlockId: params.triggerBlockId, }) // Determine success for both non-streaming and streaming executions diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-execution-utils.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-execution-utils.ts index 2d46eefdb..04cf1dd67 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-execution-utils.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-execution-utils.ts @@ -1,8 +1,9 @@ import { getReadableWorkflowState } from '@/lib/copilot/tools/client/workflow/workflow-review-tool-utils' import { createLogger } from '@/lib/logs/console/logger' import { runQueuedWorkflowExecution } from '@/lib/workflows/queued-execution-client' -import { TriggerUtils } from '@/lib/workflows/triggers' +import { resolveWorkflowRunTrigger } from '@/lib/workflows/triggers' import type { ExecutionResult } from '@/executor/types' +import { buildExecutableWorkflowData } from '@/stores/workflows/workflow/utils' const logger = createLogger('WorkflowExecutionUtils') @@ -10,21 +11,13 @@ type WorkflowExecutionOptions = { workflowInput?: any executionId?: string workflowId: string + triggerBlockId: string } function createExecutionId() { return globalThis.crypto.randomUUID() } -function resolveWorkflowStart(blocks: Record) { - for (const triggerType of ['chat', 'manual', 'api'] as const) { - const start = TriggerUtils.findStartBlock(blocks, triggerType) - if (start) return { triggerType, startBlockId: start.blockId } - } - - return null -} - export async function executeWorkflowWithFullLogging( options: WorkflowExecutionOptions ): Promise { @@ -44,40 +37,30 @@ export async function executeWorkflowWithFullLogging( throw new Error('Workflow execution context requires workspaceId') } - const blocks = Object.entries(workflowState.blocks).reduce( - (acc, [blockId, block]) => { - if (block?.type && block.enabled !== false) { - acc[blockId] = block - } - return acc - }, - {} as typeof workflowState.blocks - ) - const start = resolveWorkflowStart(blocks) - if (!start) { - throw new Error('Workflow requires a chat, API, or manual trigger block to execute') - } + const workflowData = buildExecutableWorkflowData(workflowState.blocks, workflowState.edges) + const start = resolveWorkflowRunTrigger(workflowData.blocks, workflowData.edges, { + surface: 'copilot', + workflowInput: options.workflowInput, + triggerBlockId: options.triggerBlockId, + }) + workflowData.blocks = start.blocks logger.info('Executing workflow through server route', { workflowId: options.workflowId, triggerType: start.triggerType, - blockCount: Object.keys(blocks).length, - edgeCount: workflowState.edges.length, + triggerBlockId: start.blockId, + blockCount: Object.keys(workflowData.blocks).length, + edgeCount: workflowData.edges.length, }) return runQueuedWorkflowExecution({ workflowId: options.workflowId, executionId: options.executionId ?? createExecutionId(), - input: options.workflowInput, + input: start.input, triggerType: start.triggerType, executionTarget: 'live', - workflowData: { - blocks, - edges: workflowState.edges, - loops: workflowState.loops, - parallels: workflowState.parallels, - }, + workflowData, workflowVariables, - startBlockId: start.startBlockId, + triggerBlockId: start.blockId, }) } diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.test.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.test.ts index 48d323891..6156c5541 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.test.ts @@ -1,7 +1,12 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' +import { + buildWorkflowDocumentToolResult, + buildWorkflowSummary, + getReadableWorkflowState, +} from './workflow-review-tool-utils' -const mockGetRegisteredWorkflowSession = vi.fn() -const mockAcquireWritableWorkflowSessionLease = vi.fn() +const mockGetRegisteredWorkflowSession = vi.hoisted(() => vi.fn()) +const mockAcquireWritableWorkflowSessionLease = vi.hoisted(() => vi.fn()) vi.mock('@/lib/yjs/workflow-session-registry', () => ({ getRegisteredWorkflowSession: (...args: unknown[]) => mockGetRegisteredWorkflowSession(...args), @@ -55,7 +60,6 @@ describe('workflow-review-tool-utils', () => { doc, }) - const { getReadableWorkflowState } = await import('./workflow-review-tool-utils') const result = await getReadableWorkflowState( { toolCallId: 'tool-1', @@ -95,7 +99,6 @@ describe('workflow-review-tool-utils', () => { release, }) - const { getReadableWorkflowState } = await import('./workflow-review-tool-utils') const result = await getReadableWorkflowState( { toolCallId: 'tool-1', @@ -130,7 +133,6 @@ describe('workflow-review-tool-utils', () => { }) it('fails fast when workflow execution context is missing a workflow target', async () => { - const { getReadableWorkflowState } = await import('./workflow-review-tool-utils') await expect( getReadableWorkflowState({ toolCallId: 'tool-1', @@ -142,8 +144,6 @@ describe('workflow-review-tool-utils', () => { }) it('builds workflow document payloads with canonical workflow identity', async () => { - const { buildWorkflowDocumentToolResult } = await import('./workflow-review-tool-utils') - expect( buildWorkflowDocumentToolResult({ workflowId: 'workflow-entity', @@ -160,8 +160,6 @@ describe('workflow-review-tool-utils', () => { }) it('surfaces invalid external edges into container end handles', async () => { - const { buildWorkflowSummary } = await import('./workflow-review-tool-utils') - expect( buildWorkflowSummary({ blocks: { @@ -192,8 +190,6 @@ describe('workflow-review-tool-utils', () => { }) it('surfaces missing outer input handles on incoming container edges', async () => { - const { buildWorkflowSummary } = await import('./workflow-review-tool-utils') - const summary = buildWorkflowSummary({ blocks: { input: block('input', 'input_trigger', 'Input'), @@ -230,8 +226,6 @@ describe('workflow-review-tool-utils', () => { }) it('marks container branch edges as internal so missing outer edges stay visible', async () => { - const { buildWorkflowSummary } = await import('./workflow-review-tool-utils') - const child = { ...block('child', 'function', 'Child'), data: { parentId: 'parallel', extent: 'parent' as const }, diff --git a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.ts b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.ts index 4f09b69c2..87df58714 100644 --- a/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.ts +++ b/apps/tradinggoose/lib/copilot/tools/client/workflow/workflow-review-tool-utils.ts @@ -63,6 +63,7 @@ export function buildWorkflowDocumentToolResult(options: { entityName?: string workspaceId?: string | null entityDocument: string + documentFormat?: string }) { const entityName = normalizeWorkflowTargetValue(options.entityName) @@ -72,7 +73,7 @@ export function buildWorkflowDocumentToolResult(options: { ...(entityName ? { entityName } : {}), ...(options.workspaceId ? { workspaceId: options.workspaceId } : {}), entityDocument: options.entityDocument, - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, + documentFormat: options.documentFormat ?? TG_MERMAID_DOCUMENT_FORMAT, } } diff --git a/apps/tradinggoose/lib/copilot/tools/server/blocks/get-blocks-metadata.test.ts b/apps/tradinggoose/lib/copilot/tools/server/blocks/get-blocks-metadata.test.ts index 4969da244..4e52b9a47 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/blocks/get-blocks-metadata.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/blocks/get-blocks-metadata.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' +import { getBlocksMetadataServerTool } from '@/lib/copilot/tools/server/blocks/get-blocks-metadata' const mockGetOAuthProviderAvailability = vi.hoisted(() => vi.fn()) @@ -131,17 +132,13 @@ vi.mock('@/tools/registry', () => ({ describe('getBlocksMetadataServerTool', () => { beforeEach(() => { - vi.resetModules() + mockGetOAuthProviderAvailability.mockReset() mockGetOAuthProviderAvailability.mockImplementation(async (providerIds: string[]) => Object.fromEntries(providerIds.map((providerId) => [providerId, false])) ) }) it('returns Mermaid profiles and operation variants instead of schema-shaped metadata', async () => { - const { getBlocksMetadataServerTool } = await import( - '@/lib/copilot/tools/server/blocks/get-blocks-metadata' - ) - const result = await getBlocksMetadataServerTool.execute({ blockTypes: [ 'github', diff --git a/apps/tradinggoose/lib/copilot/tools/server/router.test.ts b/apps/tradinggoose/lib/copilot/tools/server/router.test.ts index cf62f7653..43ea96079 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/router.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/router.test.ts @@ -1,11 +1,14 @@ import { beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' -import { TG_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/studio-workflow-mermaid' +import { + TG_MERMAID_DOCUMENT_FORMAT, + WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT, +} from '@/lib/workflows/document-format' const editWorkflowExecute = vi.fn(async () => ({ entityKind: 'workflow', entityId: 'workflow-123', - entityDocument: 'flowchart TD\n%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, + entityDocument: 'flowchart TD\n n1["Input
id: input1
type: input_trigger"]', + documentFormat: WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT, workflowState: { blocks: {} }, })) const readWorkflowLogsExecute = vi.fn(async () => ({ entries: [] })) @@ -237,8 +240,17 @@ describe('copilot contract registry', () => { it('accepts explicit entity ids on workflow execution tools', () => { expect(() => getToolContract('run_workflow')?.args.parse({})).toThrow() expect(() => getToolContract('read_workflow')?.args.parse({})).toThrow() - expect(getToolContract('run_workflow')?.args.parse({ entityId: 'workflow-123' })).toEqual({ + expect(() => + getToolContract('run_workflow')?.args.parse({ entityId: 'workflow-123' }) + ).toThrow() + expect( + getToolContract('run_workflow')?.args.parse({ + entityId: 'workflow-123', + triggerBlockId: 'trigger-1', + }) + ).toEqual({ entityId: 'workflow-123', + triggerBlockId: 'trigger-1', }) expect( getToolContract('set_workflow_variables')?.args.parse({ @@ -329,8 +341,7 @@ describe('routeExecution', () => { it('preserves workflow edit entity fields when routing workflow tools', async () => { const payload = { - entityDocument: 'flowchart TD\n%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, + entityDocument: 'flowchart TD\n n1["Input
id: input1
type: input_trigger"]', entityId: 'workflow-123', currentWorkflowState: '{"blocks":{}}', } @@ -339,7 +350,7 @@ describe('routeExecution', () => { entityKind: 'workflow', entityId: 'workflow-123', entityDocument: expect.any(String), - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, + documentFormat: WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT, }) expect(editWorkflowExecute).toHaveBeenCalledWith(payload, undefined) diff --git a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow-block.test.ts b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow-block.test.ts index 3131a57a4..5bdd6fb8f 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow-block.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow-block.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it, vi } from 'vitest' +import { editWorkflowBlockServerTool } from '@/lib/copilot/tools/server/workflow/edit-workflow-block' vi.mock('@/lib/workflows/validation', () => ({ validateWorkflowState: (state: any) => ({ @@ -9,6 +10,17 @@ vi.mock('@/lib/workflows/validation', () => ({ }), })) +vi.mock('@/blocks', () => ({ + getBlock: (blockType: string) => + blockType === 'function' + ? { + type: 'function', + name: 'Function', + subBlocks: [{ id: 'code', type: 'code' }], + } + : undefined, +})) + const CURRENT_WORKFLOW_STATE = JSON.stringify({ direction: 'TD', blocks: { @@ -35,10 +47,6 @@ const CURRENT_WORKFLOW_STATE = JSON.stringify({ describe('editWorkflowBlockServerTool', () => { it('patches only the selected block config and preserves the workflow document envelope', async () => { - const { editWorkflowBlockServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow-block' - ) - const result = await editWorkflowBlockServerTool.execute( { entityId: 'wf-1', @@ -60,10 +68,6 @@ describe('editWorkflowBlockServerTool', () => { }) it('rejects non-canonical sub-block ids with structured issues', async () => { - const { editWorkflowBlockServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow-block' - ) - await expect( editWorkflowBlockServerTool.execute( { diff --git a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.test.ts b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.test.ts index b2c25b049..4b695297d 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.test.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.test.ts @@ -1,4 +1,6 @@ import { describe, expect, it, vi } from 'vitest' +import { editWorkflowServerTool } from '@/lib/copilot/tools/server/workflow/edit-workflow' +import { WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/document-format' vi.mock('@/lib/workflows/validation', () => ({ validateWorkflowState: (state: any) => ({ @@ -9,7 +11,8 @@ vi.mock('@/lib/workflows/validation', () => ({ }), })) -const INPUT_TRIGGER_CURRENT_WORKFLOW_STATE = JSON.stringify({ +const BASE_WORKFLOW_STATE = { + direction: 'TD', blocks: { input1: { id: 'input1', @@ -26,305 +29,310 @@ const INPUT_TRIGGER_CURRENT_WORKFLOW_STATE = JSON.stringify({ }, outputs: {}, }, + fn1: { + id: 'fn1', + type: 'function', + name: 'Compute Indicators', + position: { x: 0, y: 240 }, + enabled: true, + subBlocks: { + code: { + id: 'code', + type: 'code', + value: 'return { ok: true }', + }, + }, + outputs: {}, + }, }, edges: [], loops: {}, parallels: {}, -}) +} -function buildInputTriggerWorkflowDocument(subBlocks: Record): string { - return [ - 'flowchart TD', - '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - [ - '%% TG_BLOCK ', - JSON.stringify({ - id: 'input1', - type: 'input_trigger', - name: 'Input Form', - position: { x: 0, y: 0 }, - enabled: true, - subBlocks, - outputs: {}, - }), - ].join(''), - ].join('\n') +function graph(lines: string[]): string { + return lines.join('\n') } describe('editWorkflowServerTool', () => { - it( - 'does not persist canonical side effects while preparing a workflow edit proposal', - { timeout: 10_000 }, - async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' - ) + it('connects existing blocks without rewriting block internals', async () => { + const result = await editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph([ + 'flowchart TD', + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n2["Compute Indicators
id: fn1
type: function"]', + ' n1 --> n2', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), + }, + { userId: 'user-1' } + ) - const result = await editWorkflowServerTool.execute( + expect(result.workflowState.blocks.fn1.name).toBe('Compute Indicators') + expect(result.workflowState.blocks.fn1.subBlocks.code.value).toBe('return { ok: true }') + expect(result.workflowState.edges).toEqual([ + expect.objectContaining({ + id: 'input1-source-fn1-target', + source: 'input1', + target: 'fn1', + }), + ]) + expect(result.documentFormat).toBe(WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT) + expect(result.entityDocument).not.toContain('%% TG_') + expect(result.entityDocument).toContain('Compute Indicators') + }) + + it('rejects existing block label renames instead of ignoring them', async () => { + await expect( + editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: [ + entityDocument: graph([ 'flowchart TD', - '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - '%% TG_BLOCK {"id":"block-1","type":"input_trigger","name":"Edited Trigger","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', - ].join('\n'), - currentWorkflowState: JSON.stringify({ - blocks: { - 'block-1': { - id: 'block-1', - type: 'input_trigger', - name: 'Trigger', - position: { x: 0, y: 0 }, - subBlocks: {}, - outputs: {}, - enabled: true, - }, - }, - edges: [], - loops: {}, - parallels: {}, - }), + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n2["Compute
id: fn1
type: function"]', + ' n1 --> n2', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), }, { userId: 'user-1' } ) - - expect(result.entityKind).toBe('workflow') - expect(result.entityId).toBe('wf-1') - expect(result.workflowState.blocks['block-1'].name).toBe('Edited Trigger') - expect(result.documentFormat).toBe('tg-mermaid-v1') - expect(result.entityDocument).toContain('TG_BLOCK') - } - ) - - it('rejects non-canonical TG_BLOCK metadata aliases', async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' - ) + ).rejects.toThrow('Use edit_workflow_block to rename existing blocks.') await expect( editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: [ + entityDocument: graph([ 'flowchart TD', - '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - '%% TG_BLOCK {"id":"block-1","blockType":"input_trigger","blockName":"Edited Trigger","blockDescription":"ignored","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', - ].join('\n'), - currentWorkflowState: JSON.stringify({ - blocks: { - 'block-1': { - id: 'block-1', - type: 'input_trigger', - name: 'Trigger', - position: { x: 0, y: 0 }, - subBlocks: {}, - outputs: {}, - enabled: true, - }, - }, - edges: [], - loops: {}, - parallels: {}, - }), + ' input1["Input Form"]', + ' fn1["Compute"]', + ' input1 --> fn1', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), }, { userId: 'user-1' } ) - ).rejects.toThrow( - 'Invalid TG_BLOCK payload: expected object with string id and string type. Workflow documents use `type`, not `blockType`.' - ) + ).rejects.toThrow('Use edit_workflow_block to rename existing blocks.') }) - it('rejects external TG_EDGE metadata that targets a parallel end handle', async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' - ) - + it('rejects existing block type changes instead of treating them as replacements', async () => { await expect( editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: [ + entityDocument: graph([ 'flowchart TD', - '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', - 'inputTrigger["Input Form
id: inputTrigger
type: input_trigger
enabled: true"]', - 'subgraph sg_parallel1["Parallel Research
id: parallel1
type: parallel
enabled: true"]', - ' parallel1__parallel_start["Parallel Start"]', - ' parallel1__parallel_end["Parallel End"]', - 'end', - 'inputTrigger --> parallel1', - '%% TG_BLOCK {"id":"inputTrigger","type":"input_trigger","name":"Input Form","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', - '%% TG_BLOCK {"id":"parallel1","type":"parallel","name":"Parallel Research","position":{"x":240,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', - '%% TG_EDGE {"source":"inputTrigger","target":"parallel1","targetHandle":"parallel-end-target"}', - '%% TG_PARALLEL {"id":"parallel1","nodes":[],"count":2,"parallelType":"count"}', - ].join('\n'), - currentWorkflowState: JSON.stringify({ - direction: 'TD', - blocks: { - inputTrigger: { - id: 'inputTrigger', - type: 'input_trigger', - name: 'Input Form', - position: { x: 0, y: 0 }, - subBlocks: {}, - outputs: {}, - enabled: true, - }, - parallel1: { - id: 'parallel1', - type: 'parallel', - name: 'Parallel Research', - position: { x: 240, y: 0 }, - subBlocks: {}, - outputs: {}, - enabled: true, - }, - }, - edges: [], - loops: {}, - parallels: { - parallel1: { - id: 'parallel1', - nodes: [], - count: 2, - parallelType: 'count', - }, - }, - }), + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n2["Compute
id: fn1
type: agent"]', + ' n1 --> n2', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), }, { userId: 'user-1' } ) ).rejects.toThrow( - 'Invalid container edge: parallel1 container input requires targetHandle "target" for incoming outer edges.' + 'Existing block ids are immutable identities in edit_workflow; this tool cannot replace an existing block or change its type.' ) }) - it('re-lays out staged workflow state to match LR Mermaid direction before review', async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' + it('adds new blocks with canonical block defaults from metadata-only labels', async () => { + const result = await editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph([ + 'flowchart TD', + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n2["id: fn2
type: function"]', + ' n1 --> n2', + ]), + currentWorkflowState: JSON.stringify({ + ...BASE_WORKFLOW_STATE, + blocks: { input1: BASE_WORKFLOW_STATE.blocks.input1 }, + }), + }, + { userId: 'user-1' } ) + expect(result.workflowState.blocks.fn2).toMatchObject({ + id: 'fn2', + type: 'function', + name: 'Mock Function', + enabled: true, + }) + expect(result.workflowState.blocks.fn2.subBlocks.code).toMatchObject({ + id: 'code', + type: 'code', + value: '', + }) + expect(result.documentFormat).toBe(WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT) + expect(result.entityDocument).toContain('Mock Function') + expect(result.entityDocument).not.toContain('["id: fn2') + expect(result.preview.blockDiff.added).toEqual(['fn2']) + }) + + it('places new blocks after existing siblings regardless of Mermaid order', async () => { + const result = await editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph([ + 'flowchart TD', + ' n2["id: fn2
type: function"]', + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n3["Compute Indicators
id: fn1
type: function"]', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), + }, + { userId: 'user-1' } + ) + + expect(result.workflowState.blocks.fn2.position).toEqual({ x: 0, y: 360 }) + }) + + it('preserves existing block absolute position when moving into a container', async () => { const result = await editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: [ + entityDocument: graph([ 'flowchart LR', - '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"LR"}', - 'inputTrigger(["Input Trigger"])', - '%% TG_BLOCK {"id":"inputTrigger","type":"input_trigger","name":"Input Trigger","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', - 'agentBlock(["Agent"])', - '%% TG_BLOCK {"id":"agentBlock","type":"agent","name":"Agent","position":{"x":0,"y":280},"subBlocks":{},"outputs":{},"enabled":true}', - 'inputTrigger --> agentBlock', - '%% TG_EDGE {"source":"inputTrigger","target":"agentBlock"}', - ].join('\n'), + ' subgraph sg_loop1["Loop
id: loop1
type: loop"]', + ' n1["Compute Indicators
id: fn1
type: function"]', + ' end', + ]), currentWorkflowState: JSON.stringify({ - direction: 'TD', + ...BASE_WORKFLOW_STATE, blocks: { - inputTrigger: { - id: 'inputTrigger', - type: 'input_trigger', - name: 'Input Trigger', - position: { x: 0, y: 0 }, - subBlocks: {}, - outputs: {}, - enabled: true, + fn1: { + ...BASE_WORKFLOW_STATE.blocks.fn1, + position: { x: 420, y: 260 }, }, - agentBlock: { - id: 'agentBlock', - type: 'agent', - name: 'Agent', - position: { x: 0, y: 280 }, + loop1: { + id: 'loop1', + type: 'loop', + name: 'Loop', + position: { x: 100, y: 100 }, + enabled: true, subBlocks: {}, outputs: {}, - enabled: true, }, }, - edges: [ - { - id: 'inputTrigger-source-agentBlock-target', - source: 'inputTrigger', - target: 'agentBlock', - }, - ], - loops: {}, - parallels: {}, }), }, { userId: 'user-1' } ) - expect(result.workflowState.direction).toBe('LR') - expect(result.workflowState.blocks.agentBlock.position.x).toBeGreaterThan( - result.workflowState.blocks.inputTrigger.position.x - ) - expect(result.entityDocument).toContain('flowchart LR') - expect(result.preview.warnings).toContain( - 'Re-laid out workflow blocks to match Mermaid direction LR.' - ) + expect(result.workflowState.blocks.fn1.data).toMatchObject({ + parentId: 'loop1', + extent: 'parent', + }) + expect(result.workflowState.blocks.fn1.position).toEqual({ x: 320, y: 160 }) }) - it('rejects input-trigger edits that invent inputSchema instead of inputFormat', async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' + it('rejects block-internal fields in graph-only workflow edits', async () => { + await expect( + editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph([ + 'flowchart TD', + ' n1["Input Form
id: input1
type: input_trigger
enabled: false
outputs: {}
data.foo: bar
subBlocks.code: return 1"]', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), + }, + { userId: 'user-1' } + ) + ).rejects.toThrow( + 'Workflow graph Mermaid block "input1" includes block-internal fields (enabled, outputs, data.foo, subBlocks.code).' ) + }) + it('rejects omitted existing blocks without explicit removedBlockIds', async () => { await expect( editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: buildInputTriggerWorkflowDocument({ - inputSchema: { - id: 'inputSchema', - type: 'short_text', - value: JSON.stringify({ - type: 'object', - properties: { - ticker: { type: 'string' }, - trade_date: { type: 'string' }, - }, - }), - }, - ticker: { - id: 'ticker', - type: 'short_text', - value: 'AAPL', - }, - trade_date: { - id: 'trade_date', - type: 'short_text', - value: '2026-04-17', - }, - }), - currentWorkflowState: INPUT_TRIGGER_CURRENT_WORKFLOW_STATE, + entityDocument: graph([ + 'flowchart TD', + ' n1["Input Form
id: input1
type: input_trigger"]', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), }, { userId: 'user-1' } ) ).rejects.toThrow( - 'Block Input Form: non-canonical sub-block "inputSchema" is not part of the input_trigger block config.' + 'Existing block ids omitted from edit_workflow entityDocument without removedBlockIds: fn1' ) }) - it('rejects newly introduced non-canonical sub-block ids for known block configs', async () => { - const { editWorkflowServerTool } = await import( - '@/lib/copilot/tools/server/workflow/edit-workflow' + it('removes omitted blocks only when removedBlockIds declares intent', async () => { + const result = await editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph(['flowchart TD', 'input1["Input Form"]']), + removedBlockIds: ['loop1'], + currentWorkflowState: JSON.stringify({ + ...BASE_WORKFLOW_STATE, + blocks: { + input1: BASE_WORKFLOW_STATE.blocks.input1, + loop1: { + id: 'loop1', + type: 'loop', + name: 'Loop', + position: { x: 100, y: 100 }, + enabled: true, + subBlocks: {}, + outputs: {}, + }, + fn1: { + ...BASE_WORKFLOW_STATE.blocks.fn1, + data: { parentId: 'loop1', extent: 'parent' }, + }, + }, + }), + }, + { userId: 'user-1' } ) + expect(result.workflowState.blocks).toHaveProperty('input1') + expect(result.workflowState.blocks).not.toHaveProperty('loop1') + expect(result.workflowState.blocks).not.toHaveProperty('fn1') + expect(result.workflowState.edges).toEqual([]) + }) + + it('rejects removedBlockIds that still appear in the graph', async () => { await expect( editWorkflowServerTool.execute( { entityId: 'wf-1', - entityDocument: buildInputTriggerWorkflowDocument({ - ticker: { - id: 'ticker', - type: 'short_text', - value: 'AAPL', - }, - }), - currentWorkflowState: INPUT_TRIGGER_CURRENT_WORKFLOW_STATE, + entityDocument: graph([ + 'flowchart TD', + ' n1["Input Form
id: input1
type: input_trigger"]', + ' n2["Compute
id: fn1
type: function"]', + ]), + removedBlockIds: ['fn1'], + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), }, { userId: 'user-1' } ) - ).rejects.toThrow( - 'Block Input Form: non-canonical sub-block "ticker" is not part of the input_trigger block config.' - ) + ).rejects.toThrow('removedBlockIds still appear in edit_workflow entityDocument: fn1') + }) + + it('rejects old TG metadata comments in mutation input', async () => { + await expect( + editWorkflowServerTool.execute( + { + entityId: 'wf-1', + entityDocument: graph([ + 'flowchart TD', + '%% TG_WORKFLOW {"version":"tg-mermaid-v1","direction":"TD"}', + '%% TG_BLOCK {"id":"input1","type":"input_trigger","name":"Input Form","position":{"x":0,"y":0},"subBlocks":{},"outputs":{},"enabled":true}', + ]), + currentWorkflowState: JSON.stringify(BASE_WORKFLOW_STATE), + }, + { userId: 'user-1' } + ) + ).rejects.toThrow('Workflow graph Mermaid must not include TG_* metadata comments') }) }) diff --git a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.ts b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.ts index 43bc3b805..ef63d8d94 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/workflow/edit-workflow.ts @@ -1,25 +1,246 @@ import { requireCopilotEntityId } from '@/lib/copilot/tools/entity-target' import type { BaseServerTool } from '@/lib/copilot/tools/server/base-tool' import { createLogger } from '@/lib/logs/console/logger' -import { - parseTgMermaidToWorkflow, - TG_MERMAID_DOCUMENT_FORMAT, -} from '@/lib/workflows/studio-workflow-mermaid' -import { createWorkflowSnapshot } from '@/lib/yjs/workflow-session' +import { resolveBlockRuntimeState } from '@/lib/workflows/block-outputs' +import { WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/document-format' +import { parseGraphOnlyWorkflowMermaid } from '@/lib/workflows/studio-workflow-mermaid' +import { buildInitialSubBlockStates } from '@/lib/workflows/subblock-values' +import { getAbsoluteBlockPosition } from '@/lib/workflows/workflow-direction' +import { createWorkflowSnapshot, type WorkflowSnapshot } from '@/lib/yjs/workflow-session' +import { getBlock } from '@/blocks' +import type { BlockState, Position } from '@/stores/workflows/workflow/types' +import { generateLoopBlocks, generateParallelBlocks } from '@/stores/workflows/workflow/utils' import { buildWorkflowMutationResult, loadBaseWorkflowState } from './workflow-mutation-utils' interface EditWorkflowParams { entityId: string entityDocument: string - documentFormat?: string + removedBlockIds?: string[] currentWorkflowState: string } +function buildStableEdgeId(edge: { + source: string + target: string + sourceHandle?: string | null + targetHandle?: string | null +}): string { + const sourceHandle = + !edge.sourceHandle || edge.sourceHandle === 'source' || edge.sourceHandle === 'output' + ? 'source' + : edge.sourceHandle + const targetHandle = + !edge.targetHandle || edge.targetHandle === 'target' || edge.targetHandle === 'input' + ? 'target' + : edge.targetHandle + + return `${edge.source}-${sourceHandle}-${edge.target}-${targetHandle}` +} + +function createInitialPositionAllocator( + graphBlocks: Array<{ blockId: string; parentId?: string }>, + baseBlocks: Record +): (parentId?: string) => Position { + const siblingCounts = new Map() + for (const graphBlock of graphBlocks) { + if (!baseBlocks[graphBlock.blockId]) continue + siblingCounts.set(graphBlock.parentId, (siblingCounts.get(graphBlock.parentId) ?? 0) + 1) + } + + return (parentId?: string) => { + const siblingCount = siblingCounts.get(parentId) ?? 0 + siblingCounts.set(parentId, siblingCount + 1) + return parentId ? { x: 120, y: siblingCount * 180 } : { x: 0, y: siblingCount * 180 } + } +} + +function buildDefaultBlock( + blockId: string, + blockType: string, + getInitialPosition: (parentId?: string) => Position, + parentId?: string, + name?: string +): BlockState { + const blockConfig = getBlock(blockType) + const data = parentId ? { parentId, extent: 'parent' as const } : undefined + + if (!blockConfig && blockType !== 'loop' && blockType !== 'parallel') { + throw new Error(`Unknown workflow block type "${blockType}" for new block "${blockId}".`) + } + + if (!blockConfig) { + return { + id: blockId, + type: blockType, + name: name?.trim() || (blockType === 'loop' ? 'Loop' : 'Parallel'), + position: getInitialPosition(parentId), + subBlocks: {}, + outputs: {}, + enabled: true, + ...(data ? { data } : {}), + } + } + + const initialSubBlocks = buildInitialSubBlockStates( + blockConfig.subBlocks + ) as BlockState['subBlocks'] + const runtimeState = resolveBlockRuntimeState({ + blockType, + blockConfig, + subBlocks: initialSubBlocks, + triggerMode: false, + }) + + return { + id: blockId, + type: blockType, + name: name?.trim() || blockConfig.name, + position: getInitialPosition(parentId), + subBlocks: runtimeState.subBlocks as BlockState['subBlocks'], + outputs: runtimeState.outputs, + enabled: true, + ...(data ? { data } : {}), + } +} + +function setParent( + block: BlockState, + parentId: string | undefined, + blocks: Record, + baseBlocks: Record +): BlockState { + const nextPosition = + block.data?.parentId === parentId + ? block.position + : (() => { + const absolutePosition = getAbsoluteBlockPosition(block.id, baseBlocks) + if (!parentId) return absolutePosition + const parentPosition = getAbsoluteBlockPosition(parentId, blocks) + return { + x: absolutePosition.x - parentPosition.x, + y: absolutePosition.y - parentPosition.y, + } + })() + + const nextData = parentId + ? { ...(block.data ?? {}), parentId, extent: 'parent' as const } + : (() => { + const { parentId: _parentId, extent: _extent, ...data } = block.data ?? {} + return data + })() + + if (Object.keys(nextData).length === 0) { + const { data: _data, ...blockWithoutData } = block + return { ...blockWithoutData, position: nextPosition } + } + return { ...block, position: nextPosition, data: nextData } +} + +function applyGraphMermaidToWorkflow( + baseWorkflowState: WorkflowSnapshot, + entityDocument: string, + removedBlockIds: string[] = [] +): WorkflowSnapshot & { direction: 'TD' | 'LR' } { + const graph = parseGraphOnlyWorkflowMermaid(entityDocument, baseWorkflowState.blocks ?? {}) + const blocks: Record = {} + const explicitRemovedBlockIds = new Set(removedBlockIds) + for (let expanded = true; expanded; ) { + expanded = false + for (const [blockId, block] of Object.entries(baseWorkflowState.blocks ?? {})) { + const parentId = block.data?.parentId + if ( + !explicitRemovedBlockIds.has(blockId) && + parentId && + explicitRemovedBlockIds.has(parentId) + ) { + explicitRemovedBlockIds.add(blockId) + expanded = true + } + } + } + const graphBlockIds = new Set(graph.blocks.map((block) => block.blockId)) + const omittedExistingBlockIds = Object.keys(baseWorkflowState.blocks ?? {}).filter( + (blockId) => !graphBlockIds.has(blockId) + ) + const missingRemovalIntents = omittedExistingBlockIds.filter( + (blockId) => !explicitRemovedBlockIds.has(blockId) + ) + + if (missingRemovalIntents.length > 0) { + throw new Error( + `Invalid edited workflow: Existing block ids omitted from edit_workflow entityDocument without removedBlockIds: ${missingRemovalIntents.join(', ')}.` + ) + } + + const stillPresentRemovedBlockIds = [...explicitRemovedBlockIds].filter((blockId) => + graphBlockIds.has(blockId) + ) + if (stillPresentRemovedBlockIds.length > 0) { + throw new Error( + `Invalid edited workflow: removedBlockIds still appear in edit_workflow entityDocument: ${stillPresentRemovedBlockIds.join(', ')}.` + ) + } + + const getInitialPosition = createInitialPositionAllocator( + graph.blocks, + baseWorkflowState.blocks ?? {} + ) + + for (const graphBlock of graph.blocks) { + const existingBlock = baseWorkflowState.blocks?.[graphBlock.blockId] + if (existingBlock) { + if (graphBlock.blockType && graphBlock.blockType !== existingBlock.type) { + throw new Error( + `Invalid edited workflow: Existing block "${graphBlock.blockId}" has type "${existingBlock.type}" but entityDocument declares type "${graphBlock.blockType}". Existing block ids are immutable identities in edit_workflow; this tool cannot replace an existing block or change its type.` + ) + } + if (graphBlock.name && graphBlock.name.trim() !== existingBlock.name) { + throw new Error( + `Invalid edited workflow: Existing block "${graphBlock.blockId}" has name "${existingBlock.name}" but entityDocument declares name "${graphBlock.name}". Use edit_workflow_block to rename existing blocks.` + ) + } + blocks[graphBlock.blockId] = setParent( + existingBlock, + graphBlock.parentId, + blocks, + baseWorkflowState.blocks ?? {} + ) + continue + } + if (!graphBlock.blockType) { + throw new Error(`New workflow block "${graphBlock.blockId}" is missing a type label.`) + } + blocks[graphBlock.blockId] = buildDefaultBlock( + graphBlock.blockId, + graphBlock.blockType, + getInitialPosition, + graphBlock.parentId, + graphBlock.name + ) + } + + const edges = graph.edges.map((edge) => ({ + ...edge, + id: buildStableEdgeId(edge), + type: 'default', + data: {}, + })) + + return createWorkflowSnapshot({ + ...baseWorkflowState, + direction: graph.direction, + blocks, + edges, + loops: generateLoopBlocks(blocks), + parallels: generateParallelBlocks(blocks), + }) as WorkflowSnapshot & { direction: 'TD' | 'LR' } +} + export const editWorkflowServerTool: BaseServerTool = { name: 'edit_workflow', async execute(params: EditWorkflowParams): Promise { const logger = createLogger('EditWorkflowServerTool') - const { entityDocument, documentFormat, currentWorkflowState } = params + const { entityDocument, removedBlockIds, currentWorkflowState } = params const workflowId = requireCopilotEntityId(params, { toolName: 'edit_workflow' }) if (!entityDocument || entityDocument.trim().length === 0) { @@ -28,19 +249,24 @@ export const editWorkflowServerTool: BaseServerTool = { logger.info('Executing edit_workflow', { workflowId, - documentFormat: documentFormat || TG_MERMAID_DOCUMENT_FORMAT, + documentLength: entityDocument.length, }) const baseWorkflowState = await loadBaseWorkflowState(workflowId, currentWorkflowState) - const parsedWorkflowDocument = parseTgMermaidToWorkflow(entityDocument) + const nextWorkflowState = applyGraphMermaidToWorkflow( + baseWorkflowState, + entityDocument, + removedBlockIds + ) const result = buildWorkflowMutationResult({ workflowId, baseWorkflowState, - nextWorkflowState: createWorkflowSnapshot(parsedWorkflowDocument), - requestedDirection: parsedWorkflowDocument.direction, + nextWorkflowState, + requestedDirection: nextWorkflowState.direction, + documentFormat: WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT, }) - logger.info('edit_workflow successfully parsed workflow document', { + logger.info('edit_workflow successfully applied workflow graph', { workflowId, blocksCount: Object.keys(result.workflowState.blocks).length, edgesCount: result.workflowState.edges.length, diff --git a/apps/tradinggoose/lib/copilot/tools/server/workflow/workflow-mutation-utils.ts b/apps/tradinggoose/lib/copilot/tools/server/workflow/workflow-mutation-utils.ts index f5cf9789d..4274a65ac 100644 --- a/apps/tradinggoose/lib/copilot/tools/server/workflow/workflow-mutation-utils.ts +++ b/apps/tradinggoose/lib/copilot/tools/server/workflow/workflow-mutation-utils.ts @@ -1,6 +1,8 @@ import { findIntroducedNonCanonicalSubBlocks } from '@/lib/workflows/block-config-canonicalization' +import { WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT } from '@/lib/workflows/document-format' import { buildWorkflowDocumentPreviewDiff, + serializeWorkflowToGraphMermaid, serializeWorkflowToTgMermaid, TG_MERMAID_DOCUMENT_FORMAT, } from '@/lib/workflows/studio-workflow-mermaid' @@ -32,8 +34,11 @@ export function buildWorkflowMutationResult(params: { baseWorkflowState: WorkflowSnapshot nextWorkflowState: WorkflowSnapshot requestedDirection?: WorkflowDirection + entityDocument?: string + documentFormat?: string }) { const { workflowId, baseWorkflowState, nextWorkflowState, requestedDirection } = params + const documentFormat = params.documentFormat ?? TG_MERMAID_DOCUMENT_FORMAT const nonCanonicalSubBlockErrors = findIntroducedNonCanonicalSubBlocks( nextWorkflowState, baseWorkflowState @@ -63,14 +68,18 @@ export function buildWorkflowMutationResult(params: { finalWorkflowState = createWorkflowSnapshot(normalizedWorkflow.workflowState) const preview = buildWorkflowDocumentPreviewDiff(baseWorkflowState, finalWorkflowState) const warnings = Array.from(new Set([...orientationWarnings, ...preview.warnings, ...validation.warnings])) - const entityDocument = serializeWorkflowToTgMermaid(finalWorkflowState, { direction }) + const entityDocument = + params.entityDocument ?? + (documentFormat === WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT + ? serializeWorkflowToGraphMermaid(finalWorkflowState, { direction }) + : serializeWorkflowToTgMermaid(finalWorkflowState, { direction })) return { success: true, entityKind: 'workflow' as const, entityId: workflowId, entityDocument, - documentFormat: TG_MERMAID_DOCUMENT_FORMAT, + documentFormat, workflowState: finalWorkflowState, preview: { ...preview, diff --git a/apps/tradinggoose/lib/copilot/usage-reservations.ts b/apps/tradinggoose/lib/copilot/usage-reservations.ts index b44152bc4..2027b4d46 100644 --- a/apps/tradinggoose/lib/copilot/usage-reservations.ts +++ b/apps/tradinggoose/lib/copilot/usage-reservations.ts @@ -62,6 +62,8 @@ export type CopilotUsageReleaseResult = { const RESERVATION_KEY_PREFIX = 'copilot:usage-reservation' const DEFAULT_RESERVATION_TTL_SECONDS = 15 * 60 const DEFAULT_LOCK_TTL_SECONDS = 10 +const LOCK_ACQUIRE_ATTEMPTS = 10 +const LOCK_ACQUIRE_RETRY_DELAY_MS = 50 function parsePositiveInt(value: string | undefined, fallback: number): number { if (!value) return fallback @@ -215,10 +217,22 @@ function sumReservedUsd(reservations: CopilotUsageReservation[]): number { return reservations.reduce((total, reservation) => total + reservation.reservedUsd, 0) } +function delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + async function withScopeLock(scope: ReservationScope, action: () => Promise): Promise { const lockKey = getScopeLockKey(scope) const token = crypto.randomUUID() - const acquired = await acquireLock(lockKey, token, LOCK_TTL_SECONDS) + let acquired = false + + for (let attempt = 0; attempt < LOCK_ACQUIRE_ATTEMPTS; attempt++) { + acquired = await acquireLock(lockKey, token, LOCK_TTL_SECONDS) + if (acquired) break + if (attempt < LOCK_ACQUIRE_ATTEMPTS - 1) { + await delay(LOCK_ACQUIRE_RETRY_DELAY_MS) + } + } if (!acquired) { throw new Error(`Could not acquire copilot usage reservation lock for ${scope.scopeType}:${scope.scopeId}`) @@ -259,6 +273,57 @@ async function resolveReservationScope(params: { } } +async function resolveMutationScope(params: { + userId: string + workflowId?: string | null + reservationId?: string | null +}): Promise { + if (params.reservationId) { + const lookup = await readReservationLookup(params.reservationId) + if (lookup) return lookup + } + return resolveReservationScope(params) +} + +async function releaseReservationInScope( + scope: ReservationScope, + reservationId: string +): Promise { + const reservations = pruneExpiredReservations(await readScopeReservations(scope)) + const reservation = reservations.find((entry) => entry.id === reservationId) ?? null + const remainingReservations = reservations.filter((entry) => entry.id !== reservationId) + + await writeScopeReservations(scope, remainingReservations) + await deleteCachedValue(getReservationLookupKey(reservationId)) + + return { + released: reservation !== null, + reservationId, + reservedUsd: reservation?.reservedUsd, + scopeType: scope.scopeType, + scopeId: scope.scopeId, + } +} + +export async function commitCopilotUsageReservation(params: { + userId: string + workflowId?: string | null + reservationId?: string | null + operation: () => Promise +}): Promise { + const scope = await resolveMutationScope(params) + + return withScopeLock(scope, async () => { + try { + return await params.operation() + } finally { + if (params.reservationId) { + await releaseReservationInScope(scope, params.reservationId) + } + } + }) +} + export async function reserveCopilotUsage(params: { userId: string workflowId?: string | null @@ -327,126 +392,16 @@ export async function reserveCopilotUsage(params: { }) } -export async function adjustCopilotUsageReservation(params: { - reservationId: string - userId: string - workflowId?: string | null - requestedUsd: number - reason?: string -}): Promise { - const lookup = await readReservationLookup(params.reservationId) - if (!lookup) { - return { - allowed: false, - status: 404, - currentUsage: 0, - limit: 0, - remaining: 0, - activeReservedUsd: 0, - scopeType: 'user', - scopeId: params.userId, - message: 'Reservation not found', - } - } - - return withScopeLock(lookup, async () => { - const reservations = pruneExpiredReservations(await readScopeReservations(lookup)) - const reservation = reservations.find((entry) => entry.id === params.reservationId) ?? null - - if (!reservation) { - await deleteCachedValue(getReservationLookupKey(params.reservationId)) - return { - allowed: false, - status: 404, - currentUsage: 0, - limit: 0, - remaining: 0, - activeReservedUsd: 0, - scopeType: lookup.scopeType, - scopeId: lookup.scopeId, - message: 'Reservation not found', - } - } - - const usage = await checkServerSideUsageLimits({ - userId: params.userId, - workflowId: params.workflowId ?? reservation.workflowId, - }) - - const otherReservations = reservations.filter((entry) => entry.id !== params.reservationId) - const otherReservedUsd = sumReservedUsd(otherReservations) - const remainingBeforeAdjust = Math.max(0, usage.limit - usage.currentUsage - otherReservedUsd) - - if (usage.isExceeded || remainingBeforeAdjust < params.requestedUsd) { - return { - allowed: false, - status: 402, - reservationId: params.reservationId, - reservedUsd: reservation.reservedUsd, - currentUsage: usage.currentUsage, - limit: usage.limit, - remaining: remainingBeforeAdjust, - activeReservedUsd: otherReservedUsd + reservation.reservedUsd, - scopeType: lookup.scopeType, - scopeId: lookup.scopeId, - message: usage.message, - } - } - - const refreshedReservation: CopilotUsageReservation = { - ...reservation, - userId: params.userId, - workflowId: params.workflowId ?? reservation.workflowId, - reservedUsd: params.requestedUsd, - reason: params.reason ?? reservation.reason, - expiresAt: new Date(Date.now() + RESERVATION_TTL_SECONDS * 1000).toISOString(), - } - - await writeScopeReservations(lookup, [...otherReservations, refreshedReservation]) - await writeReservationLookup(params.reservationId, lookup) - - return { - allowed: true, - status: 200, - reservationId: params.reservationId, - reservedUsd: refreshedReservation.reservedUsd, - currentUsage: usage.currentUsage, - limit: usage.limit, - remaining: Math.max(0, remainingBeforeAdjust - refreshedReservation.reservedUsd), - activeReservedUsd: otherReservedUsd + refreshedReservation.reservedUsd, - scopeType: lookup.scopeType, - scopeId: lookup.scopeId, - expiresAt: refreshedReservation.expiresAt, - message: usage.message, - } - }) -} - export async function releaseCopilotUsageReservation(params: { reservationId: string }): Promise { - const lookup = await readReservationLookup(params.reservationId) - if (!lookup) { + const scope = await readReservationLookup(params.reservationId) + if (!scope) { return { released: false, reservationId: params.reservationId, } } - return withScopeLock(lookup, async () => { - const reservations = pruneExpiredReservations(await readScopeReservations(lookup)) - const reservation = reservations.find((entry) => entry.id === params.reservationId) ?? null - const remainingReservations = reservations.filter((entry) => entry.id !== params.reservationId) - - await writeScopeReservations(lookup, remainingReservations) - await deleteCachedValue(getReservationLookupKey(params.reservationId)) - - return { - released: reservation !== null, - reservationId: params.reservationId, - reservedUsd: reservation?.reservedUsd, - scopeType: lookup.scopeType, - scopeId: lookup.scopeId, - } - }) + return withScopeLock(scope, () => releaseReservationInScope(scope, params.reservationId)) } diff --git a/apps/tradinggoose/lib/email/mailer.test.ts b/apps/tradinggoose/lib/email/mailer.test.ts index 2d7652891..9381675f2 100644 --- a/apps/tradinggoose/lib/email/mailer.test.ts +++ b/apps/tradinggoose/lib/email/mailer.test.ts @@ -5,8 +5,17 @@ const mockBatchSend = vi.fn() const mockContactsCreate = vi.fn() const mockAzureBeginSend = vi.fn() const mockAzurePollUntilDone = vi.fn() -const { mockResolveResendServiceConfig, mockResolveAzureCommunicationEmailServiceConfig } = +const { + mockLoggerInfo, + mockLoggerWarn, + mockLoggerError, + mockResolveResendServiceConfig, + mockResolveAzureCommunicationEmailServiceConfig, +} = vi.hoisted(() => ({ + mockLoggerInfo: vi.fn(), + mockLoggerWarn: vi.fn(), + mockLoggerError: vi.fn(), mockResolveResendServiceConfig: vi.fn(), mockResolveAzureCommunicationEmailServiceConfig: vi.fn(), })) @@ -59,6 +68,14 @@ vi.mock('@/lib/urls/utils', () => ({ getBaseUrl: vi.fn().mockReturnValue('https://test.tradinggoose.ai'), })) +vi.mock('@/lib/logs/console/logger', () => ({ + createLogger: () => ({ + info: mockLoggerInfo, + warn: mockLoggerWarn, + error: mockLoggerError, + }), +})) + import { addVerifiedUserEmailToAudience, type EmailType, @@ -163,7 +180,7 @@ describe('mailer', () => { html: '

Test content

Unsubscribe', headers: { 'List-Unsubscribe': - '', + '', 'List-Unsubscribe-Post': 'List-Unsubscribe=One-Click', }, }) @@ -185,6 +202,39 @@ describe('mailer', () => { expect(mockSend).not.toHaveBeenCalled() }) + it('should log unsent email text when no email service is configured', async () => { + mockResolveResendServiceConfig.mockResolvedValue({ + apiKey: null, + audienceId: null, + }) + mockResolveAzureCommunicationEmailServiceConfig.mockResolvedValue({ + connectionString: null, + }) + + const result = await sendEmail({ + ...testEmailOptions, + text: 'Verification code: 123456', + emailType: 'transactional', + }) + + expect(result).toEqual({ + success: true, + message: 'Email logging successful (no email service configured)', + data: { id: 'mock-email-id' }, + }) + expect(mockSend).not.toHaveBeenCalled() + expect(mockAzureBeginSend).not.toHaveBeenCalled() + expect(mockLoggerInfo).toHaveBeenCalledWith( + 'Email not sent (no email service configured):', + expect.objectContaining({ + to: testEmailOptions.to, + subject: testEmailOptions.subject, + from: 'TradingGoose ', + text: 'Verification code: 123456', + }) + ) + }) + it.concurrent('should handle Resend API errors and fallback to Azure', async () => { // Mock Resend to fail mockSend.mockResolvedValue({ diff --git a/apps/tradinggoose/lib/email/mailer.ts b/apps/tradinggoose/lib/email/mailer.ts index ef1f4df55..43fff8fcd 100644 --- a/apps/tradinggoose/lib/email/mailer.ts +++ b/apps/tradinggoose/lib/email/mailer.ts @@ -184,6 +184,7 @@ export async function sendEmail(options: EmailOptions): Promise to: options.to, subject: options.subject, from: processedData.senderEmail, + text: processedData.text, }) return { success: true, diff --git a/apps/tradinggoose/lib/environment/api.ts b/apps/tradinggoose/lib/environment/api.ts index 1a38cbd31..2b0525f90 100644 --- a/apps/tradinggoose/lib/environment/api.ts +++ b/apps/tradinggoose/lib/environment/api.ts @@ -17,12 +17,14 @@ export interface WorkspaceEnvironmentData { personalRows: WorkspaceEnvironmentRow[] } -export async function fetchPersonalEnvironment(): Promise> { +export async function fetchPersonalEnvironment( + callbackPathname: string +): Promise> { const response = await fetch(API_ENDPOINTS.ENVIRONMENT, { cache: 'no-store' }) if (!response.ok) { if (response.status === 401) { - await handleAuthError('environment-api:personal') + await handleAuthError('environment-api:personal', callbackPathname) } throw new Error(`Failed to load environment variables: ${response.statusText}`) } @@ -37,7 +39,8 @@ export async function fetchPersonalEnvironment(): Promise { const response = await fetch(API_ENDPOINTS.WORKSPACE_ENVIRONMENT(workspaceId), { cache: 'no-store', @@ -45,7 +48,7 @@ export async function fetchWorkspaceEnvironment( if (!response.ok) { if (response.status === 401) { - await handleAuthError('environment-api:workspace') + await handleAuthError('environment-api:workspace', callbackPathname) } throw new Error(`Failed to load workspace environment: ${response.statusText}`) } diff --git a/apps/tradinggoose/lib/indicators/custom/operations.ts b/apps/tradinggoose/lib/indicators/custom/operations.ts index 70e6c8260..a0e811c07 100644 --- a/apps/tradinggoose/lib/indicators/custom/operations.ts +++ b/apps/tradinggoose/lib/indicators/custom/operations.ts @@ -20,7 +20,6 @@ interface UpsertIndicatorsParams { indicators: Array<{ id?: string name: string - color?: string pineCode: string inputMeta?: Record }> @@ -36,20 +35,6 @@ interface ImportIndicatorsParams { requestId?: string } -const resolveIndicatorColor = ( - input: string | null | undefined, - indicatorId: string, - fallback?: string | null -): string => { - if (typeof input === 'string' && input.trim().length > 0) { - return input.trim() - } - if (typeof fallback === 'string' && fallback.trim().length > 0) { - return fallback.trim() - } - return getStableVibrantColor(indicatorId) -} - export async function upsertIndicators({ indicators, workspaceId, @@ -72,13 +57,12 @@ export async function upsertIndicators({ if (existing.length > 0) { const existingColor = existing[0]?.color - const nextColor = resolveIndicatorColor(indicator.color, indicator.id, existingColor) await tx .update(pineIndicators) .set({ name: indicator.name, - color: nextColor, + color: existingColor ?? getStableVibrantColor(indicator.id), pineCode: indicator.pineCode, inputMeta: indicator.inputMeta ?? null, updatedAt: nowTime, @@ -92,14 +76,12 @@ export async function upsertIndicators({ } const indicatorId = indicator.id ?? crypto.randomUUID() - const nextColor = resolveIndicatorColor(indicator.color, indicatorId) - await tx.insert(pineIndicators).values({ id: indicatorId, workspaceId, userId, name: indicator.name, - color: nextColor, + color: getStableVibrantColor(indicatorId), pineCode: indicator.pineCode, inputMeta: indicator.inputMeta ?? null, createdAt: nowTime, @@ -159,7 +141,7 @@ export async function importIndicators({ workspaceId, userId, name: nextName, - color: resolveIndicatorColor(indicator.color, indicatorId), + color: getStableVibrantColor(indicatorId), pineCode: indicator.pineCode, inputMeta: indicator.inputMeta ?? null, createdAt: nowTime, diff --git a/apps/tradinggoose/lib/indicators/generated/copilot-indicator-reference.ts b/apps/tradinggoose/lib/indicators/generated/copilot-indicator-reference.ts index b9c9d7fad..fd2b502bc 100644 --- a/apps/tradinggoose/lib/indicators/generated/copilot-indicator-reference.ts +++ b/apps/tradinggoose/lib/indicators/generated/copilot-indicator-reference.ts @@ -127,13 +127,7 @@ export const INDICATOR_REFERENCE_SECTION_RECORDS = [ detail: 'TradingGoose saves indicators as JSON documents using `tg-indicator-document-v1`. The canonical field set is derived from the live indicator document schema.', support: 'curated', - relatedIds: [ - 'document.format', - 'document.name', - 'document.color', - 'document.pineCode', - 'document.inputMeta', - ], + relatedIds: ['document.format', 'document.name', 'document.pineCode', 'document.inputMeta'], sourceReferences: [ { label: 'Indicator document schema', @@ -141,7 +135,7 @@ export const INDICATOR_REFERENCE_SECTION_RECORDS = [ }, ], queryText: - 'section:document indicator document saved indicator document format and field-level requirements. tradinggoose saves indicators as json documents using `tg-indicator-document-v1`. the canonical field set is derived from the live indicator document schema. document.format document.name document.color document.pinecode document.inputmeta', + 'section:document indicator document saved indicator document format and field-level requirements. tradinggoose saves indicators as json documents using `tg-indicator-document-v1`. the canonical field set is derived from the live indicator document schema. document.format document.name document.pinecode document.inputmeta', }, { id: 'section:runtime', @@ -305,10 +299,10 @@ export const INDICATOR_REFERENCE_ITEM_RECORDS = [ title: 'Document Format', summary: 'Canonical indicator document format id and top-level field set.', detail: - 'TradingGoose indicator editing tools expect `tg-indicator-document-v1` JSON with the live field set `name, color, pineCode, inputMeta`.', + 'TradingGoose indicator editing tools expect `tg-indicator-document-v1` JSON with the live field set `name, pineCode, inputMeta`.', support: 'curated', - signature: 'tg-indicator-document-v1 = { name, color, pineCode, inputMeta }', - relatedIds: ['document.name', 'document.color', 'document.pineCode', 'document.inputMeta'], + signature: 'tg-indicator-document-v1 = { name, pineCode, inputMeta }', + relatedIds: ['document.name', 'document.pineCode', 'document.inputMeta'], sourceReferences: [ { label: 'Indicator document schema', @@ -316,7 +310,7 @@ export const INDICATOR_REFERENCE_ITEM_RECORDS = [ }, ], queryText: - 'document.format section:document document format canonical indicator document format id and top-level field set. tradinggoose indicator editing tools expect `tg-indicator-document-v1` json with the live field set `name, color, pinecode, inputmeta`. tg-indicator-document-v1 = { name, color, pinecode, inputmeta } document.name document.color document.pinecode document.inputmeta', + 'document.format section:document document format canonical indicator document format id and top-level field set. tradinggoose indicator editing tools expect `tg-indicator-document-v1` json with the live field set `name, pinecode, inputmeta`. tg-indicator-document-v1 = { name, pinecode, inputmeta } document.name document.pinecode document.inputmeta', }, { id: 'document.name', @@ -336,24 +330,6 @@ export const INDICATOR_REFERENCE_ITEM_RECORDS = [ queryText: 'document.name section:document document field: name human-readable indicator name in the canonical document. the `name` field is part of the live indicator document schema and is what tradinggoose renames when copilot updates an indicator title.', }, - { - id: 'document.color', - sectionId: 'section:document', - type: 'document_field', - title: 'Document Field: color', - summary: 'Default display color in the canonical document.', - detail: - 'The `color` field is part of the live indicator document schema and stores the default indicator display color.', - support: 'curated', - sourceReferences: [ - { - label: 'Indicator document schema', - path: 'apps/tradinggoose/lib/copilot/entity-documents.ts', - }, - ], - queryText: - 'document.color section:document document field: color default display color in the canonical document. the `color` field is part of the live indicator document schema and stores the default indicator display color.', - }, { id: 'document.pineCode', sectionId: 'section:document', diff --git a/apps/tradinggoose/lib/indicators/import-export.test.ts b/apps/tradinggoose/lib/indicators/import-export.test.ts index 9856cf723..3b8225f21 100644 --- a/apps/tradinggoose/lib/indicators/import-export.test.ts +++ b/apps/tradinggoose/lib/indicators/import-export.test.ts @@ -13,7 +13,6 @@ describe('indicator import/export helpers', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: { Length: { @@ -40,7 +39,6 @@ describe('indicator import/export helpers', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: { Length: { @@ -61,7 +59,6 @@ describe('indicator import/export helpers', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: undefined, }, @@ -80,7 +77,6 @@ describe('indicator import/export helpers', () => { indicators: [ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", }, ], @@ -101,7 +97,6 @@ describe('indicator import/export helpers', () => { indicators: [ { name: ' RSI Export Example ', - color: ' #3972F6 ', pineCode: "indicator('RSI Export Example')", inputMeta: {}, }, @@ -111,7 +106,6 @@ describe('indicator import/export helpers', () => { expect(parsed.indicators).toEqual([ { name: 'RSI Export Example', - color: '#3972F6', pineCode: "indicator('RSI Export Example')", inputMeta: {}, }, @@ -172,7 +166,7 @@ describe('indicator import/export helpers', () => { ).toThrow() }) - it('rejects import entries with extra keys', () => { + it('ignores generated indicator storage fields in transfer records', () => { expect(() => parseImportedIndicatorsFile({ version: '1', @@ -183,12 +177,13 @@ describe('indicator import/export helpers', () => { indicators: [ { id: 'indicator-1', + color: '#3972F6', name: 'RSI Export Example', pineCode: "indicator('RSI Export Example')", }, ], }) - ).toThrow() + ).not.toThrow() }) it('renames duplicate imported indicators with the imported marker', () => { diff --git a/apps/tradinggoose/lib/indicators/import-export.ts b/apps/tradinggoose/lib/indicators/import-export.ts index 0ec0d17d7..464788ee6 100644 --- a/apps/tradinggoose/lib/indicators/import-export.ts +++ b/apps/tradinggoose/lib/indicators/import-export.ts @@ -8,11 +8,6 @@ import type { IndicatorDefinition } from '@/stores/indicators/types' const IMPORTED_INDICATOR_MARKER = '(imported)' const normalizeInlineWhitespace = (value: string) => value.trim().replace(/\s+/g, ' ') -const normalizeOptionalString = (value: string | null | undefined) => { - if (typeof value !== 'string') return undefined - const normalized = value.trim() - return normalized.length > 0 ? normalized : undefined -} export const IndicatorTransferSchema = z .object({ @@ -20,11 +15,9 @@ export const IndicatorTransferSchema = z .string() .transform(normalizeInlineWhitespace) .pipe(z.string().min(1, 'Indicator name is required')), - color: z.string().transform(normalizeInlineWhitespace).optional(), pineCode: z.string(), inputMeta: z.record(z.any()).optional(), }) - .strict() export const IndicatorsTransferListSchema = z .array(IndicatorTransferSchema) @@ -46,11 +39,10 @@ export type IndicatorTransferRecord = z.infer export type IndicatorsImportFile = z.infer function normalizeIndicatorForTransfer( - indicator: Pick + indicator: Pick ): IndicatorTransferRecord { return { name: normalizeInlineWhitespace(indicator.name), - color: normalizeOptionalString(indicator.color), pineCode: indicator.pineCode ?? '', inputMeta: indicator.inputMeta && typeof indicator.inputMeta === 'object' @@ -67,7 +59,7 @@ export function createIndicatorsExportFile({ indicators, exportedFrom, }: { - indicators: Array> + indicators: Array> exportedFrom: string }): IndicatorsImportFile { return createTradingGooseExportFile({ @@ -83,7 +75,7 @@ export function exportIndicatorsAsJson({ indicators, exportedFrom, }: { - indicators: Array> + indicators: Array> exportedFrom: string }): string { return JSON.stringify(createIndicatorsExportFile({ indicators, exportedFrom }), null, 2) diff --git a/apps/tradinggoose/lib/knowledge/service.ts b/apps/tradinggoose/lib/knowledge/service.ts index 118f0f76e..a6f9dd199 100644 --- a/apps/tradinggoose/lib/knowledge/service.ts +++ b/apps/tradinggoose/lib/knowledge/service.ts @@ -5,9 +5,8 @@ import { embedding, knowledgeBase, knowledgeBaseTagDefinitions, - permissions, } from '@tradinggoose/db/schema' -import { and, count, eq, inArray, isNotNull, isNull } from 'drizzle-orm' +import { and, count, eq, inArray, isNull } from 'drizzle-orm' import { checkStorageQuota, incrementStorageUsage } from '@/lib/billing/storage' import { enqueueDocumentProcessingJobs } from '@/lib/knowledge/documents/service' import { @@ -20,7 +19,7 @@ import type { KnowledgeBaseWithCounts, } from '@/lib/knowledge/types' import { createLogger } from '@/lib/logs/console/logger' -import { getUserEntityPermissions } from '@/lib/permissions/utils' +import { checkWorkspaceAccess, getUserEntityPermissions } from '@/lib/permissions/utils' const logger = createLogger('KnowledgeBaseService') @@ -31,6 +30,11 @@ export async function getKnowledgeBases( userId: string, workspaceId: string ): Promise { + const workspaceAccess = await checkWorkspaceAccess(workspaceId, userId) + if (!workspaceAccess.hasAccess) { + return [] + } + const knowledgeBasesWithCounts = await db .select({ id: knowledgeBase.id, @@ -50,21 +54,7 @@ export async function getKnowledgeBases( document, and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt)) ) - .leftJoin( - permissions, - and( - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, knowledgeBase.workspaceId), - eq(permissions.userId, userId) - ) - ) - .where( - and( - isNull(knowledgeBase.deletedAt), - eq(knowledgeBase.workspaceId, workspaceId), - isNotNull(permissions.userId) - ) - ) + .where(and(isNull(knowledgeBase.deletedAt), eq(knowledgeBase.workspaceId, workspaceId))) .groupBy(knowledgeBase.id) .orderBy(knowledgeBase.createdAt) diff --git a/apps/tradinggoose/lib/oauth/connect.ts b/apps/tradinggoose/lib/oauth/connect.ts index 036e20cc9..8b1cf4788 100644 --- a/apps/tradinggoose/lib/oauth/connect.ts +++ b/apps/tradinggoose/lib/oauth/connect.ts @@ -1,6 +1,7 @@ 'use client' import { client } from '@/lib/auth-client' +import { normalizeCallbackUrl } from '@/i18n/utils' interface ConnectOAuthServiceOptions { providerId: string @@ -11,14 +12,19 @@ export async function startOAuthConnectFlow({ providerId, callbackURL, }: ConnectOAuthServiceOptions) { + const canonicalCallbackURL = normalizeCallbackUrl(callbackURL, window.location.origin) + if (!canonicalCallbackURL) { + throw new Error('Expected an internal OAuth callback URL') + } + if (providerId === 'trello') { - window.location.href = `/api/auth/trello/authorize?callbackURL=${encodeURIComponent(callbackURL)}` + window.location.href = `/api/auth/trello/authorize?callbackURL=${encodeURIComponent(canonicalCallbackURL)}` return } await client.oauth2.link({ providerId, - callbackURL, - errorCallbackURL: callbackURL, + callbackURL: canonicalCallbackURL, + errorCallbackURL: canonicalCallbackURL, }) } diff --git a/apps/tradinggoose/lib/permissions/utils.test.ts b/apps/tradinggoose/lib/permissions/utils.test.ts index 8e1232a9f..e21cbdb68 100644 --- a/apps/tradinggoose/lib/permissions/utils.test.ts +++ b/apps/tradinggoose/lib/permissions/utils.test.ts @@ -27,7 +27,6 @@ vi.mock('@tradinggoose/db/schema', () => ({ id: 'user_id', email: 'user_email', name: 'user_name', - image: 'user_image', }, workspace: { id: 'workspace_id', @@ -54,10 +53,7 @@ import { getUserEntityPermissions, getUsersWithPermissions, getWorkspaceById, - getWorkspaceMemberProfiles, - hasAdminPermission, hasWorkspaceAdminAccess, - workspaceExists, } from '@/lib/permissions/utils' const mockDb = db as any @@ -197,16 +193,18 @@ describe('Permission Utils', () => { expect(result).toBeNull() }) - }) - describe('workspace helpers', () => { - it('should report when a workspace exists', async () => { - const chain = createMockChain([{ id: 'workspace123' }]) - mockDb.select.mockReturnValue(chain) + it('should return admin for workspace owners without permission rows', async () => { + mockDb.select.mockReturnValueOnce(createMockChain([{ ownerId: 'owner-1' }])) - await expect(workspaceExists('workspace123')).resolves.toBe(true) + const result = await getUserEntityPermissions('owner-1', 'workspace', 'workspace456') + + expect(result).toBe('admin') + expect(mockDb.select).toHaveBeenCalledTimes(1) }) + }) + describe('workspace helpers', () => { it('should return the workspace row by id', async () => { const workspaceRow = { id: 'workspace123', @@ -334,84 +332,38 @@ describe('Permission Utils', () => { }) }) - describe('getWorkspaceMemberProfiles', () => { - it('should return minimal member profiles for a workspace', async () => { - const members = [ - { userId: 'user-1', name: 'Alice', image: 'alice.png' }, - { userId: 'user-2', name: 'Bob', image: null }, - ] - mockDb.select.mockReturnValue(createMockChain(members)) - - const result = await getWorkspaceMemberProfiles('workspace123') - - expect(result).toEqual(members) - }) - }) - - describe('hasAdminPermission', () => { - it('should return true when user has admin permission for workspace', async () => { - const chain = createMockChain([{ id: 'perm1' }]) - mockDb.select.mockReturnValue(chain) - - const result = await hasAdminPermission('admin-user', 'workspace123') - - expect(result).toBe(true) - }) - - it('should return false when user has no admin permission for workspace', async () => { - const chain = createMockChain([]) - mockDb.select.mockReturnValue(chain) - - const result = await hasAdminPermission('regular-user', 'workspace123') - - expect(result).toBe(false) - }) - - it('should return false when user has write permission but not admin', async () => { - const chain = createMockChain([]) - mockDb.select.mockReturnValue(chain) - - const result = await hasAdminPermission('write-user', 'workspace123') - - expect(result).toBe(false) - }) - - it('should return false when user has read permission but not admin', async () => { - const chain = createMockChain([]) - mockDb.select.mockReturnValue(chain) - - const result = await hasAdminPermission('read-user', 'workspace123') - - expect(result).toBe(false) - }) - - it('should handle non-existent workspace', async () => { - const chain = createMockChain([]) - mockDb.select.mockReturnValue(chain) - - const result = await hasAdminPermission('user123', 'non-existent-workspace') - - expect(result).toBe(false) - }) - - it('should handle empty user ID', async () => { - const chain = createMockChain([]) - mockDb.select.mockReturnValue(chain) + describe('getUsersWithPermissions', () => { + it('should return empty array when the workspace owner is unavailable', async () => { + const ownerChain = createMockChain([]) + const usersChain = createMockChain([]) + mockDb.select.mockReturnValueOnce(ownerChain).mockReturnValueOnce(usersChain) - const result = await hasAdminPermission('', 'workspace123') + const result = await getUsersWithPermissions('workspace123') - expect(result).toBe(false) + expect(result).toEqual([]) }) - }) - describe('getUsersWithPermissions', () => { - it('should return empty array when no users have permissions for workspace', async () => { + it('should include the workspace owner as admin without a permission row', async () => { + const ownerChain = createMockChain([ + { + userId: 'owner-1', + email: 'owner@example.com', + name: 'Owner User', + }, + ]) const usersChain = createMockChain([]) - mockDb.select.mockReturnValue(usersChain) + mockDb.select.mockReturnValueOnce(ownerChain).mockReturnValueOnce(usersChain) const result = await getUsersWithPermissions('workspace123') - expect(result).toEqual([]) + expect(result).toEqual([ + { + userId: 'owner-1', + email: 'owner@example.com', + name: 'Owner User', + permissionType: 'admin', + }, + ]) }) it('should return users with their permissions for workspace', async () => { @@ -424,8 +376,9 @@ describe('Permission Utils', () => { }, ] + const ownerChain = createMockChain([]) const usersChain = createMockChain(mockUsersResults) - mockDb.select.mockReturnValue(usersChain) + mockDb.select.mockReturnValueOnce(ownerChain).mockReturnValueOnce(usersChain) const result = await getUsersWithPermissions('workspace456') @@ -461,15 +414,16 @@ describe('Permission Utils', () => { }, ] + const ownerChain = createMockChain([]) const usersChain = createMockChain(mockUsersResults) - mockDb.select.mockReturnValue(usersChain) + mockDb.select.mockReturnValueOnce(ownerChain).mockReturnValueOnce(usersChain) const result = await getUsersWithPermissions('workspace456') expect(result).toHaveLength(3) - expect(result[0].permissionType).toBe('admin') - expect(result[1].permissionType).toBe('write') - expect(result[2].permissionType).toBe('read') + expect(result.find((row) => row.userId === 'user1')?.permissionType).toBe('admin') + expect(result.find((row) => row.userId === 'user2')?.permissionType).toBe('write') + expect(result.find((row) => row.userId === 'user3')?.permissionType).toBe('read') }) it('should handle users with empty names', async () => { @@ -482,8 +436,9 @@ describe('Permission Utils', () => { }, ] + const ownerChain = createMockChain([]) const usersChain = createMockChain(mockUsersResults) - mockDb.select.mockReturnValue(usersChain) + mockDb.select.mockReturnValueOnce(ownerChain).mockReturnValueOnce(usersChain) const result = await getUsersWithPermissions('workspace123') @@ -508,7 +463,7 @@ describe('Permission Utils', () => { if (callCount === 1) { return createMockChain([{ ownerId: 'other-user' }]) } - return createMockChain([{ id: 'perm1' }]) + return createMockChain([{ permissionType: 'admin' }]) }) const result = await hasWorkspaceAdminAccess('user123', 'workspace456') @@ -532,7 +487,7 @@ describe('Permission Utils', () => { if (callCount === 1) { return createMockChain([{ ownerId: 'other-user' }]) } - return createMockChain([]) + return createMockChain([{ permissionType: 'write' }]) }) const result = await hasWorkspaceAdminAccess('user123', 'workspace456') @@ -547,7 +502,7 @@ describe('Permission Utils', () => { if (callCount === 1) { return createMockChain([{ ownerId: 'other-user' }]) } - return createMockChain([]) + return createMockChain([{ permissionType: 'read' }]) }) const result = await hasWorkspaceAdminAccess('user123', 'workspace456') diff --git a/apps/tradinggoose/lib/permissions/utils.ts b/apps/tradinggoose/lib/permissions/utils.ts index e9ccedeb5..42b12fd50 100644 --- a/apps/tradinggoose/lib/permissions/utils.ts +++ b/apps/tradinggoose/lib/permissions/utils.ts @@ -13,27 +13,11 @@ export interface WorkspaceAccess { workspace: WorkspaceRecord | null } -export interface WorkspaceMemberProfile { - userId: string - name: string - image: string | null -} - async function selectWorkspaceById(workspaceId: string): Promise { const [row] = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1) return row ?? null } -export async function workspaceExists(workspaceId: string): Promise { - const [row] = await db - .select({ id: workspace.id }) - .from(workspace) - .where(eq(workspace.id, workspaceId)) - .limit(1) - - return !!row -} - export async function getWorkspaceById(workspaceId: string): Promise { return await selectWorkspaceById(workspaceId) } @@ -111,10 +95,14 @@ export async function getUserEntityPermissions( entityId: string ): Promise { if (entityType === 'workspace') { - const activeWorkspace = await workspaceExists(entityId) + const activeWorkspace = await selectWorkspaceById(entityId) if (!activeWorkspace) { return null } + + if (activeWorkspace.ownerId === userId) { + return 'admin' + } } const result = await db @@ -142,30 +130,6 @@ export async function getUserEntityPermissions( return highestPermission.permissionType } -/** - * Check if a user has admin permission for a specific workspace - * - * @param userId - The ID of the user to check - * @param workspaceId - The ID of the workspace to check - * @returns Promise - True if the user has admin permission for the workspace, false otherwise - */ -export async function hasAdminPermission(userId: string, workspaceId: string): Promise { - const result = await db - .select({ id: permissions.id }) - .from(permissions) - .where( - and( - eq(permissions.userId, userId), - eq(permissions.entityType, 'workspace'), - eq(permissions.entityId, workspaceId), - eq(permissions.permissionType, 'admin') - ) - ) - .limit(1) - - return result.length > 0 -} - /** * Retrieves a list of users with their associated permissions for a given workspace. * @@ -180,6 +144,17 @@ export async function getUsersWithPermissions(workspaceId: string): Promise< permissionType: PermissionType }> > { + const [owner] = await db + .select({ + userId: user.id, + email: user.email, + name: user.name, + }) + .from(workspace) + .innerJoin(user, eq(workspace.ownerId, user.id)) + .where(eq(workspace.id, workspaceId)) + .limit(1) + const usersWithPermissions = await db .select({ userId: user.id, @@ -193,12 +168,35 @@ export async function getUsersWithPermissions(workspaceId: string): Promise< .where(and(eq(permissions.entityType, 'workspace'), eq(permissions.entityId, workspaceId))) .orderBy(user.email) - return usersWithPermissions.map((row) => ({ - userId: row.userId, - email: row.email, - name: row.name, - permissionType: row.permissionType, - })) + const usersById = new Map< + string, + { + userId: string + email: string + name: string + permissionType: PermissionType + } + >() + + if (owner) { + usersById.set(owner.userId, { + ...owner, + permissionType: 'admin', + }) + } + + for (const row of usersWithPermissions) { + if (!usersById.has(row.userId)) { + usersById.set(row.userId, { + userId: row.userId, + email: row.email, + name: row.name, + permissionType: row.permissionType, + }) + } + } + + return [...usersById.values()].sort((a, b) => a.email.localeCompare(b.email)) } /** @@ -212,17 +210,7 @@ export async function hasWorkspaceAdminAccess( userId: string, workspaceId: string ): Promise { - const ws = await selectWorkspaceById(workspaceId) - - if (!ws) { - return false - } - - if (ws.ownerId === userId) { - return true - } - - return await hasAdminPermission(userId, workspaceId) + return (await getUserEntityPermissions(userId, 'workspace', workspaceId)) === 'admin' } /** @@ -279,20 +267,3 @@ export async function getManageableWorkspaces(userId: string): Promise< return combined } - -export async function getWorkspaceMemberProfiles( - workspaceId: string -): Promise { - const rows = await db - .select({ - userId: user.id, - name: user.name, - image: user.image, - }) - .from(permissions) - .innerJoin(user, eq(permissions.userId, user.id)) - .innerJoin(workspace, eq(permissions.entityId, workspace.id)) - .where(and(eq(permissions.entityType, 'workspace'), eq(permissions.entityId, workspaceId))) - - return rows -} diff --git a/apps/tradinggoose/lib/subscription/upgrade.ts b/apps/tradinggoose/lib/subscription/upgrade.ts index 65171ac18..4cff00d47 100644 --- a/apps/tradinggoose/lib/subscription/upgrade.ts +++ b/apps/tradinggoose/lib/subscription/upgrade.ts @@ -40,7 +40,7 @@ export function useSubscriptionUpgrade() { throw new Error('User not authenticated') } - let currentPersonalSubscriptionId: string | undefined + let currentPersonalStripeSubscriptionId: string | undefined let allSubscriptions: any[] = [] try { const listResult = await client.subscription.list() @@ -51,9 +51,10 @@ export function useSubscriptionUpgrade() { sub.status as (typeof ENTITLED_SUBSCRIPTION_STATUSES)[number] ) && sub.referenceId === userId ) - currentPersonalSubscriptionId = activePersonalSubscription?.id + currentPersonalStripeSubscriptionId = + activePersonalSubscription?.stripeSubscriptionId || undefined } catch (_e) { - currentPersonalSubscriptionId = undefined + currentPersonalStripeSubscriptionId = undefined } let referenceId = userId @@ -127,18 +128,19 @@ export function useSubscriptionUpgrade() { ...(targetTier.ownerType === 'organization' && { seats: initialSeats }), } as const - const finalParams = currentPersonalSubscriptionId - ? { ...upgradeParams, subscriptionId: currentPersonalSubscriptionId } - : upgradeParams + const finalParams = + targetTier.ownerType === 'user' && currentPersonalStripeSubscriptionId + ? { ...upgradeParams, subscriptionId: currentPersonalStripeSubscriptionId } + : upgradeParams logger.info( - currentPersonalSubscriptionId + targetTier.ownerType === 'user' && currentPersonalStripeSubscriptionId ? 'Upgrading existing subscription' : 'Creating new subscription', { billingTierId: targetTier.billingTierId, billingTier: targetTier.displayName, - subscriptionId: currentPersonalSubscriptionId, + stripeSubscriptionId: currentPersonalStripeSubscriptionId, usageScope: targetTier.usageScope, seatMode: targetTier.seatMode, referenceId, diff --git a/apps/tradinggoose/lib/urls/utils.test.ts b/apps/tradinggoose/lib/urls/utils.test.ts new file mode 100644 index 000000000..bb6abc2c5 --- /dev/null +++ b/apps/tradinggoose/lib/urls/utils.test.ts @@ -0,0 +1,45 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockEnv } = vi.hoisted(() => ({ + mockEnv: {} as Record, +})) + +vi.mock('@/lib/env', () => ({ + getEnv: (key: string) => mockEnv[key], +})) + +import { getBaseDomain, getBaseUrl, getEmailDomain } from './utils' + +describe('url helpers', () => { + beforeEach(() => { + for (const key of Object.keys(mockEnv)) { + delete mockEnv[key] + } + mockEnv.NEXT_PUBLIC_APP_URL = 'https://www.tradinggoose.ai' + }) + + it('uses NEXT_PUBLIC_APP_URL as the base URL', () => { + expect(getBaseUrl()).toBe('https://www.tradinggoose.ai') + }) + + it('treats preview and production as configured app URLs', () => { + mockEnv.NEXT_PUBLIC_APP_URL = 'https://preview.tradinggoose.ai' + + expect(getBaseUrl()).toBe('https://preview.tradinggoose.ai') + }) + + it('derives the base domain only from NEXT_PUBLIC_APP_URL', () => { + mockEnv.NEXT_PUBLIC_APP_URL = 'https://www.tradinggoose.ai' + + expect(getBaseDomain()).toBe('www.tradinggoose.ai') + expect(getEmailDomain()).toBe('tradinggoose.ai') + }) + + it('rejects missing and invalid configured app urls outside email preview', () => { + mockEnv.NEXT_PUBLIC_APP_URL = undefined + expect(() => getBaseDomain()).toThrow('NEXT_PUBLIC_APP_URL is required') + + mockEnv.NEXT_PUBLIC_APP_URL = 'app.tradinggoose.ai' + expect(() => getBaseDomain()).toThrow('Configured base URL must be a valid URL') + }) +}) diff --git a/apps/tradinggoose/lib/urls/utils.ts b/apps/tradinggoose/lib/urls/utils.ts index 883f965ca..d78da6c2b 100644 --- a/apps/tradinggoose/lib/urls/utils.ts +++ b/apps/tradinggoose/lib/urls/utils.ts @@ -1,65 +1,30 @@ import { getEnv } from '@/lib/env' -import { isProd } from '@/lib/environment' -/** - * Returns the base URL of the application from NEXT_PUBLIC_APP_URL - * This ensures webhooks, callbacks, and other integrations always use the correct public URL - * @returns The base URL string (e.g., 'http://localhost:3000' or 'https://example.com') - * @throws Error if NEXT_PUBLIC_APP_URL is not configured - */ export function getBaseUrl(): string { - const baseUrl = getEnv('NEXT_PUBLIC_APP_URL') - const isPreviewDev = - getEnv('NEXT_PUBLIC_IS_PREVIEW_DEVELOPMENT') === 'true' || - getEnv('NEXT_PUBLIC_IS_PREVIEW_DEVELOPMENT') === '1' - const isReactEmailPreview = - !!process.env.EMAILS_DIR_ABSOLUTE_PATH || !!process.env.PREVIEW_SERVER_LOCATION + const value = getEnv('NEXT_PUBLIC_APP_URL')?.trim() - if (!baseUrl || !baseUrl.trim()) { - const fallback = process.env.EMAILS_PREVIEW_BASE_URL || 'http://localhost:3000' - if (isProd && !isPreviewDev && !isReactEmailPreview) { - // Avoid hard crash in production builds but surface a warning for misconfiguration. - // eslint-disable-next-line no-console - console.warn('NEXT_PUBLIC_APP_URL missing – falling back to', fallback) - } - return fallback + if (!value) { + throw new Error('NEXT_PUBLIC_APP_URL is required') } - if (baseUrl.startsWith('http://') || baseUrl.startsWith('https://')) { - return baseUrl + let url: URL + try { + url = new URL(value) + } catch { + throw new Error('Configured base URL must be a valid URL') + } + + if (url.protocol !== 'http:' && url.protocol !== 'https:') { + throw new Error('Configured base URL must use http or https') } - const protocol = isProd ? 'https://' : 'http://' - return `${protocol}${baseUrl}` + return url.origin } -/** - * Returns just the domain and port part of the application URL - * @returns The domain with port if applicable (e.g., 'localhost:3000' or 'tradinggoose.ai') - */ export function getBaseDomain(): string { - try { - const url = new URL(getBaseUrl()) - return url.host // host includes port if specified - } catch (_e) { - const fallbackUrl = getEnv('NEXT_PUBLIC_APP_URL') || 'http://localhost:3000' - try { - return new URL(fallbackUrl).host - } catch { - return isProd ? 'tradinggoose.ai' : 'localhost:3000' - } - } + return new URL(getBaseUrl()).host } -/** - * Returns the domain for email addresses, stripping www subdomain for Resend compatibility - * @returns The email domain (e.g., 'tradinggoose.ai' instead of 'www.tradinggoose.ai') - */ export function getEmailDomain(): string { - try { - const baseDomain = getBaseDomain() - return baseDomain.startsWith('www.') ? baseDomain.substring(4) : baseDomain - } catch (_e) { - return isProd ? 'tradinggoose.ai' : 'localhost:3000' - } + return getBaseDomain().replace(/^www\./, '') } diff --git a/apps/tradinggoose/lib/watchlists/types.ts b/apps/tradinggoose/lib/watchlists/types.ts index 8af94a28c..50e1891ad 100644 --- a/apps/tradinggoose/lib/watchlists/types.ts +++ b/apps/tradinggoose/lib/watchlists/types.ts @@ -1,3 +1,4 @@ +import type { TradingGooseExportEnvelope } from '@/lib/import-export/trading-goose' import type { ListingIdentity } from '@/lib/listing/identity' export type WatchlistSettings = { @@ -38,17 +39,8 @@ export type WatchlistTransferRecord = { items: WatchlistImportFileItem[] } -export type WatchlistImportFile = { - version: '1' - fileType: 'tradingGooseExport' - exportedAt: string - exportedFrom: string - resourceTypes: string[] +export type WatchlistImportFile = TradingGooseExportEnvelope & { watchlists: [WatchlistTransferRecord] - skills?: unknown[] - workflows?: unknown[] - customTools?: unknown[] - indicators?: unknown[] } export type WatchlistRecord = { diff --git a/apps/tradinggoose/lib/webhooks/processor.ts b/apps/tradinggoose/lib/webhooks/processor.ts index 926f8bcd2..48d455d2a 100644 --- a/apps/tradinggoose/lib/webhooks/processor.ts +++ b/apps/tradinggoose/lib/webhooks/processor.ts @@ -378,6 +378,10 @@ export async function queueWebhookExecution( } const headers = Object.fromEntries(request.headers.entries()) + if (typeof foundWebhook.blockId !== 'string' || foundWebhook.blockId.length === 0) { + logger.warn(`[${options.requestId}] Webhook ${foundWebhook.id} is missing trigger block`) + return NextResponse.json({ message: 'Webhook trigger block not found' }, { status: 410 }) + } // For Microsoft Teams Graph notifications, extract unique identifiers for idempotency if ( @@ -409,7 +413,7 @@ export async function queueWebhookExecution( const pendingExecutionId = `webhook_execution:${IdempotencyService.createWebhookIdempotencyKey( foundWebhook.id, - headers, + headers )}` const handle = await enqueuePendingExecution({ @@ -429,7 +433,7 @@ export async function queueWebhookExecution( logger.info( `[${options.requestId}] Queued ${options.testMode ? 'TEST ' : ''}webhook execution ${ handle.pendingExecutionId - } for ${foundWebhook.provider} webhook`, + } for ${foundWebhook.provider} webhook` ) } catch (error: any) { if (error instanceof TriggerExecutionUnavailableError) { diff --git a/apps/tradinggoose/lib/workflows/document-format.ts b/apps/tradinggoose/lib/workflows/document-format.ts index 1fef5d54b..3d51364d5 100644 --- a/apps/tradinggoose/lib/workflows/document-format.ts +++ b/apps/tradinggoose/lib/workflows/document-format.ts @@ -1 +1,2 @@ export const TG_MERMAID_DOCUMENT_FORMAT = 'tg-mermaid-v1' as const +export const WORKFLOW_GRAPH_MERMAID_DOCUMENT_FORMAT = 'tg-workflow-graph-mermaid-v1' as const diff --git a/apps/tradinggoose/lib/workflows/execution-runner.test.ts b/apps/tradinggoose/lib/workflows/execution-runner.test.ts index 3d4e45704..6610a0f81 100644 --- a/apps/tradinggoose/lib/workflows/execution-runner.test.ts +++ b/apps/tradinggoose/lib/workflows/execution-runner.test.ts @@ -71,7 +71,7 @@ vi.mock('@/lib/workflows/db-helpers', () => ({ vi.mock('@/lib/workflows/triggers', () => ({ TriggerUtils: { - findStartBlock: vi.fn(), + findTriggerBlock: vi.fn(), }, })) @@ -157,7 +157,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'webhook', workflowInput: { symbol: 'AAPL' }, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'block', blockId: 'trigger', }, @@ -219,7 +219,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'block', blockId: 'trigger', }, @@ -247,7 +247,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'block', blockId: 'trigger', }, @@ -271,14 +271,14 @@ describe('runPreparedWorkflowExecution', () => { expect(result.dispatchFailureReason).toBe('usage_limit_exceeded') }) - it('reports missing start blocks as dispatch failures', async () => { + it('reports missing trigger blocks as dispatch failures', async () => { const result = await runPreparedWorkflowExecution({ blueprint, actorUserId: 'user-1', triggerType: 'webhook', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'block', blockId: 'missing', }, @@ -286,7 +286,7 @@ describe('runPreparedWorkflowExecution', () => { expect(mocks.execute).not.toHaveBeenCalled() expect(result.result.success).toBe(false) - expect(result.dispatchFailureReason).toBe('missing_start_block') + expect(result.dispatchFailureReason).toBe('missing_trigger_block') }) it('does not rewrite successful executions as failed when terminal success logging fails', async () => { @@ -299,7 +299,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'block', blockId: 'trigger', }, @@ -310,8 +310,8 @@ describe('runPreparedWorkflowExecution', () => { expect(mocks.completeWithError).not.toHaveBeenCalled() }) - it('resolves queued child API starts through the child input-trigger path', async () => { - vi.mocked(TriggerUtils.findStartBlock).mockReturnValue({ + it('resolves queued child API triggers through the child input-trigger path', async () => { + vi.mocked(TriggerUtils.findTriggerBlock).mockReturnValue({ blockId: 'trigger', block: { type: 'input_trigger' }, }) @@ -322,7 +322,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: { symbol: 'AAPL' }, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'trigger', triggerType: 'api', }, @@ -331,7 +331,7 @@ describe('runPreparedWorkflowExecution', () => { }, }) - expect(TriggerUtils.findStartBlock).toHaveBeenCalledWith( + expect(TriggerUtils.findTriggerBlock).toHaveBeenCalledWith( blueprint.workflowData.blocks, 'api', true @@ -349,7 +349,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'trigger', triggerType: 'manual', }, @@ -372,7 +372,7 @@ describe('runPreparedWorkflowExecution', () => { triggerType: 'manual', workflowInput: {}, executionId: 'execution-1', - start: { + triggerTarget: { kind: 'trigger', triggerType: 'manual', }, diff --git a/apps/tradinggoose/lib/workflows/execution-runner.ts b/apps/tradinggoose/lib/workflows/execution-runner.ts index c31f074b7..4ecdd3b6a 100644 --- a/apps/tradinggoose/lib/workflows/execution-runner.ts +++ b/apps/tradinggoose/lib/workflows/execution-runner.ts @@ -35,14 +35,14 @@ type ResolvedWorkflowExecutionContext = { variables: unknown } -export type WorkflowStart = +export type WorkflowTriggerTarget = | { kind: 'trigger' triggerType: 'api' | 'chat' | 'manual' } | { kind: 'block' - blockId?: string + blockId: string } export type WorkflowExecutionBlueprint = { @@ -58,7 +58,7 @@ export type WorkflowExecutionBlueprint = { } export type WorkflowRunnerExecutionResult = ExecutionResult -export type WorkflowDispatchFailureReason = 'usage_limit_exceeded' | 'missing_start_block' +export type WorkflowDispatchFailureReason = 'usage_limit_exceeded' | 'missing_trigger_block' export type WorkflowRunnerResult = { executionId: string @@ -78,7 +78,7 @@ export class WorkflowUsageLimitError extends Error { } } -class WorkflowStartBlockError extends Error {} +class WorkflowTriggerBlockError extends Error {} async function resolveRequiredWorkflowExecutionContext( workflowId: string, @@ -191,70 +191,66 @@ function buildProcessedBlockStates( return processedBlockStates } -function resolveStartBlockId(params: { +function resolveTriggerBlockId(params: { mergedStates: Record serializedWorkflow: { connections: Array<{ source: string }> } - start: WorkflowStart + target: WorkflowTriggerTarget isChildExecution: boolean }) { - if (params.start.kind === 'trigger') { - const startBlock = TriggerUtils.findStartBlock( + if (params.target.kind === 'trigger') { + const triggerBlock = TriggerUtils.findTriggerBlock( params.mergedStates, - params.start.triggerType, + params.target.triggerType, params.isChildExecution ) - if (!startBlock) { + if (!triggerBlock) { const triggerName = - params.start.triggerType === 'api' && params.isChildExecution + params.target.triggerType === 'api' && params.isChildExecution ? 'Input' - : params.start.triggerType === 'api' + : params.target.triggerType === 'api' ? 'API' - : params.start.triggerType === 'chat' + : params.target.triggerType === 'chat' ? 'Chat' : 'Manual' - throw new WorkflowStartBlockError( + throw new WorkflowTriggerBlockError( `No ${triggerName} trigger block found. Add a ${triggerName} Trigger block to this workflow.` ) } const outgoingConnections = params.serializedWorkflow.connections.filter( - (connection) => connection.source === startBlock.blockId + (connection) => connection.source === triggerBlock.blockId ) if (outgoingConnections.length === 0) { - throw new WorkflowStartBlockError( + throw new WorkflowTriggerBlockError( 'Trigger block must be connected to other blocks to execute' ) } - return startBlock.blockId + return triggerBlock.blockId } - if ( - params.start.kind === 'block' && - params.start.blockId && - !params.mergedStates[params.start.blockId] - ) { - throw new WorkflowStartBlockError( - `Workflow does not contain trigger block ${params.start.blockId}` + if (params.target.kind === 'block' && !params.mergedStates[params.target.blockId]) { + throw new WorkflowTriggerBlockError( + `Workflow does not contain trigger block ${params.target.blockId}` ) } - if (params.start.kind === 'block' && params.start.blockId) { - const blockId = params.start.blockId + if (params.target.kind === 'block') { + const blockId = params.target.blockId const outgoingConnections = params.serializedWorkflow.connections.filter( (connection) => connection.source === blockId ) if (outgoingConnections.length === 0) { - throw new WorkflowStartBlockError( + throw new WorkflowTriggerBlockError( `Trigger block ${blockId} must be connected to other blocks to execute` ) } } - return params.start.blockId + return params.target.blockId } export async function loadWorkflowExecutionBlueprint(params: { @@ -295,7 +291,7 @@ export async function runPreparedWorkflowExecution(params: { actorUserId: string triggerType: TriggerType workflowInput: unknown - start: WorkflowStart + triggerTarget: WorkflowTriggerTarget requestId?: string executionId?: string triggerData?: Record @@ -388,14 +384,14 @@ export async function runPreparedWorkflowExecution(params: { contextExtensions, }) - const startBlockId = resolveStartBlockId({ + const triggerBlockId = resolveTriggerBlockId({ mergedStates, serializedWorkflow, - start: params.start, + target: params.triggerTarget, isChildExecution: contextExtensions.isChildExecution === true, }) - result = await executor.execute(params.blueprint.workflowId, startBlockId) + result = await executor.execute(params.blueprint.workflowId, triggerBlockId) if (result.success) { await updateWorkflowRunCounts(params.blueprint.workflowId).catch((error) => @@ -407,8 +403,8 @@ export async function runPreparedWorkflowExecution(params: { const dispatchFailureReason = error instanceof WorkflowUsageLimitError ? 'usage_limit_exceeded' - : error instanceof WorkflowStartBlockError - ? 'missing_start_block' + : error instanceof WorkflowTriggerBlockError + ? 'missing_trigger_block' : undefined result = (error?.executionResult as ExecutionResult | undefined) || { success: false, @@ -469,7 +465,7 @@ export async function runWorkflowExecution(params: { actorUserId: string triggerType: TriggerType workflowInput: unknown - start: WorkflowStart + triggerTarget: WorkflowTriggerTarget executionTarget?: WorkflowExecutionTarget workflowContext?: WorkflowContextHint workflowData?: WorkflowExecutionBlueprint['workflowData'] @@ -502,7 +498,7 @@ export async function runWorkflowExecution(params: { actorUserId: params.actorUserId, triggerType: params.triggerType, workflowInput: params.workflowInput, - start: params.start, + triggerTarget: params.triggerTarget, requestId: params.requestId, executionId: params.executionId, triggerData: params.triggerData, diff --git a/apps/tradinggoose/lib/workflows/import-export.test.ts b/apps/tradinggoose/lib/workflows/import-export.test.ts index 832281559..27014726c 100644 --- a/apps/tradinggoose/lib/workflows/import-export.test.ts +++ b/apps/tradinggoose/lib/workflows/import-export.test.ts @@ -73,7 +73,6 @@ describe('workflow import/export helpers', () => { workflow: { name: ' Primary Workflow ', description: ' Workflow used for trading ', - color: ' #3972F6 ', state: createWorkflowState(), }, }) @@ -89,7 +88,6 @@ describe('workflow import/export helpers', () => { { name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', state: { blocks: { block_1: { @@ -115,7 +113,6 @@ describe('workflow import/export helpers', () => { workflow: { name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', state: createWorkflowStateWithSkills(), }, skills: [ @@ -185,7 +182,6 @@ describe('workflow import/export helpers', () => { { name: ' Primary Workflow ', description: ' Workflow used for trading ', - color: ' #3972F6 ', state: createWorkflowState(), }, ], @@ -198,7 +194,6 @@ describe('workflow import/export helpers', () => { expect(parsed.data).toMatchObject({ name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', skills: [ { name: 'Ignore me', @@ -220,23 +215,21 @@ describe('workflow import/export helpers', () => { }) }) - it('rejects invalid workflow envelopes', () => { - const parsed = parseImportedWorkflowFile({ + it('ignores generated workflow presentation color in transfer records', () => { + expect(parseImportedWorkflowFile({ version: '1', - fileType: 'wrongFileType', + fileType: 'tradingGooseExport', exportedAt: '2026-04-08T15:30:00.000Z', exportedFrom: 'workflowEditor', - resourceTypes: ['skills'], + resourceTypes: ['workflows'], workflows: [ { name: 'Primary Workflow', + color: '#3972F6', state: createWorkflowState(), }, ], - }) - - expect(parsed.data).toBeNull() - expect(parsed.errors[0]).toContain('Unsupported JSON format') + }).errors).toEqual([]) }) it('renames duplicate imported workflows with the imported marker', () => { diff --git a/apps/tradinggoose/lib/workflows/import-export.ts b/apps/tradinggoose/lib/workflows/import-export.ts index a6eb3b7af..5fd04f285 100644 --- a/apps/tradinggoose/lib/workflows/import-export.ts +++ b/apps/tradinggoose/lib/workflows/import-export.ts @@ -27,7 +27,6 @@ const formatZodIssue = (issue: z.ZodIssue) => { export interface WorkflowTransferRecord { name: string description: string - color: string state: ExportWorkflowState['state'] skills: SkillTransferRecord[] } @@ -37,7 +36,6 @@ type WorkflowSkillSource = Pick { { name: 'Primary Workflow', description: 'Workflow imported from the unified schema', - color: '#3972F6', state: { blocks: { block_1: { @@ -45,13 +44,11 @@ describe('workflow import orchestration', () => { name: string description: string workspaceId: string - color?: string }) => { callOrder.push('createWorkflow') expect(params).toMatchObject({ name: 'Primary Workflow (imported) 1', description: 'Workflow imported from the unified schema', - color: '#3972F6', workspaceId: 'workspace-1', }) return 'workflow-1' @@ -115,7 +112,6 @@ describe('workflow import orchestration', () => { { name: 'Primary Workflow', description: 'Workflow imported from the unified schema', - color: '#3972F6', state: { blocks: { block_1: { @@ -176,12 +172,10 @@ describe('workflow import orchestration', () => { name: string description: string workspaceId: string - color?: string }) => { expect(params).toMatchObject({ name: 'Primary Workflow (imported) 1', description: 'Workflow imported from the unified schema', - color: '#3972F6', workspaceId: 'workspace-1', }) return 'workflow-1' diff --git a/apps/tradinggoose/lib/workflows/import.ts b/apps/tradinggoose/lib/workflows/import.ts index 9d133bc74..ec8eae3ef 100644 --- a/apps/tradinggoose/lib/workflows/import.ts +++ b/apps/tradinggoose/lib/workflows/import.ts @@ -15,7 +15,6 @@ type CreateWorkflowParams = { name: string description: string workspaceId: string - color?: string } type ImportWorkflowFromJsonContentParams = { @@ -121,7 +120,6 @@ export async function importWorkflowFromJsonContent({ const workflowId = await createWorkflow({ name: resolvedName, description: workflowData.description, - color: workflowData.color.length > 0 ? workflowData.color : undefined, workspaceId, }) diff --git a/apps/tradinggoose/lib/workflows/queued-execution-client.ts b/apps/tradinggoose/lib/workflows/queued-execution-client.ts index b4ea6d338..7c67c920d 100644 --- a/apps/tradinggoose/lib/workflows/queued-execution-client.ts +++ b/apps/tradinggoose/lib/workflows/queued-execution-client.ts @@ -2,16 +2,17 @@ import type { WorkflowExecutionEvent } from '@/lib/workflows/execution-events' import { isExecutionResult } from '@/lib/workflows/execution-result' import type { WorkflowExecutionBlueprint } from '@/lib/workflows/execution-runner' import type { ExecutionResult } from '@/executor/types' +import type { QueuedWorkflowTriggerType } from '@/services/queue' type QueuedWorkflowExecutionRequest = { workflowId: string executionId?: string input?: unknown - triggerType: 'api' | 'manual' | 'chat' + triggerType: QueuedWorkflowTriggerType executionTarget: 'deployed' | 'live' workflowData?: WorkflowExecutionBlueprint['workflowData'] workflowVariables?: Record - startBlockId?: string + triggerBlockId?: string selectedOutputs?: string[] stream?: boolean signal?: AbortSignal @@ -91,7 +92,7 @@ export async function queueWorkflowExecution( executionTarget: request.executionTarget, workflowData: request.workflowData, workflowVariables: request.workflowVariables, - startBlockId: request.startBlockId, + triggerBlockId: request.triggerBlockId, selectedOutputs: request.selectedOutputs, stream: request.stream === true, }), diff --git a/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.test.ts b/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.test.ts index 65abb7676..f1db52c16 100644 --- a/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.test.ts +++ b/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.test.ts @@ -2,7 +2,9 @@ import { describe, expect, it } from 'vitest' import { applyAutoLayout } from '@/lib/workflows/autolayout' import { buildWorkflowDocumentPreviewDiff, + parseGraphOnlyWorkflowMermaid, parseTgMermaidToWorkflow, + serializeWorkflowToGraphMermaid, serializeWorkflowToTgMermaid, TG_MERMAID_DOCUMENT_FORMAT, } from '@/lib/workflows/studio-workflow-mermaid' @@ -395,6 +397,77 @@ n3 --> n4 ]) }) + it('parses ordinary graph-only Mermaid aliases without flattening containers', () => { + const parsed = parseGraphOnlyWorkflowMermaid( + [ + 'flowchart TD', + 'sink["Send Alert"]', + 'subgraph loop_parent["For Each Symbol"]', + ' loop_child["Generate Signal"]', + 'end', + 'sink --> loop_parent', + ].join('\n'), + workflowState.blocks + ) + + expect(parsed.blocks.find((block) => block.blockId === 'loop_child')?.parentId).toBe( + 'loop_parent' + ) + expect(parsed.edges).toContainEqual({ + source: 'sink', + target: 'loop_parent', + targetHandle: 'target', + }) + }) + + it('serializes empty graph-only containers with boundary nodes', () => { + const document = serializeWorkflowToGraphMermaid({ + direction: 'TD', + blocks: { + loop1: { + id: 'loop1', + type: 'loop', + name: 'Loop', + position: { x: 0, y: 0 }, + enabled: true, + subBlocks: {}, + outputs: {}, + }, + sink: { + id: 'sink', + type: 'telegram', + name: 'Sink', + position: { x: 320, y: 0 }, + enabled: true, + subBlocks: {}, + outputs: {}, + }, + }, + edges: [{ id: 'e1', source: 'loop1', target: 'sink', sourceHandle: 'loop-end-source' }], + loops: {}, + parallels: {}, + }) + + expect(document).toContain('n1__loop_start["Loop Start"]') + expect(document).toContain('n1__loop_end["Loop End"]') + expect(document).toContain('n1__loop_end --> n2') + expect(() => parseGraphOnlyWorkflowMermaid(document, {})).not.toThrow() + }) + + it('rejects shorthand graph-only condition edge handles', () => { + expect(() => + parseGraphOnlyWorkflowMermaid( + [ + 'flowchart TD', + 'gate["Market Hours?
id: gate
type: condition"]', + 'sink["Send Alert
id: sink
type: telegram"]', + 'gate -- "if -> target" --> sink', + ].join('\n'), + workflowState.blocks + ) + ).toThrow('must use canonical sourceHandle "condition-gate-"') + }) + it('rejects visible external edges into container internal endpoint nodes', () => { for (const [endpoint, message] of [ ['n2__parallel_end', 'end node only accepts edges from blocks inside that container'], diff --git a/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.ts b/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.ts index 51b5512dd..81d2f7a29 100644 --- a/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.ts +++ b/apps/tradinggoose/lib/workflows/studio-workflow-mermaid.ts @@ -27,7 +27,7 @@ type ConditionEntry = { type MermaidLabelOverlay = { id: string - name: string + name?: string type?: string enabled?: boolean advancedMode?: boolean @@ -35,6 +35,7 @@ type MermaidLabelOverlay = { outputs?: Record dataEntries: Record subBlockEntries: Record + internalFields: string[] } type ConditionBranchOverlay = { @@ -61,6 +62,12 @@ type ParsedVisibleWorkflowEdges = { inferredParentIds: Map } +export type GraphOnlyWorkflowMermaid = { + direction: WorkflowDirection + blocks: Array<{ blockId: string; blockType?: string; name?: string; parentId?: string }> + edges: Array> +} + const COMMENT_PREFIX = '%% ' export const TG_WORKFLOW_PREFIX = `${COMMENT_PREFIX}TG_WORKFLOW ` export const TG_BLOCK_PREFIX = `${COMMENT_PREFIX}TG_BLOCK ` @@ -68,6 +75,18 @@ export const TG_EDGE_PREFIX = `${COMMENT_PREFIX}TG_EDGE ` const TG_LOOP_PREFIX = `${COMMENT_PREFIX}TG_LOOP ` const TG_PARALLEL_PREFIX = `${COMMENT_PREFIX}TG_PARALLEL ` const CONDITION_INPUT_KEY = 'conditions' +const HIDDEN_VISIBLE_EDGE_HANDLES = new Set([ + 'source', + 'target', + 'input', + 'output', + 'loop-start-source', + 'loop-end-source', + 'parallel-start-source', + 'parallel-end-source', + 'loop-end-target', + 'parallel-end-target', +]) function toDocumentJson(value: unknown): string { return stableStringifyJsonValue(value) @@ -101,7 +120,7 @@ function resolveBlockIdFromVisibleNodeId( } function parseRectNodeLine(line: string): { nodeId: string; label: string } | null { - const rectMatch = line.match(/^([A-Za-z0-9_]+)(?:\(\["(.*)"\]\)|\["(.*)"\])$/) + const rectMatch = line.match(/^([A-Za-z0-9_-]+)(?:\(\["(.*)"\]\)|\["(.*)"\])$/) const label = rectMatch?.[2] ?? rectMatch?.[3] if (!rectMatch?.[1] || !label) { @@ -325,6 +344,10 @@ function buildBlockLabelLines(blockId: string, block: BlockState): string[] { return lines } +function buildGraphOnlyBlockLabelLines(blockId: string, block: BlockState): string[] { + return [block.name || block.type, `id: ${blockId}`, `type: ${block.type}`] +} + function renderRectNode(nodeId: string, labelLines: string[], indent: string): string { return `${indent}${nodeId}["${escapeMermaidLabel(labelLines.join('\n'))}"]` } @@ -378,9 +401,20 @@ function emitBlockGraphLines(params: { aliases: Map childrenByParent: Map lines: string[] + labelLinesForBlock?: (blockId: string, block: BlockState) => string[] + includeConditionBranches?: boolean indent?: string }): void { - const { blockId, blocks, aliases, childrenByParent, lines, indent = ' ' } = params + const { + blockId, + blocks, + aliases, + childrenByParent, + lines, + labelLinesForBlock = buildBlockLabelLines, + includeConditionBranches = true, + indent = ' ', + } = params const block = blocks[blockId] const alias = aliases.get(blockId) @@ -388,10 +422,10 @@ function emitBlockGraphLines(params: { return } - const labelLines = buildBlockLabelLines(blockId, block) + const labelLines = labelLinesForBlock(blockId, block) const children = childrenByParent.get(blockId) ?? [] - if (block.type === 'condition') { + if (block.type === 'condition' && includeConditionBranches) { const conditionEntries = parseConditionEntries(block.subBlocks?.[CONDITION_INPUT_KEY]?.value) lines.push(`${indent}subgraph sg_${alias}["${escapeMermaidLabel(labelLines.join('\n'))}"]`) @@ -409,7 +443,12 @@ function emitBlockGraphLines(params: { return } - if (children.length === 0 || (block.type !== 'loop' && block.type !== 'parallel')) { + if (block.type === 'condition') { + lines.push(renderDiamondNode(alias, labelLines, indent)) + return + } + + if (block.type !== 'loop' && block.type !== 'parallel') { lines.push(renderRectNode(alias, labelLines, indent)) return } @@ -429,6 +468,8 @@ function emitBlockGraphLines(params: { aliases, childrenByParent, lines, + labelLinesForBlock, + includeConditionBranches, indent: `${indent} `, }) } @@ -552,20 +593,10 @@ function resolveVisibleEdgeLabel(edge: Edge, blocks: Record) } } - const hiddenHandles = new Set([ - 'source', - 'target', - 'input', - 'output', - 'loop-start-source', - 'loop-end-source', - 'parallel-start-source', - 'parallel-end-source', - 'loop-end-target', - 'parallel-end-target', - ]) - - if (hiddenHandles.has(sourceHandle) && hiddenHandles.has(targetHandle)) { + if ( + HIDDEN_VISIBLE_EDGE_HANDLES.has(sourceHandle) && + HIDDEN_VISIBLE_EDGE_HANDLES.has(targetHandle) + ) { return null } @@ -593,6 +624,60 @@ function emitEdgeGraphLine( return ` ${sourceNodeId} -- "${escapeMermaidLabel(label)}" --> ${targetNodeId}` } +function resolveGraphOnlySourceNodeId( + edge: Edge, + blocks: Record, + aliases: Map +): string | null { + const sourceAlias = aliases.get(edge.source) + const sourceBlock = blocks[edge.source] + + if (!sourceAlias || !sourceBlock) return sourceAlias ?? null + if (sourceBlock.type === 'loop') { + if (edge.sourceHandle === 'loop-start-source') { + return createContainerNodeId(sourceAlias, 'loop', 'start') + } + if (edge.sourceHandle === 'loop-end-source') { + return createContainerNodeId(sourceAlias, 'loop', 'end') + } + } + if (sourceBlock.type === 'parallel') { + if (edge.sourceHandle === 'parallel-start-source') { + return createContainerNodeId(sourceAlias, 'parallel', 'start') + } + if (edge.sourceHandle === 'parallel-end-source') { + return createContainerNodeId(sourceAlias, 'parallel', 'end') + } + } + return sourceAlias +} + +function resolveGraphOnlyEdgeLabel(edge: Edge): string | null { + const sourceHandle = edge.sourceHandle || 'source' + const targetHandle = edge.targetHandle || 'target' + + return HIDDEN_VISIBLE_EDGE_HANDLES.has(sourceHandle) && + HIDDEN_VISIBLE_EDGE_HANDLES.has(targetHandle) + ? null + : `${sourceHandle} -> ${targetHandle}` +} + +function emitGraphOnlyEdgeGraphLine( + edge: Edge, + blocks: Record, + aliases: Map +): string | null { + const sourceNodeId = resolveGraphOnlySourceNodeId(edge, blocks, aliases) + const targetNodeId = resolveVisibleTargetNodeId(edge, blocks, aliases, aliases) + + if (!sourceNodeId || !targetNodeId) return null + + const label = resolveGraphOnlyEdgeLabel(edge) + return label + ? ` ${sourceNodeId} -- "${escapeMermaidLabel(label)}" --> ${targetNodeId}` + : ` ${sourceNodeId} --> ${targetNodeId}` +} + function parseCommentPayload(line: string, prefix: string): T | null { if (!line.startsWith(prefix)) { return null @@ -628,16 +713,17 @@ function parseOverlayFromLabel(label: string): MermaidLabelOverlay | null { const overlay: MermaidLabelOverlay = { id: '', - name: lines[0], dataEntries: {}, subBlockEntries: {}, + internalFields: [], } const conditionEntries: ConditionEntry[] = [] - for (const line of lines.slice(1)) { + for (const line of lines) { const separatorIndex = line.indexOf(':') if (separatorIndex === -1) { + overlay.name ??= line continue } @@ -653,18 +739,22 @@ function parseOverlayFromLabel(label: string): MermaidLabelOverlay | null { continue } if (rawKey === 'enabled') { + overlay.internalFields.push(rawKey) overlay.enabled = Boolean(parseLabelValue(rawValue)) continue } if (rawKey === 'advancedMode') { + overlay.internalFields.push(rawKey) overlay.advancedMode = Boolean(parseLabelValue(rawValue)) continue } if (rawKey === 'triggerMode') { + overlay.internalFields.push(rawKey) overlay.triggerMode = Boolean(parseLabelValue(rawValue)) continue } if (rawKey === 'outputs') { + overlay.internalFields.push(rawKey) const parsed = parseLabelValue(rawValue) if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { overlay.outputs = parsed as Record @@ -672,10 +762,12 @@ function parseOverlayFromLabel(label: string): MermaidLabelOverlay | null { continue } if (rawKey.startsWith('data.')) { + overlay.internalFields.push(rawKey) overlay.dataEntries[rawKey.slice('data.'.length)] = parseLabelValue(rawValue) continue } if (rawKey.startsWith('subBlocks.')) { + overlay.internalFields.push(rawKey) overlay.subBlockEntries[rawKey.slice('subBlocks.'.length)] = parseLabelValue(rawValue) continue } @@ -685,8 +777,12 @@ function parseOverlayFromLabel(label: string): MermaidLabelOverlay | null { rawKey === 'else-if' || rawKey.startsWith('else-if-') ) { + overlay.internalFields.push(`subBlocks.${CONDITION_INPUT_KEY}`) conditionEntries.push({ key: rawKey, value: rawValue }) + continue } + + overlay.name ??= line } if (overlay.id.length === 0) { @@ -798,7 +894,7 @@ function parseMermaidLabelOverlays( continue } - const diamondMatch = trimmed.match(/^[A-Za-z0-9_]+\{"(.*)"\}$/) + const diamondMatch = trimmed.match(/^[A-Za-z0-9_-]+\{"(.*)"\}$/) if (diamondMatch?.[1]) { const overlay = parseOverlayFromLabel(diamondMatch[1]) if (overlay) { @@ -810,6 +906,50 @@ function parseMermaidLabelOverlays( return { blocks, conditionBranches } } +function readGraphOnlyInternalFields(overlay: MermaidLabelOverlay): string[] { + return [...new Set(overlay.internalFields)] +} + +function readGraphOnlyDirectExistingBlockNames( + document: string, + existingBlockIds: Set +): Map { + const names = new Map() + + for (const rawLine of document.split(/\r?\n/)) { + const trimmed = rawLine.trim() + const node = + parseRectNodeLine(trimmed) ?? + (() => { + const diamondMatch = trimmed.match(/^([A-Za-z0-9_-]+)\{"(.*)"\}$/) + return diamondMatch?.[1] && diamondMatch[2] + ? { nodeId: diamondMatch[1], label: diamondMatch[2] } + : null + })() ?? + (() => { + const subgraphMatch = trimmed.match(/^subgraph\s+([A-Za-z0-9_-]+)\["(.*)"\]$/) + return subgraphMatch?.[1] && subgraphMatch[2] + ? { nodeId: subgraphMatch[1], label: subgraphMatch[2] } + : null + })() + + if (!node || !existingBlockIds.has(node.nodeId) || parseOverlayFromLabel(node.label)) { + continue + } + + const name = unescapeMermaidLabel(node.label) + .split('\n') + .map((line) => line.trim()) + .find((line) => line.length > 0) + + if (name) { + names.set(node.nodeId, name) + } + } + + return names +} + function parseVisibleEdgeLabel( rawLabel: string ): { sourceHandle: string; targetHandle: string } | null { @@ -902,6 +1042,22 @@ function parseVisibleWorkflowEdges( return chain } + const registerBlockRef = ( + nodeId: string, + blockId: string, + blockType: string | undefined, + parentId: string | null + ) => { + nodeRefs.set(nodeId, { kind: 'block', blockId, blockType }) + visibleBlockIds.add(blockId) + if (parentId && parentId !== blockId) { + inferredParentIds.set(blockId, parentId) + } + if (!preferredBlockNodeIds.has(blockId)) { + preferredBlockNodeIds.set(blockId, nodeId) + } + } + for (const rawLine of document.split(/\r?\n/)) { const trimmed = rawLine.trim() @@ -910,27 +1066,34 @@ function parseVisibleWorkflowEdges( continue } - const subgraphMatch = trimmed.match(/^subgraph\s+(sg_[A-Za-z0-9_]+)\["(.*)"\]$/) + const subgraphMatch = trimmed.match(/^subgraph\s+([A-Za-z0-9_-]+)\["(.*)"\]$/) if (subgraphMatch?.[1] && subgraphMatch[2]) { const currentContainerId = getActiveContainerId() const overlay = parseOverlayFromLabel(subgraphMatch[2]) if (overlay) { - const nodeId = subgraphMatch[1].slice(3) - aliasToBlockId.set(nodeId, overlay.id) - nodeRefs.set(nodeId, { kind: 'block', blockId: overlay.id, blockType: overlay.type }) - visibleBlockIds.add(overlay.id) - if (currentContainerId && currentContainerId !== overlay.id) { - inferredParentIds.set(overlay.id, currentContainerId) - } - if (!preferredBlockNodeIds.has(overlay.id)) { - preferredBlockNodeIds.set(overlay.id, nodeId) + const nodeId = subgraphMatch[1] + const edgeNodeId = nodeId.startsWith('sg_') ? nodeId.slice('sg_'.length) : nodeId + for (const visibleNodeId of [edgeNodeId, nodeId]) { + aliasToBlockId.set(visibleNodeId, overlay.id) + registerBlockRef(visibleNodeId, overlay.id, overlay.type, currentContainerId) } subgraphStack.push({ blockId: overlay.id, isContainer: overlay.type === 'loop' || overlay.type === 'parallel', }) } else { - subgraphStack.push({ blockId: null, isContainer: false }) + const directBlockId = resolveBlockIdFromVisibleNodeId( + subgraphMatch[1], + knownBlockIdSet, + aliasToBlockId + ) + const directBlockType = directBlockId ? blocks[directBlockId]?.type : undefined + if (directBlockId && isContainerBlockType(directBlockType)) { + registerBlockRef(subgraphMatch[1], directBlockId, directBlockType, currentContainerId) + subgraphStack.push({ blockId: directBlockId, isContainer: true }) + } else { + subgraphStack.push({ blockId: null, isContainer: false }) + } } continue } @@ -956,36 +1119,18 @@ function parseVisibleWorkflowEdges( const directBlockId = resolveBlockIdFromVisibleNodeId(nodeId, knownBlockIdSet, aliasToBlockId) if (directBlockId) { - nodeRefs.set(nodeId, { - kind: 'block', - blockId: directBlockId, - blockType: blocks[directBlockId]?.type, - }) - visibleBlockIds.add(directBlockId) - if (currentContainerId && currentContainerId !== directBlockId) { - inferredParentIds.set(directBlockId, currentContainerId) - } - if (!preferredBlockNodeIds.has(directBlockId)) { - preferredBlockNodeIds.set(directBlockId, nodeId) - } + registerBlockRef(nodeId, directBlockId, blocks[directBlockId]?.type, currentContainerId) continue } const overlay = parseOverlayFromLabel(rectNode.label) if (overlay) { aliasToBlockId.set(nodeId, overlay.id) - nodeRefs.set(nodeId, { kind: 'block', blockId: overlay.id, blockType: overlay.type }) - visibleBlockIds.add(overlay.id) - if (currentContainerId && currentContainerId !== overlay.id) { - inferredParentIds.set(overlay.id, currentContainerId) - } - if (!preferredBlockNodeIds.has(overlay.id)) { - preferredBlockNodeIds.set(overlay.id, nodeId) - } + registerBlockRef(nodeId, overlay.id, overlay.type, currentContainerId) continue } - const containerMatch = nodeId.match(/^([A-Za-z0-9_]+)__(loop|parallel)_(start|end)$/) + const containerMatch = nodeId.match(/^([A-Za-z0-9_-]+)__(loop|parallel)_(start|end)$/) if (containerMatch?.[1] && containerMatch[2] && containerMatch[3]) { const blockId = resolveBlockIdFromVisibleNodeId( containerMatch[1], @@ -1004,24 +1149,13 @@ function parseVisibleWorkflowEdges( continue } - const diamondMatch = trimmed.match(/^([A-Za-z0-9_]+)\{"(.*)"\}$/) + const diamondMatch = trimmed.match(/^([A-Za-z0-9_-]+)\{"(.*)"\}$/) if (diamondMatch?.[1] && diamondMatch[2]) { const currentContainerId = getActiveContainerId() const overlay = parseOverlayFromLabel(diamondMatch[2]) if (overlay) { aliasToBlockId.set(diamondMatch[1], overlay.id) - nodeRefs.set(diamondMatch[1], { - kind: 'block', - blockId: overlay.id, - blockType: overlay.type, - }) - visibleBlockIds.add(overlay.id) - if (currentContainerId && currentContainerId !== overlay.id) { - inferredParentIds.set(overlay.id, currentContainerId) - } - if (!preferredBlockNodeIds.has(overlay.id)) { - preferredBlockNodeIds.set(overlay.id, diamondMatch[1]) - } + registerBlockRef(diamondMatch[1], overlay.id, overlay.type, currentContainerId) } } } @@ -1031,7 +1165,7 @@ function parseVisibleWorkflowEdges( for (const rawLine of document.split(/\r?\n/)) { const trimmed = rawLine.trim() const edgeMatch = trimmed.match( - /^([A-Za-z0-9_]+)\s*(?:--\s*"((?:\\"|[^"])*)"\s*)?-->\s*([A-Za-z0-9_]+)$/ + /^([A-Za-z0-9_-]+)\s*(?:--\s*"((?:\\"|[^"])*)"\s*)?-->\s*([A-Za-z0-9_-]+)$/ ) if (!edgeMatch?.[1] || !edgeMatch[3]) { continue @@ -1040,7 +1174,9 @@ function parseVisibleWorkflowEdges( const sourceRef = nodeRefs.get(edgeMatch[1]) const targetRef = nodeRefs.get(edgeMatch[3]) if (!sourceRef || !targetRef) { - continue + throw new Error( + `Workflow graph Mermaid edge "${edgeMatch[1]} --> ${edgeMatch[3]}" references unknown node id.` + ) } if ( @@ -1055,11 +1191,11 @@ function parseVisibleWorkflowEdges( const sourceAncestors = getVisibleAncestorChain(sourceRef.blockId) const visibleEndpointViolation = targetRef.kind === 'container-start' - ? `Invalid visible container edge: ${targetRef.blockId} start node is source-only. Use the ${targetRef.blockId} container block alias in the visible line and targetHandle "target" in TG_EDGE metadata for incoming edges.` + ? `Invalid container edge: ${targetRef.blockId} start node is source-only. Use the ${targetRef.blockId} container block alias in the visible line and targetHandle "target" in TG_EDGE metadata for incoming edges.` : targetRef.kind === 'container-end' && !sourceAncestors.includes(targetRef.blockId) - ? `Invalid visible container edge: ${targetRef.blockId} end node only accepts edges from blocks inside that container. Use the ${targetRef.blockId} container block alias in the visible line and targetHandle "target" in TG_EDGE metadata for incoming outer edges.` + ? `Invalid container edge: ${targetRef.blockId} end node only accepts edges from blocks inside that container. Use the ${targetRef.blockId} container block alias in the visible line and targetHandle "target" in TG_EDGE metadata for incoming outer edges.` : sourceRef.kind === 'container-start' && !targetAncestors.includes(sourceRef.blockId) - ? `Invalid visible container edge: ${sourceRef.blockId} start node only connects to blocks inside that container. Use the ${sourceRef.blockId} container block alias for outer workflow edges.` + ? `Invalid container edge: ${sourceRef.blockId} start node only connects to blocks inside that container. Use the ${sourceRef.blockId} container block alias for outer workflow edges.` : null if (visibleEndpointViolation) throw new Error(visibleEndpointViolation) @@ -1082,6 +1218,17 @@ function parseVisibleWorkflowEdges( ? `${targetRef.blockType}-end-target` : 'target') + const sourceBlock = blocks[sourceRef.blockId] + const conditionHandlePrefix = `condition-${sourceRef.blockId}-` + if ( + sourceBlock?.type === 'condition' && + !sourceHandle.startsWith(conditionHandlePrefix) + ) { + throw new Error( + `Workflow graph Mermaid condition edge from "${sourceRef.blockId}" must use canonical sourceHandle "${conditionHandlePrefix}". Use edit_workflow_block to define condition branches before wiring them.` + ) + } + visibleEdges.push({ source: sourceRef.blockId, target: targetRef.blockId, @@ -1383,6 +1530,121 @@ function applyVisibleParenting( return nextBlocks } +export function parseGraphOnlyWorkflowMermaid( + document: string, + existingBlocks: Record +): GraphOnlyWorkflowMermaid { + const directionMatch = document.trimStart().match(/^flowchart\s+(TD|LR)\b/) + if (!directionMatch?.[1]) { + throw new Error('Workflow graph Mermaid must start with `flowchart TD` or `flowchart LR`.') + } + + for (const line of document.split(/\r?\n/)) { + const trimmed = line.trim() + if ( + trimmed.startsWith(TG_WORKFLOW_PREFIX) || + trimmed.startsWith(TG_BLOCK_PREFIX) || + trimmed.startsWith(TG_EDGE_PREFIX) || + trimmed.startsWith(TG_LOOP_PREFIX) || + trimmed.startsWith(TG_PARALLEL_PREFIX) + ) { + throw new Error( + 'Workflow graph Mermaid must not include TG_* metadata comments. Send only visible Mermaid nodes, subgraphs, and edges.' + ) + } + } + + const existingBlockIds = new Set(Object.keys(existingBlocks)) + const blockOverlays = parseMermaidLabelOverlays(document, Object.keys(existingBlocks)) + for (const [blockId, name] of readGraphOnlyDirectExistingBlockNames(document, existingBlockIds)) { + if (!blockOverlays.blocks.has(blockId)) { + blockOverlays.blocks.set(blockId, { + id: blockId, + name, + dataEntries: {}, + subBlockEntries: {}, + internalFields: [], + }) + } + } + const graphBlocks: Record = { ...existingBlocks } + + for (const [blockId, overlay] of blockOverlays.blocks) { + const internalFields = readGraphOnlyInternalFields(overlay) + if (internalFields.length > 0) { + throw new Error( + `Workflow graph Mermaid block "${blockId}" includes block-internal fields (${internalFields.join(', ')}). Use edit_workflow_block to change block configuration; edit_workflow only accepts visible graph labels: name, id, and type.` + ) + } + + if (!graphBlocks[blockId]) { + graphBlocks[blockId] = { + id: blockId, + type: overlay.type ?? 'unknown', + name: overlay.name ?? '', + position: { x: 0, y: 0 }, + subBlocks: {}, + outputs: {}, + enabled: true, + } + } + } + + if (parseMermaidLabelOverlays(document, Object.keys(graphBlocks)).conditionBranches.size > 0) { + throw new Error( + 'Workflow graph Mermaid must not include condition branch labels. Use edit_workflow_block to change condition branch definitions.' + ) + } + + const visibleGraph = parseVisibleWorkflowEdges(document, graphBlocks) + const blocksWithVisibleParenting = applyVisibleParenting( + graphBlocks, + visibleGraph.visibleBlockIds, + visibleGraph.inferredParentIds + ) + const edges = normalizeLogicalWorkflowEdges(visibleGraph.edges, blocksWithVisibleParenting).map( + ({ source, target, sourceHandle, targetHandle }) => ({ + source, + target, + ...(sourceHandle ? { sourceHandle } : {}), + ...(targetHandle ? { targetHandle } : {}), + }) + ) + for (const edge of edges) { + const conditionKey = extractConditionDisplayKey(edge.source, edge.sourceHandle) + if (!conditionKey) continue + + const sourceBlock = blocksWithVisibleParenting[edge.source] + const existingConditionKeys = new Set( + parseConditionEntries(sourceBlock?.subBlocks?.[CONDITION_INPUT_KEY]?.value).map( + (entry) => entry.key + ) + ) + if (sourceBlock?.type !== 'condition' || !existingConditionKeys.has(conditionKey)) { + throw new Error( + `Workflow graph Mermaid references unknown condition branch "${conditionKey}" on block "${edge.source}". Use edit_workflow_block to define condition branches before wiring them.` + ) + } + } + + return { + direction: directionMatch[1] as WorkflowDirection, + blocks: [...visibleGraph.visibleBlockIds].map((blockId) => { + const block = blocksWithVisibleParenting[blockId] + const overlay = blockOverlays.blocks.get(blockId) + return { + blockId, + ...((overlay?.type ?? block?.type) && (overlay?.type ?? block?.type) !== 'unknown' + ? { blockType: overlay?.type ?? block?.type } + : {}), + ...(overlay?.name ? { name: overlay.name } : {}), + ...(block?.data?.parentId ? { parentId: block.data.parentId } : {}), + } + }), + edges, + } +} + function syncContainerNodeMembership( blocks: Record, loops: Record, @@ -1705,6 +1967,44 @@ export function serializeWorkflowToTgMermaid( return lines.join('\n') } +export function serializeWorkflowToGraphMermaid( + workflowState: WorkflowSnapshot, + options: { direction?: WorkflowDirection } = {} +): string { + const direction = + options.direction ?? + workflowState.direction ?? + inferMermaidDirectionFromWorkflowState(workflowState) + const blocks = workflowState.blocks ?? {} + const blockIds = Object.keys(blocks).sort((left, right) => left.localeCompare(right)) + const aliases = buildAliasMap(blockIds) + const childrenByParent = getChildrenByParent(blocks) + const rootBlockIds = blockIds.filter((blockId) => { + const parentId = blocks[blockId]?.data?.parentId + return !parentId || !blocks[parentId] + }) + const lines = [`flowchart ${direction}`] + + for (const blockId of rootBlockIds) { + emitBlockGraphLines({ + blockId, + blocks, + aliases, + childrenByParent, + lines, + labelLinesForBlock: buildGraphOnlyBlockLabelLines, + includeConditionBranches: false, + }) + } + + for (const edge of workflowState.edges ?? []) { + const line = emitGraphOnlyEdgeGraphLine(edge, blocks, aliases) + if (line) lines.push(line) + } + + return lines.join('\n') +} + export function parseTgMermaidToWorkflow( document: string ): WorkflowSnapshot & { direction: WorkflowDirection } { diff --git a/apps/tradinggoose/lib/workflows/subblock-values.ts b/apps/tradinggoose/lib/workflows/subblock-values.ts index b3d1279dc..33dd7a8ac 100644 --- a/apps/tradinggoose/lib/workflows/subblock-values.ts +++ b/apps/tradinggoose/lib/workflows/subblock-values.ts @@ -89,6 +89,32 @@ export function resolveInitialSubBlockValue( return '' } +export function buildInitialSubBlockStates( + subBlockConfigs: SubBlockConfig[], + initialValues?: Record +): Record { + const subBlocks: Record = {} + const resolvedSubBlockParams: Record = {} + + for (const subBlock of subBlockConfigs) { + const resolvedInitialValue = resolveInitialSubBlockValue( + subBlock, + resolvedSubBlockParams, + initialValues?.[subBlock.id] + ) + + subBlocks[subBlock.id] = { + id: subBlock.id, + type: subBlock.type, + value: resolvedInitialValue, + } + + resolvedSubBlockParams[subBlock.id] = resolvedInitialValue + } + + return subBlocks +} + export function resolveDisplayedSubBlockValue( subBlock: Pick, value: unknown diff --git a/apps/tradinggoose/lib/workflows/triggers.test.ts b/apps/tradinggoose/lib/workflows/triggers.test.ts new file mode 100644 index 000000000..a87e51122 --- /dev/null +++ b/apps/tradinggoose/lib/workflows/triggers.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it, vi } from 'vitest' + +vi.unmock('@/blocks/registry') + +import { listWorkflowRunTriggers, resolveWorkflowRunTrigger } from './triggers' + +const block = (type: string, extra: Record = {}) => ({ + type, + enabled: true, + subBlocks: {}, + ...extra, +}) +const edge = (source: string, target = 'agent') => ({ source, target }) + +describe('workflow run trigger resolution', () => { + it('lists one Run option per resolved trigger identity', () => { + const edges = ['github', 'whatsapp', 'calendly', 'chat'].map((source) => edge(source)) + const runTriggers = listWorkflowRunTriggers( + { + github: block('github', { name: 'Production GitHub', triggerMode: true }), + whatsapp: block('whatsapp', { triggerMode: true }), + calendly: block('calendly', { + triggerMode: true, + subBlocks: { selectedTriggerId: { value: 'calendly_invitee_created' } }, + }), + disconnectedGithub: block('github', { triggerMode: true }), + chat: block('chat_trigger'), + }, + edges + ) + + expect(runTriggers.map(({ id, name }) => [id, name])).toEqual([ + ['github:github_webhook', 'Production GitHub'], + ['whatsapp:whatsapp_webhook', 'WhatsApp Webhook'], + ['calendly:calendly_invitee_created', 'Calendly Invitee Created'], + ]) + expect(runTriggers.every((trigger) => trigger.icon && trigger.color)).toBe(true) + }) + + it('generates editor test input while preserving explicit copilot input', () => { + const editorRun = resolveWorkflowRunTrigger( + { indicator: block('indicator_trigger') }, + [edge('indicator')], + { surface: 'editor', triggerBlockId: 'indicator' } + ) + + expect(editorRun.input).toMatchObject({ + listing: { listing_id: 'AAPL', base_id: '', quote_id: '', listing_type: 'default' }, + signal: 'mock_signal', + }) + expect( + (editorRun.blocks.indicator.subBlocks as Record).selectedTriggerId + ).toEqual({ + value: 'indicator_trigger', + }) + + const explicitInput = { listing: { listing_id: 'MSFT' }, signal: 'buy' } + expect( + resolveWorkflowRunTrigger({ indicator: block('indicator_trigger') }, [edge('indicator')], { + surface: 'copilot', + triggerBlockId: 'indicator', + workflowInput: explicitInput, + }).input + ).toBe(explicitInput) + }) + + it('surfaces selected trigger configuration errors', () => { + expect(() => + resolveWorkflowRunTrigger( + { calendly: block('calendly', { name: 'Calendly Lead Capture', triggerMode: true }) }, + [edge('calendly')], + { surface: 'editor', triggerBlockId: 'calendly' } + ) + ).toThrow('Calendly Lead Capture requires a selected trigger type') + }) +}) diff --git a/apps/tradinggoose/lib/workflows/triggers.ts b/apps/tradinggoose/lib/workflows/triggers.ts index bf8bcf9e1..b93139959 100644 --- a/apps/tradinggoose/lib/workflows/triggers.ts +++ b/apps/tradinggoose/lib/workflows/triggers.ts @@ -1,8 +1,12 @@ +import { sanitizeSolidIconColor } from '@/lib/ui/icon-colors' +import { readBlockOutputs } from '@/lib/workflows/block-outputs' import { getBlock } from '@/blocks' +import type { QueuedWorkflowTriggerType } from '@/services/queue' +import { getTrigger } from '@/triggers' +import { resolveTriggerExecutionIdentity } from '@/triggers/resolution' +import type { TriggerConfig } from '@/triggers/types' +import { generateMockPayloadFromOutputsDefinition } from './triggers/trigger-utils' -/** - * Unified trigger type definitions - */ export const TRIGGER_TYPES = { INPUT: 'input_trigger', MANUAL: 'manual_trigger', @@ -12,84 +16,13 @@ export const TRIGGER_TYPES = { SCHEDULE: 'schedule', } as const -export type TriggerType = (typeof TRIGGER_TYPES)[keyof typeof TRIGGER_TYPES] - -/** - * Mapping from reference alias (used in inline refs like , , etc.) - * to concrete trigger block type identifiers used across the system. - */ -export const TRIGGER_REFERENCE_ALIAS_MAP = { - start: TRIGGER_TYPES.INPUT, - api: TRIGGER_TYPES.API, - chat: TRIGGER_TYPES.CHAT, - manual: TRIGGER_TYPES.INPUT, -} as const - -export type TriggerReferenceAlias = keyof typeof TRIGGER_REFERENCE_ALIAS_MAP - -/** - * Trigger classification and utilities - */ export class TriggerUtils { - /** - * Check if a block is any kind of trigger - */ static isTriggerBlock(block: { type: string; triggerMode?: boolean }): boolean { const blockConfig = getBlock(block.type) - return ( - // New trigger blocks (explicit category) - blockConfig?.category === 'triggers' || - // Blocks with trigger mode enabled - block.triggerMode === true - ) - } - - /** - * Check if a block is a specific trigger type - */ - static isTriggerType(block: { type: string }, triggerType: TriggerType): boolean { - return block.type === triggerType - } - - /** - * Check if a type string is any trigger type - */ - static isAnyTriggerType(type: string): boolean { - return Object.values(TRIGGER_TYPES).includes(type as TriggerType) - } - - /** - * Check if a block is a chat trigger - */ - static isChatTrigger(block: { type: string; subBlocks?: any }): boolean { - return block.type === TRIGGER_TYPES.CHAT + return blockConfig?.category === 'triggers' || block.triggerMode === true } - /** - * Check if a block is a manual trigger - */ - static isManualTrigger(block: { type: string; subBlocks?: any }): boolean { - return block.type === TRIGGER_TYPES.INPUT || block.type === TRIGGER_TYPES.MANUAL - } - - /** - * Check if a block is an API trigger - * @param block - Block to check - * @param isChildWorkflow - Whether this is being called from a child workflow context - */ - static isApiTrigger(block: { type: string; subBlocks?: any }, isChildWorkflow = false): boolean { - if (isChildWorkflow) { - // Child workflows (workflow-in-workflow) only work with input_trigger - return block.type === TRIGGER_TYPES.INPUT - } - // Direct API calls only work with api_trigger - return block.type === TRIGGER_TYPES.API - } - - /** - * Get the default name for a trigger type - */ static getDefaultTriggerName(triggerType: string): string | null { const block = getBlock(triggerType) if ( @@ -107,120 +40,39 @@ export class TriggerUtils { return null } - /** - * Find trigger blocks of a specific type in a workflow - */ - static findTriggersByType( - blocks: T[] | Record, - triggerType: 'chat' | 'manual' | 'api', - isChildWorkflow = false - ): T[] { - const blockArray = Array.isArray(blocks) ? blocks : Object.values(blocks) - - switch (triggerType) { - case 'chat': - return blockArray.filter((block) => TriggerUtils.isChatTrigger(block)) - case 'manual': - return blockArray.filter((block) => TriggerUtils.isManualTrigger(block)) - case 'api': - return blockArray.filter((block) => TriggerUtils.isApiTrigger(block, isChildWorkflow)) - default: - return [] - } - } - - /** - * Find the appropriate start block for a given execution context - */ - static findStartBlock( + static findTriggerBlock( blocks: Record, executionType: 'chat' | 'manual' | 'api', isChildWorkflow = false ): { blockId: string; block: T } | null { - const entries = Object.entries(blocks) - - // Look for new trigger blocks first - const triggers = TriggerUtils.findTriggersByType(blocks, executionType, isChildWorkflow) - if (triggers.length > 0) { - const blockId = entries.find(([, b]) => b === triggers[0])?.[0] - if (blockId) { - return { blockId, block: triggers[0] } + const entry = Object.entries(blocks).find(([, block]) => { + if (executionType === 'chat') return block.type === TRIGGER_TYPES.CHAT + if (executionType === 'manual') { + return block.type === TRIGGER_TYPES.INPUT || block.type === TRIGGER_TYPES.MANUAL } - } + return isChildWorkflow ? block.type === TRIGGER_TYPES.INPUT : block.type === TRIGGER_TYPES.API + }) - return null + return entry ? { blockId: entry[0], block: entry[1] } : null } - /** - * Check if multiple triggers of a restricted type exist - */ - static hasMultipleTriggers( - blocks: T[] | Record, - triggerType: TriggerType - ): boolean { - const blockArray = Array.isArray(blocks) ? blocks : Object.values(blocks) - const count = blockArray.filter((block) => block.type === triggerType).length - return count > 1 - } - - /** - * Check if a trigger type requires single instance constraint - */ - static requiresSingleInstance(triggerType: string): boolean { - // Each trigger type can only have one instance of itself - // Manual and Input Form can coexist - // API, Chat triggers must be unique - // Schedules and webhooks can have multiple instances - return ( - triggerType === TRIGGER_TYPES.API || - triggerType === TRIGGER_TYPES.INPUT || - triggerType === TRIGGER_TYPES.MANUAL || - triggerType === TRIGGER_TYPES.CHAT - ) - } - - /** - * Check if adding a trigger would violate single instance constraint - */ static wouldViolateSingleInstance( blocks: T[] | Record, triggerType: string ): boolean { - const blockArray = Array.isArray(blocks) ? blocks : Object.values(blocks) - - // Only one Input trigger allowed - if (triggerType === TRIGGER_TYPES.INPUT) { - return blockArray.some((block) => block.type === TRIGGER_TYPES.INPUT) - } - - // Only one Manual trigger allowed - if (triggerType === TRIGGER_TYPES.MANUAL) { - return blockArray.some((block) => block.type === TRIGGER_TYPES.MANUAL) - } - - // Only one API trigger allowed - if (triggerType === TRIGGER_TYPES.API) { - return blockArray.some((block) => block.type === TRIGGER_TYPES.API) - } - - // Chat trigger must be unique - if (triggerType === TRIGGER_TYPES.CHAT) { - return blockArray.some((block) => block.type === TRIGGER_TYPES.CHAT) - } - - // Centralized rule: only API, Input, Chat are single-instance - if (!TriggerUtils.requiresSingleInstance(triggerType)) { + if ( + triggerType !== TRIGGER_TYPES.API && + triggerType !== TRIGGER_TYPES.INPUT && + triggerType !== TRIGGER_TYPES.MANUAL && + triggerType !== TRIGGER_TYPES.CHAT + ) { return false } + const blockArray = Array.isArray(blocks) ? blocks : Object.values(blocks) return blockArray.some((block) => block.type === triggerType) } - /** - * Evaluate whether adding a trigger of the given type is allowed and, if not, why. - * Returns null if allowed; otherwise returns an object describing the violation. - * This avoids duplicating UI logic across toolbar/drop handlers. - */ static getTriggerAdditionIssue( blocks: T[] | Record, triggerType: string @@ -229,24 +81,153 @@ export class TriggerUtils { return null } - // Otherwise treat as duplicate of a single-instance trigger const triggerName = TriggerUtils.getDefaultTriggerName(triggerType) || 'trigger' return { issue: 'duplicate', triggerName } } +} + +export type WorkflowRunTriggerBlock = { + type: string + name?: string + enabled?: boolean + triggerMode?: boolean + subBlocks?: Record +} + +type WorkflowRunSurface = 'editor' | 'copilot' +type WorkflowRunExecutionTriggerType = Extract + +export type WorkflowRunTriggerOption = { + id: string + blockId: string + name: string + triggerSource: string + icon?: TriggerConfig['icon'] + color: string +} + +function isWorkflowRunTriggerEntry( + blockId: string, + block: T | undefined, + edges: Array<{ source: string; target: string }> +): block is T { + return Boolean( + block?.type && + block.enabled !== false && + TriggerUtils.isTriggerBlock(block) && + edges.some((edge) => edge.source === blockId) + ) +} + +export function listWorkflowRunTriggers( + blocks: Record, + edges: Array<{ source: string; target: string }> +): WorkflowRunTriggerOption[] { + return Object.entries(blocks).flatMap(([blockId, block]) => { + if (!isWorkflowRunTriggerEntry(blockId, block, edges)) { + return [] + } + + try { + const identity = resolveTriggerExecutionIdentity(block) + if (identity.triggerType === 'chat') return [] + const trigger = getTrigger(identity.triggerSource)! + + return [ + { + id: `${blockId}:${identity.triggerSource}`, + blockId, + name: block.name || trigger.name, + triggerSource: identity.triggerSource, + icon: trigger.icon, + color: sanitizeSolidIconColor(getBlock(block.type)?.bgColor) ?? '#6B7280', + }, + ] + } catch { + return [] + } + }) +} + +function materializeTriggerSource( + block: T, + triggerSource: string +): T { + const selectedTriggerId = block.subBlocks?.selectedTriggerId + const nextSelectedTriggerId = + selectedTriggerId && typeof selectedTriggerId === 'object' && !Array.isArray(selectedTriggerId) + ? { ...selectedTriggerId, value: triggerSource } + : { value: triggerSource } + + return { + ...block, + triggerMode: true, + subBlocks: { + ...(block.subBlocks ?? {}), + selectedTriggerId: nextSelectedTriggerId, + }, + } +} + +function buildWorkflowRunTriggerInput( + block: WorkflowRunTriggerBlock, + workflowInput: unknown, + options: { preserveProvidedInput: boolean } +): unknown { + if (options.preserveProvidedInput && workflowInput !== undefined) { + return workflowInput + } - /** - * Get trigger validation message - */ - static getTriggerValidationMessage( - triggerType: 'chat' | 'manual' | 'api', - issue: 'missing' | 'multiple' - ): string { - const triggerName = triggerType.charAt(0).toUpperCase() + triggerType.slice(1) - - if (issue === 'missing') { - return `${triggerName} execution requires a ${triggerName} Trigger block` + const inputFormat = block.subBlocks?.inputFormat?.value + if (Array.isArray(inputFormat)) { + const testInput: Record = {} + for (const field of inputFormat) { + const name = field && typeof field === 'object' ? (field as { name?: unknown }).name : null + if (typeof name === 'string' && name.length > 0) { + testInput[name] = (field as { value?: unknown }).value + } } + return Object.keys(testInput).length > 0 ? testInput : (workflowInput ?? {}) + } + + const outputs = readBlockOutputs(block.type, block.subBlocks, true) + return Object.keys(outputs).length > 0 + ? generateMockPayloadFromOutputsDefinition(outputs) + : (workflowInput ?? {}) +} + +export function resolveWorkflowRunTrigger( + blocks: Record, + edges: Array<{ source: string; target: string }>, + options: { + surface: WorkflowRunSurface + workflowInput?: unknown + triggerBlockId: string + } +): { + blockId: string + blocks: Record + input: unknown + triggerType: WorkflowRunExecutionTriggerType +} { + const selectedBlock = blocks[options.triggerBlockId] + if (!isWorkflowRunTriggerEntry(options.triggerBlockId, selectedBlock, edges)) { + throw new Error(`Trigger block ${options.triggerBlockId} is not available for Run`) + } + if (options.surface === 'editor' && selectedBlock.type === TRIGGER_TYPES.CHAT) { + throw new Error(`Trigger block ${options.triggerBlockId} is not available for Run`) + } - return `Multiple ${triggerName} Trigger blocks found. Keep only one.` + const identity = resolveTriggerExecutionIdentity(selectedBlock) + const block = materializeTriggerSource(selectedBlock, identity.triggerSource) + const isChatRun = identity.triggerType === 'chat' + + return { + blockId: options.triggerBlockId, + blocks: { ...blocks, [options.triggerBlockId]: block }, + input: buildWorkflowRunTriggerInput(block, options.workflowInput, { + preserveProvidedInput: options.surface === 'copilot' || isChatRun, + }), + triggerType: isChatRun ? 'chat' : 'manual', } } diff --git a/apps/tradinggoose/lib/workflows/triggers/trigger-utils.ts b/apps/tradinggoose/lib/workflows/triggers/trigger-utils.ts index 884781abe..29f9913f4 100644 --- a/apps/tradinggoose/lib/workflows/triggers/trigger-utils.ts +++ b/apps/tradinggoose/lib/workflows/triggers/trigger-utils.ts @@ -1,3 +1,5 @@ +import { LISTING_IDENTITY_VALUE_TYPE, type ListingIdentity } from '@/lib/listing/identity' + /** * Generates mock data based on the output type definition */ @@ -26,6 +28,13 @@ function generateMockValue(type: string, _description?: string, fieldName?: stri name: 'Sample Object', status: 'active', } + case LISTING_IDENTITY_VALUE_TYPE: + return { + listing_id: 'AAPL', + base_id: '', + quote_id: '', + listing_type: 'default', + } satisfies ListingIdentity default: return null } diff --git a/apps/tradinggoose/lib/workflows/utils.ts b/apps/tradinggoose/lib/workflows/utils.ts index f61f6c1b1..1cb6f6005 100644 --- a/apps/tradinggoose/lib/workflows/utils.ts +++ b/apps/tradinggoose/lib/workflows/utils.ts @@ -569,9 +569,9 @@ export async function validateWorkflowPermissions( } } - const { workflow, workspacePermission, isOwner } = accessContext + const { workflow, workspacePermission, isOwner, isWorkspaceOwner } = accessContext - if (isOwner) { + if (isOwner || isWorkspaceOwner) { return { error: null, session, diff --git a/apps/tradinggoose/lib/workflows/workflow-direction.ts b/apps/tradinggoose/lib/workflows/workflow-direction.ts index 77b3858de..834d92072 100644 --- a/apps/tradinggoose/lib/workflows/workflow-direction.ts +++ b/apps/tradinggoose/lib/workflows/workflow-direction.ts @@ -4,7 +4,7 @@ import type { BlockState, WorkflowDirection } from '@/stores/workflows/workflow/ type WorkflowGraphState = Pick -function getAbsoluteBlockPosition( +export function getAbsoluteBlockPosition( blockId: string, blocks: Record, visiting = new Set() @@ -70,7 +70,9 @@ export function inferMermaidDirectionFromWorkflowState( return horizontalDistance > verticalDistance ? 'LR' : 'TD' } - const positions = Object.keys(blocks).map((blockId) => getPosition(blockId)).filter(Boolean) as Array<{ + const positions = Object.keys(blocks) + .map((blockId) => getPosition(blockId)) + .filter(Boolean) as Array<{ x: number y: number }> diff --git a/apps/tradinggoose/lib/workspaces/service.ts b/apps/tradinggoose/lib/workspaces/service.ts new file mode 100644 index 000000000..aef7c5a75 --- /dev/null +++ b/apps/tradinggoose/lib/workspaces/service.ts @@ -0,0 +1,165 @@ +import { db } from '@tradinggoose/db' +import { permissions, workflow, workspace } from '@tradinggoose/db/schema' +import { and, desc, eq, isNull } from 'drizzle-orm' +import { buildWorkspaceAccessScope } from '@/lib/permissions/utils' +import { saveWorkflowToNormalizedTables } from '@/lib/workflows/db-helpers' +import { buildDefaultWorkflowArtifacts } from '@/lib/workflows/defaults' +import { toWorkspaceApiRecord } from '@/lib/workspaces/billing-owner' +import { tryApplyWorkflowState } from '@/lib/yjs/server/apply-workflow-state' +import { createWorkflowSnapshot } from '@/lib/yjs/workflow-session' + +type WorkspaceRecord = typeof workspace.$inferSelect + +export async function getUserWorkspaces({ + userId, + userName, + autoCreate = true, +}: { + userId: string + userName?: string | null + autoCreate?: boolean +}) { + const workspaceAccess = buildWorkspaceAccessScope(userId, workspace.id) + const userWorkspaces = await db + .select({ + workspace: workspace, + permissionType: permissions.permissionType, + }) + .from(workspace) + .leftJoin(permissions, workspaceAccess.permissionJoin) + .where(workspaceAccess.accessFilter) + .orderBy(desc(workspace.createdAt)) + + if (userWorkspaces.length === 0) { + if (!autoCreate) { + return [] + } + + const defaultWorkspace = await createDefaultWorkspace(userId, userName) + await migrateExistingWorkflows(userId, defaultWorkspace.id) + return [defaultWorkspace] + } + + if (autoCreate) { + await ensureWorkflowsHaveWorkspace(userId, userWorkspaces[0].workspace.id) + } + + return userWorkspaces.map(({ workspace: workspaceDetails, permissionType }) => { + const resolvedPermissionType = workspaceDetails.ownerId === userId ? 'admin' : permissionType + if (!resolvedPermissionType) { + throw new Error(`Expected workspace permission for ${workspaceDetails.id}`) + } + + return { + ...toWorkspaceApiRecord(workspaceDetails), + role: resolvedPermissionType === 'admin' ? 'owner' : 'member', + permissions: resolvedPermissionType, + } + }) +} + +export async function createWorkspace(userId: string, name: string) { + const workspaceId = crypto.randomUUID() + const workflowId = crypto.randomUUID() + const now = new Date() + const workspaceDetails = { + id: workspaceId, + name, + ownerId: userId, + billingOwnerType: 'user', + billingOwnerUserId: userId, + billingOwnerOrganizationId: null, + allowPersonalApiKeys: true, + createdAt: now, + updatedAt: now, + } satisfies WorkspaceRecord + + await db.transaction(async (tx) => { + await tx.insert(workspace).values(workspaceDetails) + + await tx.insert(workflow).values({ + id: workflowId, + userId, + workspaceId, + folderId: null, + name: 'default-agent', + description: 'Your first workflow - start building here!', + color: '#3972F6', + lastSynced: now, + createdAt: now, + updatedAt: now, + isDeployed: false, + collaborators: [], + runCount: 0, + variables: {}, + isPublished: false, + marketplaceData: null, + }) + }) + + const { workflowState } = buildDefaultWorkflowArtifacts() + const lastSaved = now.toISOString() + + try { + const saveResult = await saveWorkflowToNormalizedTables(workflowId, workflowState) + if (!saveResult.success) { + throw new Error(saveResult.error || 'Failed to persist default workflow state') + } + + const seedResult = await tryApplyWorkflowState( + workflowId, + createWorkflowSnapshot({ + blocks: saveResult.normalizedState?.blocks ?? workflowState.blocks, + edges: saveResult.normalizedState?.edges ?? workflowState.edges, + loops: saveResult.normalizedState?.loops ?? workflowState.loops, + parallels: saveResult.normalizedState?.parallels ?? workflowState.parallels, + lastSaved, + isDeployed: false, + }), + undefined, + 'default-agent' + ) + if (!seedResult.success) { + throw seedResult.error instanceof Error + ? seedResult.error + : new Error('Failed to seed default workflow state') + } + } catch (error) { + await db.transaction(async (tx) => { + await tx.delete(workflow).where(eq(workflow.id, workflowId)) + await tx.delete(workspace).where(eq(workspace.id, workspaceId)) + }) + throw error + } + + return { + ...toWorkspaceApiRecord(workspaceDetails), + role: 'owner', + permissions: 'admin', + } +} + +async function createDefaultWorkspace(userId: string, userName?: string | null) { + const firstName = userName?.split(' ')[0] || null + return createWorkspace(userId, firstName ? `${firstName}'s Workspace` : 'My Workspace') +} + +async function migrateExistingWorkflows(userId: string, workspaceId: string) { + await db + .update(workflow) + .set({ + workspaceId, + updatedAt: new Date(), + }) + .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) +} + +async function ensureWorkflowsHaveWorkspace(userId: string, defaultWorkspaceId: string) { + await db + .update(workflow) + .set({ + workspaceId: defaultWorkspaceId, + updatedAt: new Date(), + }) + .where(and(eq(workflow.userId, userId), isNull(workflow.workspaceId))) +} diff --git a/apps/tradinggoose/lib/yjs/use-workflow-doc.ts b/apps/tradinggoose/lib/yjs/use-workflow-doc.ts index 22f70573f..a81ca5f8c 100644 --- a/apps/tradinggoose/lib/yjs/use-workflow-doc.ts +++ b/apps/tradinggoose/lib/yjs/use-workflow-doc.ts @@ -15,7 +15,7 @@ import type { Edge } from '@xyflow/react' import type * as Y from 'yjs' import { escapeRegExp } from '@/lib/utils' import { readBlockOutputs, resolveBlockRuntimeState } from '@/lib/workflows/block-outputs' -import { resolveInitialSubBlockValue } from '@/lib/workflows/subblock-values' +import { buildInitialSubBlockStates } from '@/lib/workflows/subblock-values' import { YJS_ORIGINS, type YjsOrigin } from '@/lib/yjs/transaction-origins' import { useYjsSubscription } from '@/lib/yjs/use-yjs-subscription' import { rewriteWorkflowContentReferences } from '@/lib/yjs/workflow-reference-rewrite' @@ -839,25 +839,13 @@ export function useWorkflowMutations() { const blockConfig = getBlock(type) let subBlocks: Record = {} const outputs: Record = {} - const resolvedSubBlockParams: Record = {} if (blockConfig) { const initValues = blockProperties?.initialSubBlockValues - blockConfig.subBlocks.forEach((subBlock) => { - const resolvedInitialValue = resolveInitialSubBlockValue( - subBlock, - resolvedSubBlockParams, - initValues?.[subBlock.id] - ) - - subBlocks[subBlock.id] = { - id: subBlock.id, - type: subBlock.type, - value: resolvedInitialValue as any, - } - - resolvedSubBlockParams[subBlock.id] = resolvedInitialValue - }) + subBlocks = buildInitialSubBlockStates( + blockConfig.subBlocks, + initValues + ) as Record const runtimeState = resolveBlockRuntimeState({ blockType: type, diff --git a/apps/tradinggoose/lib/yjs/workflow-session-host.tsx b/apps/tradinggoose/lib/yjs/workflow-session-host.tsx index 76dd606ea..322df4e10 100644 --- a/apps/tradinggoose/lib/yjs/workflow-session-host.tsx +++ b/apps/tradinggoose/lib/yjs/workflow-session-host.tsx @@ -1,34 +1,24 @@ 'use client' -import React, { - createContext, - useCallback, - useContext, - useEffect, - useState, - type ReactNode, -} from 'react' -import * as Y from 'yjs' +import { createContext, type ReactNode, useCallback, useContext, useEffect, useState } from 'react' +import type * as Y from 'yjs' +import { YJS_ORIGINS } from '@/lib/yjs/transaction-origins' +import { readWorkflowSnapshotCloned, type WorkflowSnapshot } from '@/lib/yjs/workflow-session' import { - EMPTY_SHARED_WORKFLOW_SESSION_STATE, acquireSharedWorkflowSession, + EMPTY_SHARED_WORKFLOW_SESSION_STATE, getSharedWorkflowSessionState, redoSharedWorkflowSession, + type SharedWorkflowSessionState, setSharedWorkflowSessionUser, subscribeToSharedWorkflowSession, undoSharedWorkflowSession, - type SharedWorkflowSessionState, } from '@/lib/yjs/workflow-shared-session' -import { - readWorkflowSnapshotCloned, - type WorkflowSnapshot, -} from '@/lib/yjs/workflow-session' -import { YJS_ORIGINS } from '@/lib/yjs/transaction-origins' export interface WorkflowSessionContextValue { workflowId: string doc: Y.Doc | null - awareness: any | null + awareness: SharedWorkflowSessionState['awareness'] isSynced: boolean isLoading: boolean error: string | null @@ -74,7 +64,9 @@ export function WorkflowSessionProvider({ children, }: WorkflowSessionProviderProps) { const [state, setState] = useState(() => - workflowId ? getSharedWorkflowSessionState(workflowId) : { ...EMPTY_SHARED_WORKFLOW_SESSION_STATE } + workflowId + ? getSharedWorkflowSessionState(workflowId) + : { ...EMPTY_SHARED_WORKFLOW_SESSION_STATE } ) const { doc, awareness, isSynced, isLoading, error, canUndo, canRedo } = state @@ -143,9 +135,5 @@ export function WorkflowSessionProvider({ redo, } - return ( - - {children} - - ) + return {children} } diff --git a/apps/tradinggoose/lib/yjs/workflow-shared-session.test.ts b/apps/tradinggoose/lib/yjs/workflow-shared-session.test.ts index 736470b05..d0aca6f7d 100644 --- a/apps/tradinggoose/lib/yjs/workflow-shared-session.test.ts +++ b/apps/tradinggoose/lib/yjs/workflow-shared-session.test.ts @@ -1,5 +1,5 @@ -import * as Y from 'yjs' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import * as Y from 'yjs' import { YJS_ORIGINS } from '@/lib/yjs/transaction-origins' const mockBootstrapYjsProvider = vi.fn() @@ -39,6 +39,26 @@ function createMockProvider() { } } +function createBootstrapResult(doc: Y.Doc, provider: ReturnType) { + return { + doc, + provider, + descriptor: { + workspaceId: 'workspace-1', + entityKind: 'workflow', + entityId: 'workflow-1', + draftSessionId: null, + reviewSessionId: null, + yjsSessionId: 'workflow-1', + }, + runtime: { + docState: 'active', + replaySafe: true, + reseededFromCanonical: false, + }, + } +} + async function waitForCondition(assertion: () => void, timeoutMs = 1000) { const start = Date.now() @@ -68,12 +88,52 @@ describe('workflow shared session lifecycle', () => { mockWaitForYjsWriteSync.mockResolvedValue(undefined) mockRegisterWorkflowSession.mockReset() mockUnregisterWorkflowSession.mockReset() - delete globalThis.__workflowYjsSessionEntries + globalThis.__workflowYjsSessionEntries = undefined }) afterEach(() => { vi.useRealTimers() - delete globalThis.__workflowYjsSessionEntries + globalThis.__workflowYjsSessionEntries = undefined + }) + + it('does not publish a readable doc before bootstrap completes', async () => { + const doc = new Y.Doc() + const provider = createMockProvider() + let finishBootstrap!: () => void + const bootstrapReady = new Promise((resolve) => { + finishBootstrap = resolve + }) + + mockBootstrapYjsProvider.mockImplementation(async () => { + await bootstrapReady + return createBootstrapResult(doc, provider) + }) + + const { acquireSharedWorkflowSession, getSharedWorkflowSessionState } = await import( + './workflow-shared-session' + ) + + const release = acquireSharedWorkflowSession({ + workflowId: 'workflow-1', + workspaceId: 'workspace-1', + }) + + await waitForCondition(() => { + expect(mockBootstrapYjsProvider).toHaveBeenCalledTimes(1) + }) + expect(getSharedWorkflowSessionState('workflow-1')).toMatchObject({ + doc: null, + isLoading: true, + }) + + finishBootstrap() + + await waitForCondition(() => { + expect(getSharedWorkflowSessionState('workflow-1').doc).toBe(doc) + expect(getSharedWorkflowSessionState('workflow-1').isLoading).toBe(false) + }) + + release() }) it('reuses one bootstrapped workflow session across multiple acquisitions', async () => { @@ -82,23 +142,7 @@ describe('workflow shared session lifecycle', () => { const destroyDoc = vi.spyOn(doc, 'destroy') const provider = createMockProvider() - mockBootstrapYjsProvider.mockResolvedValue({ - doc, - provider, - descriptor: { - workspaceId: 'workspace-1', - entityKind: 'workflow', - entityId: 'workflow-1', - draftSessionId: null, - reviewSessionId: null, - yjsSessionId: 'workflow-1', - }, - runtime: { - docState: 'active', - replaySafe: true, - reseededFromCanonical: false, - }, - }) + mockBootstrapYjsProvider.mockResolvedValue(createBootstrapResult(doc, provider)) const { acquireSharedWorkflowSession, @@ -168,23 +212,7 @@ describe('workflow shared session lifecycle', () => { const destroyDoc = vi.spyOn(doc, 'destroy') const provider = createMockProvider() - mockBootstrapYjsProvider.mockResolvedValue({ - doc, - provider, - descriptor: { - workspaceId: 'workspace-1', - entityKind: 'workflow', - entityId: 'workflow-1', - draftSessionId: null, - reviewSessionId: null, - yjsSessionId: 'workflow-1', - }, - runtime: { - docState: 'active', - replaySafe: true, - reseededFromCanonical: false, - }, - }) + mockBootstrapYjsProvider.mockResolvedValue(createBootstrapResult(doc, provider)) const { acquireSharedWorkflowSession, getSharedWorkflowSessionState } = await import( './workflow-shared-session' @@ -228,23 +256,7 @@ describe('workflow shared session lifecycle', () => { const doc = new Y.Doc() const provider = createMockProvider() - mockBootstrapYjsProvider.mockResolvedValue({ - doc, - provider, - descriptor: { - workspaceId: 'workspace-1', - entityKind: 'workflow', - entityId: 'workflow-1', - draftSessionId: null, - reviewSessionId: null, - yjsSessionId: 'workflow-1', - }, - runtime: { - docState: 'active', - replaySafe: true, - reseededFromCanonical: false, - }, - }) + mockBootstrapYjsProvider.mockResolvedValue(createBootstrapResult(doc, provider)) const { acquireSharedWorkflowSession, diff --git a/apps/tradinggoose/lib/yjs/workflow-shared-session.ts b/apps/tradinggoose/lib/yjs/workflow-shared-session.ts index 3eff46c8b..40b125542 100644 --- a/apps/tradinggoose/lib/yjs/workflow-shared-session.ts +++ b/apps/tradinggoose/lib/yjs/workflow-shared-session.ts @@ -1,7 +1,7 @@ 'use client' -import * as Y from 'yjs' import type { WebsocketProvider } from 'y-websocket' +import * as Y from 'yjs' import type { ReviewTargetDescriptor } from '@/lib/copilot/review-sessions/types' import { deriveUserColor } from '@/lib/utils' import { @@ -9,23 +9,23 @@ import { waitForYjsWriteSync, type YjsProviderBootstrapResult, } from '@/lib/yjs/provider' +import { createYjsUndoTrackedOrigins } from '@/lib/yjs/transaction-origins' import { getMetadataMap, getVariablesMap, readWorkflowMap, readWorkflowTextFieldsMap, } from '@/lib/yjs/workflow-session' -import { createYjsUndoTrackedOrigins } from '@/lib/yjs/transaction-origins' import { + type RegisteredWorkflowSession, registerWorkflowSession, unregisterWorkflowSession, - type RegisteredWorkflowSession, } from '@/lib/yjs/workflow-session-registry' export interface SharedWorkflowSessionState { doc: Y.Doc | null provider: WebsocketProvider | null - awareness: any | null + awareness: WebsocketProvider['awareness'] | null canUndo: boolean canRedo: boolean isSynced: boolean @@ -177,7 +177,11 @@ async function initializeSharedSession(entry: SharedWorkflowSessionEntry): Promi } const undoManager = new Y.UndoManager( - [readWorkflowMap(result.doc), readWorkflowTextFieldsMap(result.doc), getVariablesMap(result.doc)], + [ + readWorkflowMap(result.doc), + readWorkflowTextFieldsMap(result.doc), + getVariablesMap(result.doc), + ], { trackedOrigins: createYjsUndoTrackedOrigins(), } diff --git a/apps/tradinggoose/next.config.test.ts b/apps/tradinggoose/next.config.test.ts index 989ea082f..69172a6e8 100644 --- a/apps/tradinggoose/next.config.test.ts +++ b/apps/tradinggoose/next.config.test.ts @@ -22,13 +22,12 @@ function buildSourceMatcher(source: string) { switch (source) { case '/api/:path((?!workflows/[^/]+/execute$).*)': - return (path: string) => /^\/api\/.+$/.test(path) && !/^\/api\/workflows\/[^/]+\/execute$/.test(path) + return (path: string) => + /^\/api\/.+$/.test(path) && !/^\/api\/workflows\/[^/]+\/execute$/.test(path) case '/api/workflows/:id/execute': return (path: string) => /^\/api\/workflows\/[^/]+\/execute$/.test(path) - case '/:app(w|workspace|chat)/:path*': - return (path: string) => /^\/(?:w|workspace|chat)(?:\/.*)?$/.test(path) - case '/:locale(es|zh)/:app(w|workspace|chat)/:path*': - return (path: string) => /^\/(?:es|zh)\/(?:w|workspace|chat)(?:\/.*)?$/.test(path) + case '/:locale(en|es|zh)/:app(workspace|chat)/:path*': + return (path: string) => /^\/(?:en|es|zh)\/(?:workspace|chat)(?:\/.*)?$/.test(path) case '/api/tools/drive/:path*': return (path: string) => /^\/api\/tools\/drive(?:\/.*)?$/.test(path) case '/_next/:path*': @@ -47,7 +46,9 @@ function matchesSource(source: string, path: string) { function getHeaderValues(rules: HeaderRules, path: string, key: string) { return rules .filter((rule) => matchesSource(rule.source, path)) - .flatMap((rule) => rule.headers.filter((header) => header.key === key).map((header) => header.value)) + .flatMap((rule) => + rule.headers.filter((header) => header.key === key).map((header) => header.value) + ) } function expectHeaderValue(rules: HeaderRules, path: string, key: string, value: string) { @@ -67,11 +68,10 @@ describe('next.config headers routing', () => { } }) - it('applies permissive app headers to localized, unlocalized, and internal app resources', async () => { + it('applies permissive app headers to localized app and internal resource routes', async () => { const rules = await getHeaderRules() const appPaths = [ - '/workspace/ws-1/dashboard', - '/chat/test-chat', + '/en/workspace/ws-1/dashboard', '/es/workspace/ws-1/dashboard', '/zh/chat/test-chat', ] @@ -99,7 +99,7 @@ describe('next.config headers routing', () => { it('keeps strict cross-origin and public-page CSP headers on representative public routes', async () => { const rules = await getHeaderRules() - const publicPaths = ['/', '/privacy', '/es/privacy', '/blog/hello-world'] + const publicPaths = ['/en/privacy', '/es/privacy', '/en/blog/hello-world'] const infrastructurePaths = ['/ingest/e'] for (const path of publicPaths) { diff --git a/apps/tradinggoose/next.config.ts b/apps/tradinggoose/next.config.ts index 8afac9b62..60fc9208f 100644 --- a/apps/tradinggoose/next.config.ts +++ b/apps/tradinggoose/next.config.ts @@ -11,10 +11,10 @@ const MONACO_TRACE_FILES = MONACO_TRACE_ROOTS.flatMap((root) => [ `${root}/.bun/monaco-editor@*/node_modules/monaco-editor/esm/**/*.js`, `${root}/.bun/monaco-editor@*/node_modules/monaco-editor/esm/**/*.js.map`, ]) -const PUBLIC_LOCALE_ROUTE_PREFIX = '(?:es|zh)' +const PUBLIC_LOCALE_ROUTE_PREFIX = '(?:en|es|zh)' const API_ROUTE_LOOKAHEAD = 'api(?:/.*)?$' const INGEST_ROUTE_LOOKAHEAD = 'ingest(?:/.*)?$' -const LOCALIZED_APP_ROUTE_SOURCE = `(?:${PUBLIC_LOCALE_ROUTE_PREFIX}/)?(?:w|workspace|chat)(?:/.*)?` +const LOCALIZED_APP_ROUTE_SOURCE = `${PUBLIC_LOCALE_ROUTE_PREFIX}/(?:workspace|chat)(?:/.*)?` const LOCALIZED_APP_ROUTE_LOOKAHEAD = `${LOCALIZED_APP_ROUTE_SOURCE}$` const API_ROUTE_PARAM_EXCLUDING_WORKFLOW_EXECUTION = ':path((?!workflows/[^/]+/execute$).*)' @@ -82,6 +82,7 @@ const nextConfig: NextConfig = { '/monaco-editor/esm/**/*': MONACO_TRACE_FILES, }, turbopack: { + root: new URL('../..', import.meta.url).pathname, resolveExtensions: ['.tsx', '.ts', '.jsx', '.js', '.mjs', '.json'], }, serverExternalPackages: [ @@ -209,14 +210,9 @@ const nextConfig: NextConfig = { }, ], }, - { - // For main app routes - use permissive policies - source: '/:app(w|workspace|chat)/:path*', - headers: permissiveRouteHeaders, - }, { // Localized public app routes use the same permissive policies - source: '/:locale(es|zh)/:app(w|workspace|chat)/:path*', + source: '/:locale(en|es|zh)/:app(workspace|chat)/:path*', headers: permissiveRouteHeaders, }, { diff --git a/apps/tradinggoose/package.json b/apps/tradinggoose/package.json index 525066cd5..83c047aa8 100644 --- a/apps/tradinggoose/package.json +++ b/apps/tradinggoose/package.json @@ -43,13 +43,10 @@ "@hookform/resolvers": "^4.1.3", "@monaco-editor/react": "4.7.0", "@opentelemetry/api": "^1.9.0", - "@opentelemetry/exporter-jaeger": "2.1.0", "@opentelemetry/exporter-trace-otlp-http": "^0.200.0", "@opentelemetry/resources": "^2.0.0", - "@opentelemetry/sdk-node": "^0.200.0", "@opentelemetry/sdk-trace-base": "2.0.0", "@opentelemetry/sdk-trace-node": "2.0.0", - "@opentelemetry/semantic-conventions": "^1.32.0", "@radix-ui/react-alert-dialog": "^1.1.5", "@radix-ui/react-avatar": "1.1.10", "@radix-ui/react-checkbox": "^1.1.3", diff --git a/apps/tradinggoose/proxy.test.ts b/apps/tradinggoose/proxy.test.ts index 95b5bb547..a303bfaf0 100644 --- a/apps/tradinggoose/proxy.test.ts +++ b/apps/tradinggoose/proxy.test.ts @@ -1,12 +1,6 @@ import { NextRequest } from 'next/server' import { beforeEach, describe, expect, it, vi } from 'vitest' -const mockGetSessionCookie = vi.fn() - -vi.mock('better-auth/cookies', () => ({ - getSessionCookie: (...args: unknown[]) => mockGetSessionCookie(...args), -})) - vi.mock('./lib/logs/console/logger', () => ({ createLogger: () => ({ warn: vi.fn(), @@ -19,17 +13,19 @@ vi.mock('./lib/security/csp', () => ({ vi.mock('next-intl/middleware', async () => { const { NextResponse } = await vi.importActual('next/server') + const locales = ['en', 'es', 'zh'] as const return { default: () => (request: { nextUrl: URL; url: string }) => { const url = new URL(request.url) + const firstSegment = url.pathname.split('/').filter(Boolean)[0] - if (url.pathname === '/en' || url.pathname.startsWith('/en/')) { - url.pathname = url.pathname === '/en' ? '/' : url.pathname.slice('/en'.length) - return NextResponse.redirect(url) + if (firstSegment && locales.includes(firstSegment as (typeof locales)[number])) { + return NextResponse.next() } - return NextResponse.next() + url.pathname = url.pathname === '/' ? '/en' : `/en${url.pathname}` + return NextResponse.redirect(url) }, } }) @@ -41,75 +37,164 @@ describe('proxy auth routing', () => { process.env.NEXT_PUBLIC_APP_URL = 'https://www.tradinggoose.ai' }) - it('uses the request host for localhost auth redirects instead of hosted-mode rewrites', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - + it('uses the request host for protected route locale redirects', async () => { const { proxy } = await import('./proxy') const response = await proxy( - new NextRequest('http://localhost:3000/workspace/ws-1/dashboard?layoutId=layout-1') + new NextRequest('http://localhost:3000/workspace/ws-1/dashboard?layoutId=layout-1', { + headers: { + 'user-agent': 'vitest', + }, + }) ) expect(response.status).toBe(307) expect(response.headers.get('location')).toBe( - 'http://localhost:3000/login?callbackUrl=%2Fworkspace%2Fws-1%2Fdashboard%3FlayoutId%3Dlayout-1' + 'http://localhost:3000/en/workspace/ws-1/dashboard?layoutId=layout-1' ) expect(response.headers.get('x-middleware-rewrite')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('en') + }) + + it.each([ + ['root', 'http://localhost:3000/', 'http://localhost:3000/en'], + ['privacy', 'http://localhost:3000/privacy', 'http://localhost:3000/en/privacy'], + ['login', 'http://localhost:3000/login', 'http://localhost:3000/en/login'], + ])( + 'redirects unprefixed %s routes to the default locale when no preference is present', + async (_, url, location) => { + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest(url, { + headers: { + 'user-agent': 'vitest', + accept: 'text/html', + }, + }) + ) + + expect(response.status).toBe(307) + expect(response.headers.get('location')).toBe(location) + expect(response.headers.get('x-middleware-rewrite')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('en') + } + ) + + it.each([ + ['root', 'http://localhost:3000/?source=nav', 'http://localhost:3000/zh?source=nav'], + ['privacy', 'http://localhost:3000/privacy', 'http://localhost:3000/zh/privacy'], + ['login', 'http://localhost:3000/login', 'http://localhost:3000/zh/login'], + ])('redirects anonymous unprefixed %s routes to the locale cookie', async (_, url, location) => { + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest(url, { + headers: { + cookie: 'NEXT_LOCALE=zh', + 'user-agent': 'vitest', + }, + }) + ) + + expect(response.status).toBe(307) + expect(response.headers.get('location')).toBe(location) + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('zh') + }) + + it('localizes unprefixed protected routes using Accept-Language', async () => { + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest('http://localhost:3000/workspace/ws-1/dashboard', { + headers: { + 'accept-language': 'es-ES,es;q=0.9,en;q=0.8', + 'user-agent': 'vitest', + }, + }) + ) + + expect(response.status).toBe(307) + expect(response.headers.get('location')).toBe( + 'http://localhost:3000/es/workspace/ws-1/dashboard' + ) + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('es') }) - it('redirects hosted protected routes to login when no session is present', async () => { + it('localizes hosted protected routes before the app auth boundary handles access', async () => { const { proxy } = await import('./proxy') const response = await proxy( - new NextRequest('https://www.tradinggoose.ai/workspace/ws-1/dashboard') + new NextRequest('https://www.tradinggoose.ai/workspace/ws-1/dashboard', { + headers: { + 'user-agent': 'vitest', + }, + }) ) expect(response.status).toBe(307) expect(response.headers.get('location')).toBe( - 'https://www.tradinggoose.ai/login?callbackUrl=%2Fworkspace%2Fws-1%2Fdashboard' + 'https://www.tradinggoose.ai/en/workspace/ws-1/dashboard' ) expect(response.headers.get('x-middleware-rewrite')).toBeNull() }) - it('allows the login route through when reauth is explicitly requested', async () => { - mockGetSessionCookie.mockReturnValue('stale-cookie') + it('lets the default-locale reauth login route reach its page boundary', async () => { + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' const { proxy } = await import('./proxy') const response = await proxy( - new NextRequest('http://localhost:3000/login?reauth=1&callbackUrl=%2Fworkspace%2Fws-1') + new NextRequest('http://localhost:3000/en/login?reauth=1&callbackUrl=%2Fworkspace%2Fws-1', { + headers: { + 'user-agent': 'vitest', + }, + }) ) expect(response.status).toBe(200) expect(response.headers.get('location')).toBeNull() - expect(response.cookies.get('better-auth.session_token')?.maxAge).toBe(0) + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('en') }) - it('preserves locale on the login route while keeping callback targets canonical', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - + it('keeps localized protected routes at the app auth boundary with a canonical callback header', async () => { const { proxy } = await import('./proxy') const response = await proxy( - new NextRequest('http://localhost:3000/es/workspace/ws-1/dashboard?layoutId=layout-1') + new NextRequest('http://localhost:3000/es/workspace/ws-1/dashboard?layoutId=layout-1', { + headers: { + 'user-agent': 'vitest', + }, + }) ) - expect(response.status).toBe(307) - expect(response.headers.get('location')).toBe( - 'http://localhost:3000/es/login?callbackUrl=%2Fworkspace%2Fws-1%2Fdashboard%3FlayoutId%3Dlayout-1' + expect(response.status).toBe(200) + expect(response.headers.get('location')).toBeNull() + expect(response.headers.get('x-middleware-request-x-tradinggoose-callback-path')).toBe( + '/workspace/ws-1/dashboard?layoutId=layout-1' ) expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('es') }) - it('redirects authenticated localized auth routes to the localized workspace root', async () => { - mockGetSessionCookie.mockReturnValue('session-cookie') + it.each([ + '/es/login', + '/es/signup', + '/es/waitlist', + '/es/reset-password', + '/es/verify', + '/es/sso', + '/es/error', + ])('lets auth route %s reach its page boundary', async (pathname) => { + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' const { proxy } = await import('./proxy') - const response = await proxy(new NextRequest('http://localhost:3000/es/login')) + const response = await proxy( + new NextRequest(`http://localhost:3000${pathname}`, { + headers: { + 'user-agent': 'vitest', + }, + }) + ) - expect(response.status).toBe(307) - expect(response.headers.get('location')).toBe('http://localhost:3000/es/workspace') + expect(response.status).toBe(200) + expect(response.headers.get('location')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('es') }) - it('normalizes default-locale prefixed routes before rendering', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - + it('keeps default-locale prefixed routes canonical', async () => { const { proxy } = await import('./proxy') const response = await proxy( new NextRequest('http://localhost:3000/en/login', { @@ -119,13 +204,13 @@ describe('proxy auth routing', () => { }) ) - expect(response.status).toBe(307) - expect(response.headers.get('location')).toBe('http://localhost:3000/login') + expect(response.status).toBe(200) + expect(response.headers.get('location')).toBeNull() + expect(response.headers.get('x-middleware-rewrite')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('en') }) - it('lets next-intl handle localized landing routes without stripping the locale', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - + it('lets next-intl handle localized landing routes canonically', async () => { const { proxy } = await import('./proxy') const response = await proxy( new NextRequest('http://localhost:3000/es', { @@ -142,30 +227,74 @@ describe('proxy auth routing', () => { }) it.each([ - ['root', 'http://localhost:3000/?source=nav', 'http://localhost:3000/zh?source=nav'], - ['workspace', 'http://localhost:3000/workspace', 'http://localhost:3000/zh/workspace'], - ])( - 'redirects canonical %s requests to the locale remembered by NEXT_LOCALE', - async (_, url, location) => { - mockGetSessionCookie.mockReturnValue('session-cookie') + ['root', 'http://localhost:3000/?source=nav', 'http://localhost:3000/es?source=nav'], + ['workspace', 'http://localhost:3000/workspace', 'http://localhost:3000/es/workspace'], + ])('redirects locale-cookie %s requests to the request locale', async (_, url, location) => { + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' - const { proxy } = await import('./proxy') - const response = await proxy( - new NextRequest(url, { - headers: { - cookie: 'NEXT_LOCALE=zh', - 'user-agent': 'vitest', - }, - }) - ) + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest(url, { + headers: { + cookie: 'NEXT_LOCALE=es', + 'user-agent': 'vitest', + }, + }) + ) - expect(response.status).toBe(307) - expect(response.headers.get('location')).toBe(location) - } - ) + expect(response.status).toBe(307) + expect(response.headers.get('location')).toBe(location) + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('es') + }) + + it('keeps locale-cookie prefixed requests canonical to the URL locale', async () => { + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' + + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest('http://localhost:3000/en/workspace', { + headers: { + cookie: 'NEXT_LOCALE=zh', + 'user-agent': 'vitest', + }, + }) + ) + + expect(response.status).toBe(200) + expect(response.headers.get('location')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')?.value).toBe('en') + }) + + it('rewrites POST protected requests with the canonical callback header', async () => { + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' + + const { proxy } = await import('./proxy') + const response = await proxy( + new NextRequest('http://localhost:3000/workspace/ws-1/dashboard?layoutId=layout-1', { + method: 'POST', + headers: { + cookie: 'NEXT_LOCALE=es', + 'user-agent': 'vitest', + }, + }) + ) + + expect(response.status).toBe(200) + expect(response.headers.get('location')).toBeNull() + expect(response.headers.get('x-middleware-rewrite')).toBe( + 'http://localhost:3000/es/workspace/ws-1/dashboard?layoutId=layout-1' + ) + expect(response.headers.get('x-middleware-request-x-tradinggoose-callback-path')).toBe( + '/workspace/ws-1/dashboard?layoutId=layout-1' + ) + expect(response.headers.get('x-middleware-override-headers')?.split(',')).toContain( + 'x-tradinggoose-callback-path' + ) + expect(response.cookies.get('NEXT_LOCALE')).toBeUndefined() + }) it('does not rewrite localized API-shaped paths to canonical API routes', async () => { - mockGetSessionCookie.mockReturnValue('session-cookie') + process.env.NEXT_PUBLIC_APP_URL = 'http://localhost:3000' const { proxy } = await import('./proxy') const response = await proxy( @@ -186,8 +315,6 @@ describe('proxy auth routing', () => { }) it('exempts canonical webhook trigger API requests from suspicious user-agent filtering', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - const { proxy } = await import('./proxy') const response = await proxy( new NextRequest('http://localhost:3000/api/webhooks/trigger/webhook-1', { @@ -199,11 +326,10 @@ describe('proxy auth routing', () => { expect(response.status).toBe(200) expect(response.headers.get('x-middleware-rewrite')).toBeNull() + expect(response.cookies.get('NEXT_LOCALE')).toBeUndefined() }) it('does not exempt localized API-shaped webhook paths from suspicious user-agent filtering', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - const { proxy } = await import('./proxy') const response = await proxy( new NextRequest('http://localhost:3000/es/api/webhooks/trigger/webhook-1', { @@ -217,12 +343,10 @@ describe('proxy auth routing', () => { expect(response.headers.get('x-middleware-rewrite')).toBeNull() }) - it('rewrites localized markdown requests with the normalized content path', async () => { - mockGetSessionCookie.mockReturnValue(undefined) - + it('rewrites default-locale markdown requests with the normalized content path', async () => { const { proxy } = await import('./proxy') const response = await proxy( - new NextRequest('http://localhost:3000/es/terms', { + new NextRequest('http://localhost:3000/en/terms', { headers: { accept: 'text/markdown', 'user-agent': 'vitest', @@ -232,7 +356,7 @@ describe('proxy auth routing', () => { expect(response.status).toBe(200) expect(response.headers.get('x-middleware-rewrite')).toBe( - 'http://localhost:3000/api/markdown?path=%2Fterms&locale=es' + 'http://localhost:3000/api/markdown?path=%2Fterms&locale=en' ) }) }) diff --git a/apps/tradinggoose/proxy.ts b/apps/tradinggoose/proxy.ts index 87f7bc18c..671ecb383 100644 --- a/apps/tradinggoose/proxy.ts +++ b/apps/tradinggoose/proxy.ts @@ -1,4 +1,3 @@ -import { getSessionCookie } from 'better-auth/cookies' import { type NextRequest, NextResponse } from 'next/server' import createMiddleware from 'next-intl/middleware' import { appendHomepageDiscoveryLinks } from '@/lib/discovery/link-headers' @@ -14,6 +13,8 @@ import { CANONICAL_CALLBACK_PATH_HEADER, defaultLocale, isLocaleCode, + LOCALE_COOKIE, + LOCALE_COOKIE_MAX_AGE, type LocaleCode, localizeUrl, stripLocaleFromPathname, @@ -23,29 +24,6 @@ import { generateRuntimeCSP } from './lib/security/csp' const logger = createLogger('Proxy') const handleI18nRouting = createMiddleware(routing) -const LOCALE_COOKIE = 'NEXT_LOCALE' -const LOCALE_COOKIE_MAX_AGE = 60 * 60 * 24 * 365 - -const AUTH_ROUTES = new Set(['/login', '/signup']) -const AUTH_COOKIE_KEYS = [ - 'better-auth.session_token', - 'better-auth.session_data', - 'better-auth.dont_remember', - '__Secure-better-auth.session_token', - '__Secure-better-auth.session_data', - '__Secure-better-auth.dont_remember', -] - -function clearAuthCookies(response: NextResponse) { - AUTH_COOKIE_KEYS.forEach((name) => { - response.cookies.set({ - name, - value: '', - maxAge: 0, - path: '/', - }) - }) -} const SUSPICIOUS_UA_PATTERNS = [ /^\s*$/, @@ -61,32 +39,23 @@ interface LocaleRoute { hasLocalePrefix: boolean } -function resolveLocaleRoute(pathname: string): LocaleRoute { +type AcceptLanguageCandidate = { + locale: LocaleCode + quality: number + index: number +} + +function resolveLocaleRoute(pathname: string, localeOverride?: LocaleCode): LocaleRoute { const firstSegment = pathname.split('/').filter(Boolean)[0] const { locale, pathname: normalizedPathname } = stripLocaleFromPathname(pathname) + const hasLocalePrefix = Boolean(firstSegment && isLocaleCode(firstSegment)) return { - locale, + locale: hasLocalePrefix ? locale : (localeOverride ?? locale), pathname: normalizedPathname, - hasLocalePrefix: Boolean(firstSegment && isLocaleCode(firstSegment)), + hasLocalePrefix, } } -function buildNormalizedUrl(request: NextRequest, pathname: string) { - const normalizedUrl = new URL(pathname, request.url) - normalizedUrl.search = request.nextUrl.search - return normalizedUrl -} - -function resolveRequestLocale( - request: NextRequest, - route = resolveLocaleRoute(request.nextUrl.pathname) -) { - const preferredLocale = request.cookies.get(LOCALE_COOKIE)?.value - return route.hasLocalePrefix || !preferredLocale || !isLocaleCode(preferredLocale) - ? route.locale - : preferredLocale -} - function isCanonicalRouteHandlerPath(pathname: string) { return ( pathname === '/api' || @@ -105,15 +74,52 @@ function isCanonicalRouteHandlerPath(pathname: string) { ) } -function buildLoginRedirect(request: NextRequest, callback?: string) { - const locale = resolveRequestLocale(request) - const loginUrl = new URL(localizeUrl(request.nextUrl.origin, locale, '/login')) +function getLocaleCookie(request: NextRequest): LocaleCode | null { + const locale = request.cookies.get(LOCALE_COOKIE)?.value + return locale && isLocaleCode(locale) ? locale : null +} - if (callback) { - loginUrl.searchParams.set('callbackUrl', callback) +function getAcceptLanguageLocale(header: string | null): LocaleCode | null { + if (!header) { + return null } - return withLocaleCookie(NextResponse.redirect(loginUrl), locale) + const candidates: AcceptLanguageCandidate[] = [] + + header.split(',').forEach((entry, index) => { + const [rawLanguageRange, ...rawParams] = entry + .split(';') + .map((part) => part.trim()) + .filter(Boolean) + + if (!rawLanguageRange || rawLanguageRange === '*') { + return + } + + const locale = rawLanguageRange.toLowerCase().split('-', 1)[0] + if (!isLocaleCode(locale)) { + return + } + + const qualityParam = rawParams.find((param) => param.toLowerCase().startsWith('q=')) + const quality = qualityParam ? Number.parseFloat(qualityParam.slice(2)) : 1 + if (!Number.isFinite(quality) || quality <= 0) { + return + } + + candidates.push({ locale, quality, index }) + }) + + candidates.sort((a, b) => b.quality - a.quality || a.index - b.index) + return candidates[0]?.locale ?? null +} + +function resolveRequestLocale(request: NextRequest): LocaleCode { + return ( + getLocaleCookie(request) ?? + getAcceptLanguageLocale(request.headers.get('accept-language')) ?? + defaultLocale + ) } function isProtectedAppPath(pathname: string): boolean { @@ -127,14 +133,10 @@ function isProtectedAppPath(pathname: string): boolean { ) } -function isAuthRoute(pathname: string): boolean { - const { pathname: normalizedPathname } = resolveLocaleRoute(pathname) - return AUTH_ROUTES.has(normalizedPathname) -} - -function getCanonicalCallbackPath(pathname: string, search: string) { - const { pathname: normalizedPathname } = resolveLocaleRoute(pathname) - return `${normalizedPathname}${search}` +function buildProtectedRequestHeaders(request: NextRequest, route: LocaleRoute) { + const requestHeaders = new Headers(request.headers) + requestHeaders.set(CANONICAL_CALLBACK_PATH_HEADER, `${route.pathname}${request.nextUrl.search}`) + return requestHeaders } function isMarkdownRequestPath(pathname: string) { @@ -166,10 +168,6 @@ function rewriteMarkdownRequest(request: NextRequest): NextResponse | null { const route = resolveLocaleRoute(request.nextUrl.pathname) const { locale, pathname: normalizedPathname } = route - if (route.hasLocalePrefix && locale === defaultLocale) { - return NextResponse.redirect(buildNormalizedUrl(request, normalizedPathname)) - } - const rewriteUrl = new URL(MARKDOWN_RENDER_ROUTE, request.url) rewriteUrl.searchParams.set('path', normalizedPathname) rewriteUrl.searchParams.set('locale', locale) @@ -193,23 +191,42 @@ function withLocaleCookie(response: NextResponse, locale: LocaleCode) { return response } -function redirectToCookieLocale(request: NextRequest, route: LocaleRoute): NextResponse | null { - const preferredLocale = request.cookies.get(LOCALE_COOKIE)?.value +function resolveCanonicalLocaleRoute(request: NextRequest, route: LocaleRoute): LocaleRoute { + if (isCanonicalRouteHandlerPath(request.nextUrl.pathname)) { + return route + } + + if (route.hasLocalePrefix) { + return route + } - if ( - route.hasLocalePrefix || - isCanonicalRouteHandlerPath(request.nextUrl.pathname) || - (request.method !== 'GET' && request.method !== 'HEAD') || - !preferredLocale || - !isLocaleCode(preferredLocale) || - preferredLocale === defaultLocale - ) { + return { ...route, locale: resolveRequestLocale(request) } +} + +function routeToCanonicalLocale( + request: NextRequest, + route: LocaleRoute, + requestHeaders?: Headers +): NextResponse | null { + if (isCanonicalRouteHandlerPath(request.nextUrl.pathname)) { return null } - const redirectUrl = new URL(localizeUrl(request.nextUrl.origin, preferredLocale, route.pathname)) - redirectUrl.search = request.nextUrl.search - return NextResponse.redirect(redirectUrl) + const requestRoute = resolveLocaleRoute(request.nextUrl.pathname) + if (route.hasLocalePrefix && requestRoute.locale === route.locale) { + return null + } + + const targetUrl = new URL(localizeUrl(request.nextUrl.origin, route.locale, route.pathname)) + targetUrl.search = request.nextUrl.search + + if (request.method === 'GET' || request.method === 'HEAD') { + return withLocaleCookie(NextResponse.redirect(targetUrl), route.locale) + } + + return requestHeaders + ? NextResponse.rewrite(targetUrl, { request: { headers: requestHeaders } }) + : NextResponse.rewrite(targetUrl) } function handleSecurityFiltering(request: NextRequest): NextResponse | null { @@ -246,39 +263,21 @@ function handleSecurityFiltering(request: NextRequest): NextResponse | null { export async function proxy(request: NextRequest) { const url = request.nextUrl - const route = resolveLocaleRoute(url.pathname) + const initialRoute = resolveLocaleRoute(url.pathname) + const route = resolveCanonicalLocaleRoute(request, initialRoute) const { locale, pathname: normalizedPathname } = route - const hasActiveSession = Boolean(getSessionCookie(request)) const isProtectedPath = isProtectedAppPath(url.pathname) - const reauth = url.searchParams.get('reauth') === '1' - if (isProtectedPath && !hasActiveSession) { - const callbackTarget = getCanonicalCallbackPath(url.pathname, url.search) - return buildLoginRedirect(request, callbackTarget) - } - - if (isAuthRoute(url.pathname)) { - if (reauth) { - const response = handleI18nRouting(request) - clearAuthCookies(response) - return route.hasLocalePrefix ? withLocaleCookie(response, locale) : response - } - - if (hasActiveSession) { - const requestLocale = resolveRequestLocale(request, route) - return withLocaleCookie( - NextResponse.redirect(new URL(localizeUrl(url.origin, requestLocale, '/workspace'))), - requestLocale - ) - } - } + const protectedRequestHeaders = isProtectedPath + ? buildProtectedRequestHeaders(request, route) + : undefined const securityBlock = handleSecurityFiltering(request) if (securityBlock) return securityBlock - const localeRedirect = redirectToCookieLocale(request, route) - if (localeRedirect) return localeRedirect + const localeResponse = routeToCanonicalLocale(request, route, protectedRequestHeaders) + if (localeResponse) return localeResponse const markdownRewrite = rewriteMarkdownRequest(request) if (markdownRewrite) return markdownRewrite @@ -291,15 +290,12 @@ export async function proxy(request: NextRequest) { return response } - if (isProtectedPath) { - const requestHeaders = new Headers(request.headers) - requestHeaders.set( - CANONICAL_CALLBACK_PATH_HEADER, - getCanonicalCallbackPath(url.pathname, url.search) + if (protectedRequestHeaders) { + NextResponse.next({ request: { headers: protectedRequestHeaders } }).headers.forEach( + (value, key) => { + response.headers.set(key, value) + } ) - NextResponse.next({ request: { headers: requestHeaders } }).headers.forEach((value, key) => { - response.headers.set(key, value) - }) } response.headers.set('Vary', appendVaryHeader(appendVaryHeader(null, 'User-Agent'), 'Accept')) @@ -316,7 +312,7 @@ export async function proxy(request: NextRequest) { appendHomepageDiscoveryLinks(response.headers, locale) } - return route.hasLocalePrefix ? withLocaleCookie(response, locale) : response + return isCanonicalRouteHandlerPath(url.pathname) ? response : withLocaleCookie(response, locale) } export const config = { diff --git a/apps/tradinggoose/services/queue/index.ts b/apps/tradinggoose/services/queue/index.ts index 13de63515..45be40759 100644 --- a/apps/tradinggoose/services/queue/index.ts +++ b/apps/tradinggoose/services/queue/index.ts @@ -1,3 +1,3 @@ export { ExecutionLimiter } from '@/services/queue/ExecutionLimiter' -export type { TriggerType } from '@/services/queue/types' +export type { QueuedWorkflowTriggerType, TriggerType } from '@/services/queue/types' export { RateLimitError } from '@/services/queue/types' diff --git a/apps/tradinggoose/services/queue/types.ts b/apps/tradinggoose/services/queue/types.ts index d73167d6b..6561d13c9 100644 --- a/apps/tradinggoose/services/queue/types.ts +++ b/apps/tradinggoose/services/queue/types.ts @@ -1,5 +1,6 @@ // Trigger types for rate limiting export type TriggerType = 'api' | 'webhook' | 'schedule' | 'manual' | 'chat' | 'api-endpoint' +export type QueuedWorkflowTriggerType = Exclude // Rate limit counter types - which counter to increment in the database export type RateLimitCounterType = 'sync' | 'async' | 'api-endpoint' diff --git a/apps/tradinggoose/stores/copilot/store.test.ts b/apps/tradinggoose/stores/copilot/store.test.ts index c3f2d42f2..fc859bec6 100644 --- a/apps/tradinggoose/stores/copilot/store.test.ts +++ b/apps/tradinggoose/stores/copilot/store.test.ts @@ -859,7 +859,7 @@ describe('copilot streaming regressions', () => { expect(store.getState().isAwaitingContinuation).toBe(false) }) - it('treats awaiting_tools as a pause and skips terminal billing fetch', async () => { + it('treats awaiting_tools as a pause and skips terminal context usage refresh', async () => { const channelId = 'copilot-awaiting-tools-pause' const store = getCopilotStore(channelId) const fetchMock = vi.fn(async (input: RequestInfo | URL) => { @@ -1345,8 +1345,7 @@ describe('copilot streaming regressions', () => { name: 'edit_workflow', arguments: { entityId: 'wf-limited-edit', - entityDocument: 'workflow: {}', - documentFormat: 'tg-mermaid-v1', + entityDocument: 'flowchart TD', }, }, }, @@ -2807,7 +2806,7 @@ describe('copilot context usage', () => { store.setState({ currentChat: { reviewSessionId: 'review-context-usage-generic', - workspaceId: null, + workspaceId: 'workspace-context-usage', entityKind: 'copilot', entityId: null, draftSessionId: null, @@ -2839,6 +2838,7 @@ describe('copilot context usage', () => { conversationId: 'conversation-context-usage-generic', model: 'claude-sonnet-4.6', provider: 'anthropic', + workspaceId: 'workspace-context-usage', }) expect(store.getState().contextUsage).toEqual({ usage: 1234, diff --git a/apps/tradinggoose/stores/copilot/store.ts b/apps/tradinggoose/stores/copilot/store.ts index d97835d98..ae8c0d4e0 100644 --- a/apps/tradinggoose/stores/copilot/store.ts +++ b/apps/tradinggoose/stores/copilot/store.ts @@ -4,7 +4,7 @@ import { createContext, createElement, type ReactNode, useContext, useMemo } fro import type { StoreApi } from 'zustand' import { devtools } from 'zustand/middleware' import { createWithEqualityFn as create, useStoreWithEqualityFn } from 'zustand/traditional' -import { shouldAutoExecuteTool } from '@/lib/copilot/access-policy' +import { shouldRequireToolApproval } from '@/lib/copilot/access-policy' import { type CopilotChat, sendStreamingMessage } from '@/lib/copilot/api' import { mergeCopilotContexts } from '@/lib/copilot/chat-contexts' import { DEFAULT_COPILOT_RUNTIME_MODEL } from '@/lib/copilot/runtime-models' @@ -70,6 +70,7 @@ import { ensureClientToolInstance, handleCopilotServerToolSuccess, isCopilotTool, + isGatedTool, isServerManagedCopilotTool, prepareCopilotToolArgs, resolveToolDisplay, @@ -275,10 +276,6 @@ function autoExecutePendingToolsForAccessLevel( accessLevel: CopilotStore['accessLevel'], get: () => CopilotStore ) { - if (!shouldAutoExecuteTool(accessLevel)) { - return - } - const { toolCallsById } = get() const copilotToolIds: string[] = [] @@ -287,7 +284,10 @@ function autoExecutePendingToolsForAccessLevel( continue } - if (isCopilotTool(toolCall.name)) { + if ( + isCopilotTool(toolCall.name) && + !shouldRequireToolApproval(accessLevel, isGatedTool(toolCall.name)) + ) { copilotToolIds.push(id) } } @@ -1172,13 +1172,7 @@ const createCopilotStoreInstance = (storeChannelId = DEFAULT_COPILOT_CHANNEL_ID) // Fetch context usage after response completes if (!context.awaitingTools) { logger.info('[Context Usage] Stream completed, fetching usage') - const billingOptions = assistantMessageId - ? { - bill: true, - assistantMessageId, - } - : undefined - await get().fetchContextUsage(billingOptions) + await get().fetchContextUsage() } } finally { abortSignal?.removeEventListener('abort', cancelReader) @@ -1270,9 +1264,8 @@ const createCopilotStoreInstance = (storeChannelId = DEFAULT_COPILOT_CHANNEL_ID) setAgentPrefetch: (prefetch) => set({ agentPrefetch: prefetch }), // Fetch context usage from copilot API - fetchContextUsage: async (options?: { bill?: boolean; assistantMessageId?: string }) => { + fetchContextUsage: async () => { try { - const { bill = false, assistantMessageId } = options ?? {} const { currentChat, selectedModel } = get() const selectedProvider = resolveCopilotRuntimeProvider(selectedModel) logger.info('[Context Usage] Starting fetch', { @@ -1280,8 +1273,6 @@ const createCopilotStoreInstance = (storeChannelId = DEFAULT_COPILOT_CHANNEL_ID) conversationId: currentChat?.conversationId, model: selectedModel, provider: selectedProvider, - bill, - assistantMessageId, }) if (!currentChat) { @@ -1303,15 +1294,8 @@ const createCopilotStoreInstance = (storeChannelId = DEFAULT_COPILOT_CHANNEL_ID) conversationId: currentChat.conversationId, model: selectedModel, provider: selectedProvider, + ...(currentChat.workspaceId ? { workspaceId: currentChat.workspaceId } : {}), } - // Generic Copilot context usage is conversation/user scoped. Workflow contexts are - // prompt context for the chat, not billing scope selectors for this widget. - if (bill && assistantMessageId) { - requestPayload.bill = true - requestPayload.assistantMessageId = assistantMessageId - requestPayload.billingModel = selectedModel - } - logger.info('[Context Usage] Calling API', requestPayload) // Call the backend API route which proxies to copilot @@ -1510,7 +1494,7 @@ const createCopilotStoreInstance = (storeChannelId = DEFAULT_COPILOT_CHANNEL_ID) syncClientToolInstanceState(id, instance) if ( stateBeforeUserAction !== ClientToolCallState.review && - shouldAutoExecuteTool(get().accessLevel) && + !shouldRequireToolApproval(get().accessLevel, true) && get().toolCallsById[id]?.state === ClientToolCallState.review && typeof instance.handleUserAction === 'function' ) { diff --git a/apps/tradinggoose/stores/copilot/types.ts b/apps/tradinggoose/stores/copilot/types.ts index c0306d6ca..d7f07387f 100644 --- a/apps/tradinggoose/stores/copilot/types.ts +++ b/apps/tradinggoose/stores/copilot/types.ts @@ -162,7 +162,7 @@ export interface CopilotActions { setAccessLevel: (accessLevel: CopilotAccessLevel) => void setSelectedModel: (model: CopilotStore['selectedModel']) => Promise setAgentPrefetch: (prefetch: boolean) => void - fetchContextUsage: (options?: { bill?: boolean; assistantMessageId?: string }) => Promise + fetchContextUsage: () => Promise loadChats: (options?: { workspaceId?: string | null }) => Promise selectChat: (chat: CopilotChat) => Promise diff --git a/apps/tradinggoose/stores/index.ts b/apps/tradinggoose/stores/index.ts index 1708f0438..124a01d7d 100644 --- a/apps/tradinggoose/stores/index.ts +++ b/apps/tradinggoose/stores/index.ts @@ -1,8 +1,7 @@ 'use client' -import { useEffect } from 'react' import { createLogger } from '@/lib/logs/console/logger' -import { stripLocaleFromPathname } from '@/i18n/utils' +import { resetWorkspacePermissionsStore } from '@/hooks/use-workspace-permissions' import { useConsoleStore } from '@/stores/console/store' import { getCopilotStore, useCopilotStore } from '@/stores/copilot/store' import { useCustomToolsStore } from '@/stores/custom-tools/store' @@ -14,113 +13,6 @@ import { useSubscriptionStore } from '@/stores/subscription/store' import { useWorkflowRegistry } from '@/stores/workflows/registry/store' const logger = createLogger('Stores') -const BEFORE_UNLOAD_AUTH_PATHS = new Set(['/login', '/signup', '/reset-password', '/verify']) - -// Track initialization state -let isInitializing = false -let appFullyInitialized = false -let dataInitialized = false // Flag for actual data loading completion - -const AUTH_COOKIE_KEYS = [ - 'better-auth.session_token', - 'better-auth.session_data', - 'better-auth.dont_remember', - '__Secure-better-auth.session_token', - '__Secure-better-auth.session_data', - '__Secure-better-auth.dont_remember', -] - -function hasAuthCookie(): boolean { - if (typeof document === 'undefined') return false - return AUTH_COOKIE_KEYS.some((key) => document.cookie.includes(`${key}=`)) -} - -/** - * Initialize the application state and sync system - * localStorage persistence has been removed - relies on DB and Zustand stores only - */ -async function initializeApplication(): Promise { - if (typeof window === 'undefined' || isInitializing) return - - // Skip initialization entirely when no auth cookie is present to avoid - // unauthenticated 401 loops while the app is redirecting to /login. - if (!hasAuthCookie()) { - logger.info('Auth cookie missing, skipping app initialization') - appFullyInitialized = false - dataInitialized = false - return - } - - isInitializing = true - appFullyInitialized = false - - // Track initialization start time - const initStartTime = Date.now() - - try { - // Load environment variables directly from DB - await useEnvironmentStore.getState().loadEnvironmentVariables() - - // Mark data as initialized only after sync managers have loaded data from DB - dataInitialized = true - - // Log initialization timing information - const initDuration = Date.now() - initStartTime - logger.info(`Application initialization completed in ${initDuration}ms`) - - // Mark application as fully initialized - appFullyInitialized = true - } catch (error) { - logger.error('Error during application initialization:', { error }) - // Still mark as initialized to prevent being stuck in initializing state - appFullyInitialized = true - // But don't mark data as initialized on error - dataInitialized = false - } finally { - isInitializing = false - } -} - -/** - * Checks if application is fully initialized - */ -export function isAppInitialized(): boolean { - return appFullyInitialized -} - -/** - * Checks if data has been loaded from the database - * This should be checked before any sync operations - */ -export function isDataInitialized(): boolean { - return dataInitialized -} - -/** - * Handle application cleanup before unload - */ -function handleBeforeUnload(event: BeforeUnloadEvent): void { - // Check if we're on an authentication page and skip confirmation if we are - if (typeof window !== 'undefined') { - const path = stripLocaleFromPathname(window.location.pathname).pathname - // Skip confirmation for auth-related pages - if (BEFORE_UNLOAD_AUTH_PATHS.has(path)) { - return - } - } - - // Standard beforeunload pattern - event.preventDefault() - event.returnValue = '' -} - -/** - * Clean up sync system - */ -function cleanupApplication(): void { - window.removeEventListener('beforeunload', handleBeforeUnload) - // Note: No sync managers to dispose - Socket.IO handles cleanup -} /** * Clear all user data when signing out @@ -140,74 +32,12 @@ export async function clearUserData(): Promise { const keysToRemove = Object.keys(localStorage).filter((key) => !keysToKeep.includes(key)) keysToRemove.forEach((key) => localStorage.removeItem(key)) - // Reset application initialization state - appFullyInitialized = false - dataInitialized = false - logger.info('User data cleared successfully') } catch (error) { logger.error('Error clearing user data:', { error }) } } -/** - * Hook to manage application lifecycle - */ -export function useAppInitialization() { - useEffect(() => { - // Use Promise to handle async initialization - initializeApplication() - - return () => { - cleanupApplication() - } - }, []) -} - -/** - * Hook to reinitialize the application after successful login - * Use this in the login success handler or post-login page - */ -export function useLoginInitialization() { - useEffect(() => { - reinitializeAfterLogin() - }, []) -} - -/** - * Reinitialize the application after login - * This ensures we load fresh data from the database for the new user - */ -export async function reinitializeAfterLogin(): Promise { - if (typeof window === 'undefined') return - - try { - // Reset application initialization state - appFullyInitialized = false - dataInitialized = false - - // Note: No sync managers to dispose - Socket.IO handles cleanup - - // Clean existing state to avoid stale data - resetAllStores() - - // Reset initialization flags to force a fresh load - isInitializing = false - - // Reinitialize the application - await initializeApplication() - - logger.info('Application reinitialized after login') - } catch (error) { - logger.error('Error reinitializing application:', { error }) - } -} - -// Initialize immediately when imported on client -if (typeof window !== 'undefined') { - initializeApplication() -} - // Export all stores export { useWorkflowRegistry, @@ -243,6 +73,7 @@ export const resetAllStores = () => { useCustomToolsStore.getState().resetAll() useSkillsStore.getState().resetAll() useIndicatorsStore.getState().resetAll() + resetWorkspacePermissionsStore() // Variables store has no tracking to reset; registry hydrates useSubscriptionStore.getState().reset() // Reset subscription store } diff --git a/apps/tradinggoose/stores/organization/store.ts b/apps/tradinggoose/stores/organization/store.ts index 2edfa6245..edd2d234b 100644 --- a/apps/tradinggoose/stores/organization/store.ts +++ b/apps/tradinggoose/stores/organization/store.ts @@ -1,8 +1,8 @@ -import { createWithEqualityFn as create } from 'zustand/traditional' import { devtools } from 'zustand/middleware' +import { createWithEqualityFn as create } from 'zustand/traditional' import { client } from '@/lib/auth-client' import { createLogger } from '@/lib/logs/console/logger' -import { useGeneralStore } from '@/stores/settings/general/store' +import { stripLocaleFromPathname } from '@/i18n/utils' import type { OrganizationStore, WorkspaceInvitation } from '@/stores/organization/types' import { calculateSeatUsage, @@ -15,6 +15,12 @@ const logger = createLogger('OrganizationStore') const CACHE_DURATION = 30 * 1000 +function getCurrentRouteLocale() { + return typeof window === 'undefined' + ? 'en' + : stripLocaleFromPathname(window.location.pathname).locale +} + export const useOrganizationStore = create()( devtools( (set, get) => ({ @@ -291,18 +297,8 @@ export const useOrganizationStore = create()( if (permissionResponse.ok) { const permissionData = await permissionResponse.json() - // Check if current user has admin permission - // Use userId if provided, otherwise fall back to checking isOwner from workspace data - let hasAdminAccess = false - - if (userId && permissionData.users) { - const currentUserPermission = permissionData.users.find( - (user: any) => user.id === userId || user.userId === userId - ) - hasAdminAccess = currentUserPermission?.permissionType === 'admin' - } + const hasAdminAccess = permissionData.currentUserPermission === 'admin' - // Also check if user is the workspace owner const isOwner = workspace.isOwner || workspace.ownerId === userId if (hasAdminAccess || isOwner) { @@ -519,7 +515,7 @@ export const useOrganizationStore = create()( workspaceInvitations, }) - const locale = useGeneralStore.getState().preferredLocale + const locale = getCurrentRouteLocale() const inviteUrl = workspaceInvitations && workspaceInvitations.length > 0 ? `/api/organizations/${activeOrganization.id}/invitations?batch=true` diff --git a/apps/tradinggoose/stores/settings/environment/store.ts b/apps/tradinggoose/stores/settings/environment/store.ts index aebe01f50..3033553f3 100644 --- a/apps/tradinggoose/stores/settings/environment/store.ts +++ b/apps/tradinggoose/stores/settings/environment/store.ts @@ -1,6 +1,4 @@ import { createWithEqualityFn as create } from 'zustand/traditional' -import { handleAuthError } from '@/lib/auth/auth-error-handler' -import { fetchPersonalEnvironment, fetchWorkspaceEnvironment } from '@/lib/environment/api' import { createLogger } from '@/lib/logs/console/logger' import { API_ENDPOINTS } from '@/stores/constants' import type { EnvironmentStore, EnvironmentVariable } from '@/stores/settings/environment/types' @@ -16,10 +14,15 @@ export const useEnvironmentStore = create()((set, get) => ({ try { set({ isLoading: true, error: null }) - const data = await fetchPersonalEnvironment() + const response = await fetch(API_ENDPOINTS.ENVIRONMENT, { cache: 'no-store' }) + if (!response.ok) { + throw new Error(`Failed to load environment variables: ${response.statusText}`) + } + + const { data } = await response.json() set({ - variables: data, + variables: data && typeof data === 'object' ? data : {}, isLoading: false, }) } catch (error) { @@ -35,115 +38,6 @@ export const useEnvironmentStore = create()((set, get) => ({ set({ variables }) }, - saveEnvironmentVariables: async (variables: Record) => { - try { - set({ isLoading: true, error: null }) - - const transformedVariables = Object.fromEntries( - Object.entries(variables).map(([key, value]) => [key, { key, value }]) - ) - - set({ variables: transformedVariables }) - - const response = await fetch(API_ENDPOINTS.ENVIRONMENT, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ variables }), - }) - - if (!response.ok) { - if (response.status === 401) { - await handleAuthError('environment-store:save') - } - throw new Error(`Failed to save environment variables: ${response.statusText}`) - } - - set({ isLoading: false }) - } catch (error) { - logger.error('Error saving environment variables:', { error }) - set({ - error: error instanceof Error ? error.message : 'Unknown error', - isLoading: false, - }) - - get().loadEnvironmentVariables() - } - }, - - loadWorkspaceEnvironment: async (workspaceId: string) => { - try { - set({ isLoading: true, error: null }) - - const data = await fetchWorkspaceEnvironment(workspaceId) - set({ isLoading: false }) - return data as { - workspace: Record - personal: Record - conflicts: string[] - workspaceRows?: Array<{ - key: string - value: string - createdAt?: string | null - updatedAt?: string | null - }> - personalRows?: Array<{ - key: string - value: string - createdAt?: string | null - updatedAt?: string | null - }> - } - } catch (error) { - logger.error('Error loading workspace environment:', { error }) - set({ error: error instanceof Error ? error.message : 'Unknown error', isLoading: false }) - return { workspace: {}, personal: {}, conflicts: [] } - } - }, - - upsertWorkspaceEnvironment: async (workspaceId: string, variables: Record) => { - try { - set({ isLoading: true, error: null }) - const response = await fetch(API_ENDPOINTS.WORKSPACE_ENVIRONMENT(workspaceId), { - method: 'PUT', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ variables }), - }) - if (!response.ok) { - if (response.status === 401) { - await handleAuthError('environment-store:upsert-workspace') - } - throw new Error(`Failed to update workspace environment: ${response.statusText}`) - } - set({ isLoading: false }) - } catch (error) { - logger.error('Error updating workspace environment:', { error }) - set({ error: error instanceof Error ? error.message : 'Unknown error', isLoading: false }) - } - }, - - removeWorkspaceEnvironmentKeys: async (workspaceId: string, keys: string[]) => { - try { - set({ isLoading: true, error: null }) - const response = await fetch(API_ENDPOINTS.WORKSPACE_ENVIRONMENT(workspaceId), { - method: 'DELETE', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ keys }), - }) - if (!response.ok) { - if (response.status === 401) { - await handleAuthError('environment-store:remove-keys') - } - throw new Error(`Failed to remove workspace environment keys: ${response.statusText}`) - } - set({ isLoading: false }) - } catch (error) { - logger.error('Error removing workspace environment keys:', { error }) - set({ error: error instanceof Error ? error.message : 'Unknown error', isLoading: false }) - } - }, - getAllVariables: (): Record => { return get().variables }, diff --git a/apps/tradinggoose/stores/settings/environment/types.ts b/apps/tradinggoose/stores/settings/environment/types.ts index e4aa4a56b..9b60b786b 100644 --- a/apps/tradinggoose/stores/settings/environment/types.ts +++ b/apps/tradinggoose/stores/settings/environment/types.ts @@ -12,30 +12,5 @@ export interface EnvironmentState { export interface EnvironmentStore extends EnvironmentState { loadEnvironmentVariables: () => Promise setVariables: (variables: Record) => void - saveEnvironmentVariables: (variables: Record) => Promise - - loadWorkspaceEnvironment: (workspaceId: string) => Promise<{ - workspace: Record - personal: Record - conflicts: string[] - workspaceRows?: Array<{ - key: string - value: string - createdAt?: string | null - updatedAt?: string | null - }> - personalRows?: Array<{ - key: string - value: string - createdAt?: string | null - updatedAt?: string | null - }> - }> - upsertWorkspaceEnvironment: ( - workspaceId: string, - variables: Record - ) => Promise - removeWorkspaceEnvironmentKeys: (workspaceId: string, keys: string[]) => Promise - getAllVariables: () => Record } diff --git a/apps/tradinggoose/stores/settings/general/store.ts b/apps/tradinggoose/stores/settings/general/store.ts index 47504a6a3..d49f413c8 100644 --- a/apps/tradinggoose/stores/settings/general/store.ts +++ b/apps/tradinggoose/stores/settings/general/store.ts @@ -12,7 +12,6 @@ export const useGeneralStore = create()( (set, get) => { const store: General = { theme: 'system', - preferredLocale: 'en', telemetryEnabled: true, isLoading: false, error: null, @@ -44,6 +43,10 @@ export const useGeneralStore = create()( return { ...store, setSettings: (settings) => { + if (settings.theme) { + syncThemeToNextThemes(settings.theme) + } + set((state) => ({ ...state, ...settings, @@ -93,7 +96,11 @@ export const useGeneralStore = create()( throw new Error(`Failed to update setting: ${key}`) } - set({ [key]: value, error: null } as Partial) + if (key === 'preferredLocale') { + set({ error: null }) + } else { + set({ [key]: value, error: null } as Partial) + } } catch (error) { logger.error(`Error updating setting ${key}:`, error) set({ error: error instanceof Error ? error.message : 'Unknown error' }) diff --git a/apps/tradinggoose/stores/settings/general/types.ts b/apps/tradinggoose/stores/settings/general/types.ts index 0e598cfcd..30a83b966 100644 --- a/apps/tradinggoose/stores/settings/general/types.ts +++ b/apps/tradinggoose/stores/settings/general/types.ts @@ -1,6 +1,5 @@ export interface General { theme: 'system' | 'light' | 'dark' - preferredLocale: 'en' | 'es' | 'zh' telemetryEnabled: boolean isLoading: boolean error: string | null diff --git a/apps/tradinggoose/stores/workflows/json/importer.test.ts b/apps/tradinggoose/stores/workflows/json/importer.test.ts index e0e689f2f..d99e067fd 100644 --- a/apps/tradinggoose/stores/workflows/json/importer.test.ts +++ b/apps/tradinggoose/stores/workflows/json/importer.test.ts @@ -32,7 +32,6 @@ describe('workflow json importer', () => { { name: ' Primary Workflow ', description: ' Workflow used for trading ', - color: ' #3972F6 ', state: createWorkflowState(), }, ], @@ -48,7 +47,6 @@ describe('workflow json importer', () => { expect(data).toMatchObject({ name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', state: { blocks: { block_1: { @@ -77,7 +75,6 @@ describe('workflow json importer', () => { { name: ' Primary Workflow ', description: ' Workflow used for trading ', - color: ' #3972F6 ', state: createWorkflowState(), }, ], @@ -93,7 +90,6 @@ describe('workflow json importer', () => { expect(data).toMatchObject({ name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', skills: [ { name: 'Market Research', @@ -122,7 +118,6 @@ describe('workflow json importer', () => { { name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', state: createWorkflowState(), }, ], @@ -160,7 +155,6 @@ describe('workflow json importer', () => { { name: 'Primary Workflow', description: 'Workflow used for trading', - color: '#3972F6', state: createWorkflowState(), }, ], diff --git a/apps/tradinggoose/stores/workflows/json/importer.ts b/apps/tradinggoose/stores/workflows/json/importer.ts index 5486f8a8a..5f89c58c3 100644 --- a/apps/tradinggoose/stores/workflows/json/importer.ts +++ b/apps/tradinggoose/stores/workflows/json/importer.ts @@ -156,7 +156,6 @@ export function parseWorkflowJson( logger.info('Successfully parsed workflow JSON', { name: workflowData.name, description: workflowData.description, - color: workflowData.color, blocksCount: Object.keys(workflowData.state.blocks).length, edgesCount: workflowData.state.edges.length, loopsCount: Object.keys(workflowData.state.loops).length, diff --git a/apps/tradinggoose/stores/workflows/json/store.ts b/apps/tradinggoose/stores/workflows/json/store.ts index f75f85c43..5020fc270 100644 --- a/apps/tradinggoose/stores/workflows/json/store.ts +++ b/apps/tradinggoose/stores/workflows/json/store.ts @@ -86,7 +86,6 @@ export const useWorkflowJsonStore = create()( workflow: { name: currentWorkflow.name, description: currentWorkflow.description ?? '', - color: currentWorkflow.color ?? '', state: workflowSnapshot, }, skills: workspaceSkills, diff --git a/apps/tradinggoose/stores/workflows/registry/store.ts b/apps/tradinggoose/stores/workflows/registry/store.ts index 02611846d..94166bcf0 100644 --- a/apps/tradinggoose/stores/workflows/registry/store.ts +++ b/apps/tradinggoose/stores/workflows/registry/store.ts @@ -1,5 +1,5 @@ -import { createWithEqualityFn as create } from 'zustand/traditional' import { devtools } from 'zustand/middleware' +import { createWithEqualityFn as create } from 'zustand/traditional' import { getStableVibrantColor } from '@/lib/colors' import { createLogger } from '@/lib/logs/console/logger' import { generateCreativeWorkflowName } from '@/lib/naming' @@ -799,7 +799,7 @@ export const useWorkflowRegistry = create()( } logger.warn( - `Workflow ${workflowId} has no state in DB - this should not happen with server-side start block creation` + `Workflow ${workflowId} has no state in DB - this should not happen with server-side trigger block creation` ) } @@ -887,12 +887,6 @@ export const useWorkflowRegistry = create()( workspaceId, folderId: options.folderId || null, } - if (typeof options.color === 'string') { - requestBody.color = options.color - } - if (options.marketplaceId) { - requestBody.color = '#808080' - } const response = await fetch('/api/workflows', { method: 'POST', @@ -1069,9 +1063,7 @@ export const useWorkflowRegistry = create()( }, error: null, })) - logger.info( - `Duplicated workflow ${sourceId} to ${id} in workspace ${workspaceId}` - ) + logger.info(`Duplicated workflow ${sourceId} to ${id} in workspace ${workspaceId}`) return id }, @@ -1181,7 +1173,10 @@ export const useWorkflowRegistry = create()( }, // Update workflow metadata - updateWorkflow: async (id: string, metadata: Partial) => { + updateWorkflow: async ( + id: string, + metadata: Partial> + ) => { const { workflows } = get() const workflow = workflows[id] if (!workflow) { diff --git a/apps/tradinggoose/stores/workflows/registry/types.ts b/apps/tradinggoose/stores/workflows/registry/types.ts index 95670f699..d1214e0e8 100644 --- a/apps/tradinggoose/stores/workflows/registry/types.ts +++ b/apps/tradinggoose/stores/workflows/registry/types.ts @@ -63,14 +63,16 @@ export interface WorkflowRegistryActions { id: string, options?: { skipApi?: boolean; templateAction?: 'keep' | 'delete' } ) => Promise - updateWorkflow: (id: string, metadata: Partial) => Promise + updateWorkflow: ( + id: string, + metadata: Partial> + ) => Promise createWorkflow: (options?: { isInitial?: boolean marketplaceId?: string marketplaceState?: any name?: string description?: string - color?: string workspaceId?: string folderId?: string | null }) => Promise diff --git a/apps/tradinggoose/stores/workflows/workflow/store.ts b/apps/tradinggoose/stores/workflows/workflow/store.ts index d75319fc3..b7489586c 100644 --- a/apps/tradinggoose/stores/workflows/workflow/store.ts +++ b/apps/tradinggoose/stores/workflows/workflow/store.ts @@ -4,6 +4,7 @@ import { devtools } from 'zustand/middleware' import { createStore, type StoreApi } from 'zustand/vanilla' import { createLogger } from '@/lib/logs/console/logger' import { resolveBlockRuntimeState } from '@/lib/workflows/block-outputs' +import { buildInitialSubBlockStates } from '@/lib/workflows/subblock-values' import { getBlock } from '@/blocks' import { useWorkflowRegistry } from '@/stores/workflows/registry/store' import { useSubBlockStore } from '@/stores/workflows/subblock/store' @@ -118,15 +119,10 @@ const createWorkflowStoreState = ...(parentId && { parentId, extent: extent || 'parent' }), } - let subBlocks: Record = {} - blockConfig.subBlocks.forEach((subBlock) => { - const subBlockId = subBlock.id - subBlocks[subBlockId] = { - id: subBlockId, - type: subBlock.type, - value: null, - } - }) + let subBlocks = buildInitialSubBlockStates(blockConfig.subBlocks) as Record< + string, + SubBlockState + > const triggerMode = blockProperties?.triggerMode ?? false const runtimeState = resolveBlockRuntimeState({ diff --git a/apps/tradinggoose/stores/workflows/workflow/utils.test.ts b/apps/tradinggoose/stores/workflows/workflow/utils.test.ts index 01403f95d..96eb5dbe6 100644 --- a/apps/tradinggoose/stores/workflows/workflow/utils.test.ts +++ b/apps/tradinggoose/stores/workflows/workflow/utils.test.ts @@ -1,6 +1,43 @@ import { describe, expect, it } from 'vitest' import type { BlockState } from '@/stores/workflows/workflow/types' -import { convertLoopBlockToLoop } from '@/stores/workflows/workflow/utils' +import { + buildExecutableWorkflowData, + convertLoopBlockToLoop, +} from '@/stores/workflows/workflow/utils' + +const block = (id: string, type = 'agent', extra: Partial = {}): BlockState => ({ + id, + type, + name: id, + position: { x: 0, y: 0 }, + subBlocks: {}, + outputs: {}, + enabled: true, + ...extra, +}) + +describe('buildExecutableWorkflowData', () => { + it.concurrent('keeps blocks, edges, loops, and parallels consistent with enabled blocks', () => { + const blocks: Record = { + trigger: block('trigger', 'manual_trigger'), + loop: block('loop', 'loop'), + parallel: block('parallel', 'parallel'), + active: block('active', 'agent', { data: { parentId: 'loop' } }), + disabled: block('disabled', 'agent', { enabled: false, data: { parentId: 'parallel' } }), + } + + const result = buildExecutableWorkflowData(blocks, [ + { id: 'edge-1', source: 'trigger', target: 'active' }, + { id: 'edge-2', source: 'active', target: 'disabled' }, + { id: 'edge-3', source: 'disabled', target: 'parallel' }, + ]) + + expect(Object.keys(result.blocks).sort()).toEqual(['active', 'loop', 'parallel', 'trigger']) + expect(result.edges).toEqual([{ id: 'edge-1', source: 'trigger', target: 'active' }]) + expect(result.loops.loop.nodes).toEqual(['active']) + expect(result.parallels.parallel.nodes).toEqual([]) + }) +}) describe('convertLoopBlockToLoop', () => { it.concurrent('should parse JSON array string for forEach loops', () => { diff --git a/apps/tradinggoose/stores/workflows/workflow/utils.ts b/apps/tradinggoose/stores/workflows/workflow/utils.ts index 341d0d30a..f240527a1 100644 --- a/apps/tradinggoose/stores/workflows/workflow/utils.ts +++ b/apps/tradinggoose/stores/workflows/workflow/utils.ts @@ -1,3 +1,4 @@ +import type { Edge } from '@xyflow/react' import type { BlockState, Loop, Parallel } from '@/stores/workflows/workflow/types' const DEFAULT_LOOP_ITERATIONS = 5 @@ -200,3 +201,20 @@ export function generateParallelBlocks( return parallels } + +export function buildExecutableWorkflowData(blocks: Record, edges: Edge[]) { + const executableBlocks = Object.fromEntries( + Object.entries(blocks).filter(([, block]) => block?.type && block.enabled !== false) + ) + const executableBlockIds = new Set(Object.keys(executableBlocks)) + const executableEdges = edges.filter( + (edge) => executableBlockIds.has(edge.source) && executableBlockIds.has(edge.target) + ) + + return { + blocks: executableBlocks, + edges: executableEdges, + loops: generateLoopBlocks(executableBlocks), + parallels: generateParallelBlocks(executableBlocks), + } +} diff --git a/apps/tradinggoose/telemetry.config.ts b/apps/tradinggoose/telemetry.config.ts deleted file mode 100644 index 218a1080c..000000000 --- a/apps/tradinggoose/telemetry.config.ts +++ /dev/null @@ -1,118 +0,0 @@ -/** - * TradingGoose OpenTelemetry Configuration - * - * PRIVACY NOTICE: - * - Telemetry is enabled by default to help us improve the product - * - You can disable telemetry via: - * 1. Settings UI > Privacy tab > Toggle off "Allow anonymous telemetry" - * 2. Setting NEXT_TELEMETRY_DISABLED=1 environment variable - * - * This file allows you to configure OpenTelemetry collection for your - * TradingGoose instance. If you've forked the repository, you can modify - * this file to send telemetry to your own collector. - * - * We only collect anonymous usage data to improve the product: - * - Feature usage statistics - * - Error rates (always captured) - * - Performance metrics (sampled at 10%) - * - AI/LLM operation traces (always captured for workflows) - * - * We NEVER collect: - * - Personal information - * - Workflow content or outputs - * - API keys or tokens - * - IP addresses or geolocation data - */ -import { env } from './lib/env' - -const config = { - /** - * OTLP Endpoint URL where telemetry data is sent - * Change this if you want to send telemetry to your own collector - * Supports any OTLP-compatible backend (Jaeger, Grafana Tempo, etc.) - */ - endpoint: env.TELEMETRY_ENDPOINT || 'https://telemetry.tradinggoose.ai/v1/traces', - - /** - * Service name used to identify this instance - * You can change this for your fork - */ - serviceName: 'tradinggoose-studio', - - /** - * Version of the service, defaults to the app version - */ - serviceVersion: '0.1.0', - - /** - * Batch settings for OpenTelemetry BatchSpanProcessor - * Optimized for production use with minimal overhead - * - * - maxQueueSize: Max number of spans to buffer (increased from 100 to 2048) - * - maxExportBatchSize: Max number of spans per batch (increased from 10 to 512) - * - scheduledDelayMillis: Delay between batches (5 seconds) - * - exportTimeoutMillis: Timeout for exporting data (30 seconds) - */ - batchSettings: { - maxQueueSize: 2048, - maxExportBatchSize: 512, - scheduledDelayMillis: 5000, - exportTimeoutMillis: 30000, - }, - - /** - * Sampling configuration - * - Errors: Always sampled (100%) - * - AI/LLM operations: Always sampled (100%) - * - Other operations: Sampled at 10% - */ - sampling: { - defaultRate: 0.1, // 10% sampling for regular operations - alwaysSampleErrors: true, - alwaysSampleAI: true, - }, - - /** - * Categories of events that can be collected - * This is used for validation when events are sent - */ - allowedCategories: [ - 'page_view', - 'feature_usage', - 'performance', - 'error', - 'workflow', - 'consent', - 'batch', // Added for batched events - ], - - /** - * Client-side instrumentation settings - * Set enabled: false to disable client-side telemetry entirely - * - * Client-side telemetry now uses: - * - Event batching (send every 10s or 50 events) - * - Only critical Web Vitals (LCP, FID, CLS) - * - Unhandled errors only - */ - clientSide: { - enabled: true, - batchIntervalMs: 10000, // 10 seconds - maxBatchSize: 50, - }, - - /** - * Server-side instrumentation settings - * Set enabled: false to disable server-side telemetry entirely - * - * Server-side telemetry uses: - * - OpenTelemetry SDK with BatchSpanProcessor - * - Intelligent sampling (errors and AI ops always captured) - * - Semantic conventions for AI/LLM operations - */ - serverSide: { - enabled: true, - }, -} - -export default config diff --git a/apps/tradinggoose/triggers/resolution.ts b/apps/tradinggoose/triggers/resolution.ts index c73c725f5..6c4fe7aa6 100644 --- a/apps/tradinggoose/triggers/resolution.ts +++ b/apps/tradinggoose/triggers/resolution.ts @@ -1,10 +1,12 @@ import { getBlock } from '@/blocks' +import type { QueuedWorkflowTriggerType } from '@/services/queue' import { TRIGGER_REGISTRY } from '@/triggers/registry' type TriggerSubBlockValue = { value?: unknown } | unknown type TriggerResolvableBlock = { type: string + name?: string triggerMode?: boolean subBlocks?: Record } @@ -65,3 +67,24 @@ export function resolveTriggerIdForBlock(block: TriggerResolvableBlock): string return resolveTriggerIdFromSubBlocks(block.subBlocks, blockConfig.triggers?.available) } + +export function resolveTriggerExecutionIdentity(block: TriggerResolvableBlock): { + triggerSource: string + triggerType: QueuedWorkflowTriggerType +} { + const triggerSource = resolveTriggerIdForBlock(block) + if (!triggerSource) { + const blockConfig = getBlock(block.type) + throw new Error( + `${block.name || blockConfig?.name || block.type} requires a selected trigger type` + ) + } + + if (block.type === 'api_trigger') return { triggerSource, triggerType: 'api' } + if (block.type === 'chat_trigger') return { triggerSource, triggerType: 'chat' } + if (block.type === 'schedule') return { triggerSource, triggerType: 'schedule' } + if (block.type === 'input_trigger' || block.type === 'manual_trigger') { + return { triggerSource, triggerType: 'manual' } + } + return { triggerSource, triggerType: 'webhook' } +} diff --git a/apps/tradinggoose/tsconfig.json b/apps/tradinggoose/tsconfig.json index 105a2031c..ea2fd1e46 100644 --- a/apps/tradinggoose/tsconfig.json +++ b/apps/tradinggoose/tsconfig.json @@ -53,7 +53,6 @@ "**/*.tsx", ".next/types/**/*.ts", "../next-env.d.ts", - "telemetry.config.js", "trigger.config.ts", ".next/dev/types/**/*.ts" ], diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/copilot-app.tsx b/apps/tradinggoose/widgets/widgets/copilot/components/copilot-app.tsx index 10d1fdc63..b6bd45f6b 100644 --- a/apps/tradinggoose/widgets/widgets/copilot/components/copilot-app.tsx +++ b/apps/tradinggoose/widgets/widgets/copilot/components/copilot-app.tsx @@ -81,7 +81,7 @@ const CopilotApp = ({ : undefined return ( - + value.normalize('NFD').replace(/[\u0300-\u036f]/g, '') + +const createMentionSources = (): MentionSources => ({ + pastChats: [], + workspaceEntities: { + workflow: [], + skill: [], + indicator: [], + custom_tool: [], + mcp_server: [], + }, + knowledgeBases: [], + blocksList: [], + logsList: [], + workflowBlocks: [], +}) + +const loadingState: Record = { + chats: false, + workflow: false, + skill: false, + indicator: false, + custom_tool: false, + mcp_server: false, + workflow_blocks: false, + blocks: false, + knowledge: false, + logs: false, +} + +describe('MentionMenu i18n', () => { + let container: HTMLDivElement + let root: Root + + beforeEach(() => { + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true + container = document.createElement('div') + document.body.appendChild(container) + root = createRoot(container) + }) + + afterEach(() => { + act(() => { + root.unmount() + }) + container.remove() + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = false + }) + + const renderMenu = async ({ + locale, + messages, + openSubmenuFor = null, + sources = createMentionSources(), + mentionQuery = '', + submenuQuery = '', + }: { + locale: 'es' | 'zh' + messages: unknown + openSubmenuFor?: MentionSubmenu | null + sources?: MentionSources + mentionQuery?: string + submenuQuery?: string + }) => { + await act(async () => { + root.render( + + ()} + mentionPortalRef={createRef()} + mentionPortalStyle={{ + top: 48, + left: 24, + width: 320, + maxHeight: 240, + showBelow: true, + }} + mentionQuery={mentionQuery} + menuListRef={createRef()} + onAggregatedItemHover={() => {}} + onMainOptionHover={() => {}} + onSelectAggregatedItem={() => {}} + onSelectMainOption={() => {}} + onSelectSubmenuItem={() => {}} + onSubmenuItemHover={() => {}} + openSubmenuFor={openSubmenuFor} + showMentionMenu + sources={sources} + submenuActiveIndex={0} + submenuQuery={submenuQuery} + /> + + ) + }) + } + + it('renders localized main menu labels in spanish', async () => { + await renderMenu({ + locale: 'es', + messages: esMessages, + }) + + expect(document.body.textContent).toContain( + (esMessages as any).workspace.widgets.workflowLabels.workflows + ) + expect(document.body.textContent).not.toContain( + (esMessages as any).workspace.widgets.workflowLabels.allWorkflows + ) + expect(document.body.textContent).toContain('Bloques del flujo de trabajo') + expect(document.body.textContent).toContain('Documentación') + }) + + it('renders localized empty workflow state in spanish', async () => { + await renderMenu({ + locale: 'es', + messages: esMessages, + openSubmenuFor: 'workflow', + }) + + expect(document.body.textContent).toContain( + (esMessages as any).workspace.widgets.workflowLabels.allWorkflows + ) + expect(document.body.textContent).toContain('No se encontraron flujos de trabajo') + }) + + it('filters unnamed workflows using the localized spanish fallback label', async () => { + const untitledWorkflowLabel = (esMessages as any).workspace.widgets.workflowDropdown + .untitledWorkflow + const sources = createMentionSources() + sources.workspaceEntities.workflow = [ + { + entityKind: 'workflow', + id: 'workflow-1', + name: '', + color: '#3972F6', + }, + ] + + await renderMenu({ + locale: 'es', + messages: esMessages, + openSubmenuFor: 'workflow', + sources, + submenuQuery: stripAccents(untitledWorkflowLabel.toLowerCase()), + }) + + expect(document.body.textContent).toContain(untitledWorkflowLabel) + expect(document.body.textContent).not.toContain('No se encontraron flujos de trabajo') + }) + + it('renders localized block labels in the spanish blocks submenu', async () => { + const localizedBlockName = getLocalizedBlockNameWithCopy( + (esMessages as any).workspace.widgets, + 'condition' + ) + const sources = createMentionSources() + sources.blocksList = [ + { + id: 'condition', + name: localizedBlockName, + }, + ] + + await renderMenu({ + locale: 'es', + messages: esMessages, + openSubmenuFor: 'blocks', + sources, + }) + + expect(document.body.textContent).toContain(localizedBlockName) + }) + + it('filters and renders logs using localized chinese trigger labels', async () => { + const sources = createMentionSources() + sources.logsList = [ + { + id: 'log-1', + level: 'info', + trigger: 'schedule', + startedAt: '2026-04-17T00:00:00.000Z', + entityName: 'Alpha Workflow', + }, + ] + + await renderMenu({ + locale: 'zh', + messages: zhMessages, + openSubmenuFor: 'logs', + sources, + submenuQuery: '计划', + }) + + expect(document.body.textContent).toContain('Alpha Workflow') + expect(document.body.textContent).toContain('计划') + expect(document.body.textContent).not.toContain('未找到执行记录') + }) + + it('renders localized workflow block labels in the spanish workflow blocks submenu', async () => { + const localizedWorkflowBlockName = getLocalizedDefaultBlockNameWithCopy( + (esMessages as any).workspace.widgets, + 'condition', + 'Condition 2' + ) + const sources = createMentionSources() + sources.workflowBlocks = [ + { + id: 'workflow-block-1', + type: 'condition', + name: localizedWorkflowBlockName, + }, + ] + + await renderMenu({ + locale: 'es', + messages: esMessages, + openSubmenuFor: 'workflow_blocks', + sources, + submenuQuery: 'condicion 2', + }) + + expect(document.body.textContent).toContain(localizedWorkflowBlockName) + }) + + it('filters blocks using localized chinese block names', async () => { + const localizedBlockName = getLocalizedBlockNameWithCopy( + (zhMessages as any).workspace.widgets, + 'condition' + ) + const sources = createMentionSources() + sources.blocksList = [ + { + id: 'condition', + name: localizedBlockName, + }, + ] + + await renderMenu({ + locale: 'zh', + messages: zhMessages, + openSubmenuFor: 'blocks', + sources, + submenuQuery: localizedBlockName, + }) + + expect(document.body.textContent).toContain(localizedBlockName) + }) + + it('filters blocks using accentless spanish queries', async () => { + const localizedBlockName = getLocalizedBlockNameWithCopy( + (esMessages as any).workspace.widgets, + 'condition' + ) + const sources = createMentionSources() + sources.blocksList = [ + { + id: 'condition', + name: localizedBlockName, + }, + ] + + await renderMenu({ + locale: 'es', + messages: esMessages, + openSubmenuFor: 'blocks', + sources, + submenuQuery: 'condicion', + }) + + expect(document.body.textContent).toContain(localizedBlockName) + }) +}) diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/components/mention-menu.tsx b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/components/mention-menu.tsx index ebe3c0f19..5be76b5f9 100644 --- a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/components/mention-menu.tsx +++ b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/components/mention-menu.tsx @@ -20,11 +20,22 @@ import { import { createPortal } from 'react-dom' import { getIconTileStyle, sanitizeSolidIconColor } from '@/lib/ui/icon-colors' import { cn } from '@/lib/utils' +import { useMonitorCopy } from '@/app/workspace/[workspaceId]/monitor/copy' import { type CopilotWorkspaceEntityKind, getCopilotWorkspaceEntityKindFromMentionOption, isCopilotWorkspaceEntityMentionOption, } from '../../../workspace-entities' +import { + getKnowledgeBaseMentionLabel, + getLogMentionTriggerLabel, + getMentionOptionLabel, + getMentionSubmenuTitle, + getPastChatMentionLabel, + getWorkspaceEntityMentionEmptyState, + getWorkspaceEntityMentionLabel, + useCopilotMentionCopy, +} from '../mention-copy' import { buildAggregatedMentionItems, filterBlocks, @@ -34,7 +45,6 @@ import { filterPastChats, filterWorkflowBlocks, filterWorkspaceEntitiesForOption, - getMentionSubmenuTitle, } from '../mention-utils' import type { AggregatedMentionItem, @@ -50,7 +60,6 @@ import type { WorkflowBlockItem, WorkspaceEntityItem, } from '../types' -import { getWorkspaceEntityMentionEmptyState } from '../workspace-entity-mentions' interface MentionMenuProps { inAggregated: boolean @@ -198,30 +207,30 @@ const renderWorkspaceEntityMainOptionIcon = (entityKind: CopilotWorkspaceEntityK const WORKSPACE_ENTITY_ITEM_RENDERERS: Record< CopilotWorkspaceEntityKind, - (entity: WorkspaceEntityItem) => ReactNode + (entity: WorkspaceEntityItem, label: string) => ReactNode > = { - workflow: (entity) => ( + workflow: (entity, label) => ( <> {renderWorkflowBadge(entity.color)} - {entity.name} + {label} ), - skill: (entity) => ( + skill: (entity, label) => ( <> {renderSkillBadge()} - {entity.name} + {label} ), - indicator: (entity) => ( + indicator: (entity, label) => ( <> {renderIndicatorBadge(entity.color)} - {entity.name} + {label} ), - custom_tool: (entity) => ( + custom_tool: (entity, label) => ( <> {renderCustomToolBadge()} - {entity.name} + {label} {entity.functionName ? ( <> · @@ -230,10 +239,10 @@ const WORKSPACE_ENTITY_ITEM_RENDERERS: Record< ) : null} ), - mcp_server: (entity) => ( + mcp_server: (entity, label) => ( <> {renderMcpServerBadge(entity.connectionStatus)} - {entity.name} + {label} {entity.transport ? ( <> · @@ -245,7 +254,7 @@ const WORKSPACE_ENTITY_ITEM_RENDERERS: Record< } const renderMainOptionIcon = (option: MentionOption) => { - if (option === 'Chats') { + if (option === 'chats') { return } @@ -255,58 +264,66 @@ const renderMainOptionIcon = (option: MentionOption) => { ) } - if (option === 'Blocks') { + if (option === 'blocks') { return } - if (option === 'Workflow Blocks') { + if (option === 'workflow_blocks') { return } - if (option === 'Knowledge') { + if (option === 'knowledge') { return } - if (option === 'Docs') { + if (option === 'docs') { return } - if (option === 'Logs') { + if (option === 'logs') { return } return
} -const renderMentionItemContent = (type: MentionSubmenu, item: MentionItem) => { - if (type === 'Chats') { +const renderMentionItemContent = ( + type: MentionSubmenu, + item: MentionItem, + mentionCopy: ReturnType, + monitorCopy: ReturnType['copy'] +) => { + if (type === 'chats') { const chat = item as PastChatItem return ( <>
- {chat.title || 'Untitled Chat'} + {getPastChatMentionLabel(mentionCopy, chat)} ) } if (isCopilotWorkspaceEntityMentionOption(type)) { const entity = item as WorkspaceEntityItem - return WORKSPACE_ENTITY_ITEM_RENDERERS[entity.entityKind](entity) + return WORKSPACE_ENTITY_ITEM_RENDERERS[entity.entityKind]( + entity, + getWorkspaceEntityMentionLabel(mentionCopy, entity) + ) } - if (type === 'Knowledge') { + if (type === 'knowledge') { const knowledgeBase = item as KnowledgeBaseItem return ( <> - {knowledgeBase.name || 'Untitled'} + {getKnowledgeBaseMentionLabel(mentionCopy, knowledgeBase)} ) } - if (type === 'Blocks') { + if (type === 'blocks') { const block = item as BlockItem return ( <> @@ -316,7 +333,7 @@ const renderMentionItemContent = (type: MentionSubmenu, item: MentionItem) => { ) } - if (type === 'Workflow Blocks') { + if (type === 'workflow_blocks') { const block = item as WorkflowBlockItem return ( <> @@ -326,7 +343,7 @@ const renderMentionItemContent = (type: MentionSubmenu, item: MentionItem) => { ) } - if (type === 'Logs') { + if (type === 'logs') { const log = item as LogItem return ( <> @@ -339,7 +356,7 @@ const renderMentionItemContent = (type: MentionSubmenu, item: MentionItem) => { · {formatTimestamp(log.startedAt)} · - {(log.trigger || 'manual').toLowerCase()} + {getLogMentionTriggerLabel(monitorCopy, log)} ) } @@ -350,55 +367,61 @@ const renderMentionItemContent = (type: MentionSubmenu, item: MentionItem) => { const getSubmenuItems = ( submenu: MentionSubmenu, query: string, - sources: MentionSources + sources: MentionSources, + mentionCopy: ReturnType, + monitorCopy: ReturnType['copy'] ): MentionItem[] => { - if (submenu === 'Chats') { - return filterPastChats(sources.pastChats, query) + if (submenu === 'chats') { + return filterPastChats(sources.pastChats, query, mentionCopy) } if (isCopilotWorkspaceEntityMentionOption(submenu)) { - return filterWorkspaceEntitiesForOption(submenu, sources, query) + return filterWorkspaceEntitiesForOption(submenu, sources, query, mentionCopy) } - if (submenu === 'Knowledge') { - return filterKnowledgeBases(sources.knowledgeBases, query) + if (submenu === 'knowledge') { + return filterKnowledgeBases(sources.knowledgeBases, query, mentionCopy) } - if (submenu === 'Blocks') { + if (submenu === 'blocks') { return filterBlocks(sources.blocksList, query) } - if (submenu === 'Workflow Blocks') { + if (submenu === 'workflow_blocks') { return filterWorkflowBlocks(sources.workflowBlocks, query) } - return filterLogs(sources.logsList, query) + return filterLogs(sources.logsList, query, monitorCopy) } -const getSubmenuEmptyState = (submenu: MentionSubmenu) => { - if (submenu === 'Chats') { - return 'No past chats' +const getSubmenuEmptyState = ( + submenu: MentionSubmenu, + mentionCopy: ReturnType +) => { + if (submenu === 'chats') { + return mentionCopy.emptyStates.chats } if (isCopilotWorkspaceEntityMentionOption(submenu)) { return getWorkspaceEntityMentionEmptyState( + mentionCopy, getCopilotWorkspaceEntityKindFromMentionOption(submenu) ) } - if (submenu === 'Knowledge') { - return 'No knowledge bases' + if (submenu === 'knowledge') { + return mentionCopy.emptyStates.knowledge } - if (submenu === 'Blocks') { - return 'No blocks found' + if (submenu === 'blocks') { + return mentionCopy.emptyStates.blocks } - if (submenu === 'Workflow Blocks') { - return 'No blocks in this workflow' + if (submenu === 'workflow_blocks') { + return mentionCopy.emptyStates.workflow_blocks } - return 'No executions found' + return mentionCopy.emptyStates.logs } const isSubmenuLoading = (submenu: MentionSubmenu, loading: MentionMenuProps['loading']) => { @@ -430,14 +453,24 @@ export function MentionMenu({ submenuActiveIndex, submenuQuery, }: MentionMenuProps) { + const mentionCopy = useCopilotMentionCopy() + const { copy: monitorCopy } = useMonitorCopy() + if (!showMentionMenu || !mentionPortalStyle) { return null } - const filteredOptions = filterMentionOptions(mentionQuery) - const aggregatedItems = buildAggregatedMentionItems(mentionQuery, sources) + const filteredOptions = filterMentionOptions(mentionQuery, mentionCopy) + const aggregatedItems = buildAggregatedMentionItems( + mentionQuery, + sources, + mentionCopy, + monitorCopy + ) const showAggregatedSearch = mentionQuery.length > 0 && filteredOptions.length === 0 - const submenuItems = openSubmenuFor ? getSubmenuItems(openSubmenuFor, submenuQuery, sources) : [] + const submenuItems = openSubmenuFor + ? getSubmenuItems(openSubmenuFor, submenuQuery, sources, mentionCopy, monitorCopy) + : [] return createPortal(
- {getMentionSubmenuTitle(openSubmenuFor)} + {getMentionSubmenuTitle(mentionCopy, openSubmenuFor)}
{isSubmenuLoading(openSubmenuFor, loading) ? ( -
Loading...
+
{mentionCopy.loading}
) : submenuItems.length === 0 ? (
- {getSubmenuEmptyState(openSubmenuFor)} + {getSubmenuEmptyState(openSubmenuFor, mentionCopy)}
) : ( submenuItems.map((item, index) => ( @@ -493,7 +526,7 @@ export function MentionMenu({ onMouseEnter={() => onSubmenuItemHover(index)} onClick={() => onSelectSubmenuItem(openSubmenuFor, item)} > - {renderMentionItemContent(openSubmenuFor, item)} + {renderMentionItemContent(openSubmenuFor, item, mentionCopy, monitorCopy)}
)) )} @@ -502,7 +535,7 @@ export function MentionMenu({ ) : showAggregatedSearch ? (
{aggregatedItems.length === 0 ? ( -
No matches
+
{mentionCopy.noMatches}
) : ( aggregatedItems.map((item, index) => (
onAggregatedItemHover(index)} onClick={() => onSelectAggregatedItem(item)} > - {renderMentionItemContent(item.type, item.value)} + {renderMentionItemContent(item.type, item.value, mentionCopy, monitorCopy)}
)) )} @@ -541,13 +574,9 @@ export function MentionMenu({ >
{renderMainOptionIcon(option)} - - {isCopilotWorkspaceEntityMentionOption(option) - ? getMentionSubmenuTitle(option) - : option} - + {getMentionOptionLabel(mentionCopy, option)}
- {option !== 'Docs' && ( + {option !== 'docs' && ( )}
@@ -556,7 +585,9 @@ export function MentionMenu({ {mentionQuery.length > 0 && aggregatedItems.length > 0 && ( <>
-
Matches
+
+ {mentionCopy.matches} +
{aggregatedItems.map((item, index) => (
onAggregatedItemHover(index)} onClick={() => onSelectAggregatedItem(item)} > - {renderMentionItemContent(item.type, item.value)} + {renderMentionItemContent(item.type, item.value, mentionCopy, monitorCopy)}
))} diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/constants.ts b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/constants.ts index 01d6d3b51..c805464bf 100644 --- a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/constants.ts +++ b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/constants.ts @@ -14,17 +14,17 @@ export const ANTHROPIC_MODELS: readonly CopilotRuntimeModel[] = [ export const OPENAI_MODELS: readonly CopilotRuntimeModel[] = ['gpt-5.4', 'gpt-5.4-mini'] export const MENTION_OPTIONS: readonly MentionOption[] = [ - 'Chats', + 'chats', ...COPILOT_WORKSPACE_ENTITY_MENTION_OPTIONS, - 'Workflow Blocks', - 'Blocks', - 'Knowledge', - 'Docs', - 'Logs', + 'workflow_blocks', + 'blocks', + 'knowledge', + 'docs', + 'logs', ] export const MENTION_SUBMENUS: readonly MentionSubmenu[] = MENTION_OPTIONS.filter( - (option): option is MentionSubmenu => option !== 'Docs' + (option): option is MentionSubmenu => option !== 'docs' ) export const MAX_TEXTAREA_HEIGHT = 120 diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.test.tsx b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.test.tsx new file mode 100644 index 000000000..74287ae78 --- /dev/null +++ b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.test.tsx @@ -0,0 +1,237 @@ +/** + * @vitest-environment jsdom + */ + +import { act, useEffect } from 'react' +import { NextIntlClientProvider } from 'next-intl' +import { createRoot, type Root } from 'react-dom/client' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + getLocalizedBlockNameWithCopy, + getLocalizedDefaultBlockNameWithCopy, +} from '@/i18n/workflow-inspector-core' +import esMessages from '../../../../../../i18n/messages/es.json' +import zhMessages from '../../../../../../i18n/messages/zh.json' +import { useUserInputMentionSources } from './use-user-input-mention-sources' + +const reactActEnvironment = globalThis as typeof globalThis & { + IS_REACT_ACT_ENVIRONMENT?: boolean +} + +const mockBlocks = [ + { + type: 'condition', + name: 'Condition', + category: 'blocks', + hideFromToolbar: false, + bgColor: '#3972F6', + }, +] +const mockWorkflowBlocks: Record = {} +let mockWorkflowId: string | null = null + +const mockGetAllBlocks = vi.fn(() => mockBlocks) +const mockGetBlock = vi.fn((blockType: string) => + mockBlocks.find((block) => block.type === blockType) +) +const mockFetch = vi.fn(async (input: string | URL | Request) => { + const url = + typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url + + if (url.startsWith('/api/workflows?')) { + return { + ok: true, + json: async () => ({ + data: [ + { + id: 'workflow-1', + name: 'Workflow One', + color: '#3972F6', + }, + ], + }), + } as any + } + + throw new Error(`Unexpected fetch in mention sources test: ${url}`) +}) + +vi.mock('@/blocks', () => ({ + getAllBlocks: () => mockGetAllBlocks(), + getBlock: (blockType: string) => mockGetBlock(blockType), +})) + +vi.mock('@/blocks/registry', () => ({ + registry: { + condition: { + bgColor: '#3972F6', + icon: null, + name: 'Condition', + }, + }, +})) + +vi.mock('@/lib/yjs/use-workflow-doc', () => ({ + useWorkflowBlocks: () => mockWorkflowBlocks, +})) + +vi.mock('@/lib/yjs/workflow-session-host', () => ({ + useOptionalWorkflowSession: () => + mockWorkflowId + ? { + workflowId: mockWorkflowId, + } + : null, +})) + +type MentionSourcesHookResult = ReturnType + +function MentionSourcesHarness({ + onRender, + workspaceId, +}: { + onRender: (value: MentionSourcesHookResult) => void + workspaceId: string +}) { + const result = useUserInputMentionSources({ workspaceId }) + + useEffect(() => { + onRender(result) + }, [onRender, result]) + + return null +} + +describe('useUserInputMentionSources', () => { + let container: HTMLDivElement + let root: Root + let latestResult: MentionSourcesHookResult | null + + const renderHarness = async ({ + locale, + messages, + }: { + locale: 'es' | 'zh' + messages: unknown + }) => { + await act(async () => { + root.render( + + { + latestResult = value + }} + /> + + ) + }) + } + + beforeEach(() => { + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = true + vi.stubGlobal('fetch', mockFetch) + container = document.createElement('div') + document.body.appendChild(container) + root = createRoot(container) + latestResult = null + mockWorkflowId = null + for (const key of Object.keys(mockWorkflowBlocks)) { + delete mockWorkflowBlocks[key] + } + mockGetAllBlocks.mockClear() + mockGetBlock.mockClear() + mockFetch.mockClear() + }) + + afterEach(() => { + act(() => { + root.unmount() + }) + container.remove() + reactActEnvironment.IS_REACT_ACT_ENVIRONMENT = false + vi.unstubAllGlobals() + }) + + it('reloads localized block mention labels after a locale change', async () => { + const spanishBlockName = getLocalizedBlockNameWithCopy( + (esMessages as any).workspace.widgets, + mockBlocks[0] + ) + const chineseBlockName = getLocalizedBlockNameWithCopy( + (zhMessages as any).workspace.widgets, + mockBlocks[0] + ) + + await renderHarness({ + locale: 'es', + messages: esMessages, + }) + + await act(async () => { + await latestResult?.ensureBlocksLoaded() + }) + + expect(latestResult?.blocksList.map((item) => item.name)).toEqual([spanishBlockName]) + + await renderHarness({ + locale: 'zh', + messages: zhMessages, + }) + + expect(latestResult?.blocksList).toEqual([]) + + await act(async () => { + await latestResult?.ensureBlocksLoaded() + }) + + expect(latestResult?.blocksList.map((item) => item.name)).toEqual([chineseBlockName]) + expect(mockGetAllBlocks).toHaveBeenCalledTimes(2) + }) + + it('reloads localized workflow block mention labels after a locale change', async () => { + mockWorkflowId = 'workflow-1' + mockWorkflowBlocks['workflow-block-1'] = { + id: 'workflow-block-1', + type: 'condition', + name: 'Condition 2', + } + + const spanishWorkflowBlockName = getLocalizedDefaultBlockNameWithCopy( + (esMessages as any).workspace.widgets, + 'condition', + 'Condition 2' + ) + const chineseWorkflowBlockName = getLocalizedDefaultBlockNameWithCopy( + (zhMessages as any).workspace.widgets, + 'condition', + 'Condition 2' + ) + + await renderHarness({ + locale: 'es', + messages: esMessages, + }) + + await act(async () => { + await latestResult?.ensureWorkflowBlocksLoaded() + }) + + expect(latestResult?.workflowBlocks.map((item) => item.name)).toEqual([ + spanishWorkflowBlockName, + ]) + + await renderHarness({ + locale: 'zh', + messages: zhMessages, + }) + + await act(async () => { + await latestResult?.ensureWorkflowBlocksLoaded() + }) + + expect(latestResult?.workflowBlocks.map((item) => item.name)).toEqual([ + chineseWorkflowBlockName, + ]) + }) +}) diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.ts b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.ts index 207a3c3c1..3168a2857 100644 --- a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.ts +++ b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mention-sources.ts @@ -1,11 +1,17 @@ 'use client' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import { useLocale } from 'next-intl' import { createLogger } from '@/lib/logs/console/logger' import { sanitizeSolidIconColor } from '@/lib/ui/icon-colors' import { useWorkflowBlocks } from '@/lib/yjs/use-workflow-doc' import { useOptionalWorkflowSession } from '@/lib/yjs/workflow-session-host' import { fetchKnowledgeBases as fetchWorkspaceKnowledgeBases } from '@/hooks/queries/knowledge' +import { + getLocalizedBlockNameWithCopy, + getLocalizedDefaultBlockNameWithCopy, +} from '@/i18n/workflow-inspector-core' +import { useWorkflowInspectorMessages } from '@/i18n/workspace-widget-hooks' import { getSubflowBlockConfig } from '@/widgets/widgets/editor_workflow/components/subflows/config' import { type CopilotWorkspaceEntityKind, @@ -50,6 +56,7 @@ const createEmptyWorkspaceEntityLoading = (): Record([]) const [isLoadingPastChats, setIsLoadingPastChats] = useState(false) const [workspaceEntities, setWorkspaceEntities] = useState(createEmptyWorkspaceEntities) @@ -67,6 +74,18 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS const workflowSession = useOptionalWorkflowSession() const workflowId = workflowSession?.workflowId ?? null const workflowStoreBlocks = useWorkflowBlocks() + const workflowInspectorCopy = useWorkflowInspectorMessages() + const latestBlocksLocaleRef = useRef(locale) + const latestWorkflowBlocksKeyRef = useRef(`${locale}:${workflowId ?? ''}`) + const workflowBlocksLoadingRef = useRef(false) + + useEffect(() => { + latestBlocksLocaleRef.current = locale + }, [locale]) + + useEffect(() => { + latestWorkflowBlocksKeyRef.current = `${locale}:${workflowId ?? ''}` + }, [locale, workflowId]) const ensurePastChatsLoaded = useCallback(async () => { if (isLoadingPastChats || pastChats.length > 0) { @@ -141,7 +160,7 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS setKnowledgeBases( sorted.map((item: any) => ({ id: item.id, - name: item.name || 'Untitled', + name: item.name || '', })) ) } catch { @@ -155,6 +174,8 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS return } + const loadLocale = locale + try { setIsLoadingBlocks(true) const { getAllBlocks } = await import('@/blocks') @@ -163,7 +184,7 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS .filter((block: any) => !block.hideFromToolbar && block.category === 'blocks') .map((block: any) => ({ id: block.type, - name: block.name || block.type, + name: getLocalizedBlockNameWithCopy(workflowInspectorCopy, block), iconComponent: block.icon, bgColor: sanitizeSolidIconColor(block.bgColor), })) @@ -173,18 +194,24 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS .filter((block: any) => !block.hideFromToolbar && block.category === 'tools') .map((block: any) => ({ id: block.type, - name: block.name || block.type, + name: getLocalizedBlockNameWithCopy(workflowInspectorCopy, block), iconComponent: block.icon, bgColor: sanitizeSolidIconColor(block.bgColor), })) .sort((a: any, b: any) => a.name.localeCompare(b.name)) + if (latestBlocksLocaleRef.current !== loadLocale) { + return + } + setBlocksList([...regularBlocks, ...toolBlocks]) } catch { } finally { - setIsLoadingBlocks(false) + if (latestBlocksLocaleRef.current === loadLocale) { + setIsLoadingBlocks(false) + } } - }, [blocksList.length, isLoadingBlocks]) + }, [blocksList.length, isLoadingBlocks, locale, workflowInspectorCopy]) const ensureLogsLoaded = useCallback(async () => { if (isLoadingLogs || logsList.length > 0) { @@ -226,7 +253,7 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS }, [isLoadingLogs, logsList.length, workspaceId]) const ensureWorkflowBlocksLoaded = useCallback(async () => { - if (isLoadingWorkflowBlocks) { + if (workflowBlocksLoadingRef.current) { return } @@ -235,7 +262,10 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS return } + const loadKey = `${locale}:${workflowId ?? ''}` + try { + workflowBlocksLoadingRef.current = true setIsLoadingWorkflowBlocks(true) const { registry: blockRegistry } = await import('@/blocks/registry') const mapped = Object.values(workflowStoreBlocks).map((block: any) => { @@ -245,24 +275,35 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS return { id: block.id, - name: block.name || presentation?.name || block.id, + name: getLocalizedDefaultBlockNameWithCopy( + workflowInspectorCopy, + block.type, + block.name || presentation?.name + ), type: block.type, iconComponent: presentation?.icon, bgColor: sanitizeSolidIconColor(presentation?.bgColor) || '#6B7280', } }) + if (latestWorkflowBlocksKeyRef.current !== loadKey) { + return + } + setWorkflowBlocks(mapped) } catch (error) { logger.error('Failed to sync workflow blocks:', error) } finally { - setIsLoadingWorkflowBlocks(false) + workflowBlocksLoadingRef.current = false + if (latestWorkflowBlocksKeyRef.current === loadKey) { + setIsLoadingWorkflowBlocks(false) + } } - }, [isLoadingWorkflowBlocks, workflowId, workflowStoreBlocks]) + }, [locale, workflowId, workflowInspectorCopy, workflowStoreBlocks]) const ensureSubmenuLoaded = useCallback( async (submenu: MentionSubmenu) => { - if (submenu === 'Chats') { + if (submenu === 'chats') { await ensurePastChatsLoaded() return } @@ -272,17 +313,17 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS return } - if (submenu === 'Knowledge') { + if (submenu === 'knowledge') { await ensureKnowledgeLoaded() return } - if (submenu === 'Blocks') { + if (submenu === 'blocks') { await ensureBlocksLoaded() return } - if (submenu === 'Workflow Blocks') { + if (submenu === 'workflow_blocks') { await ensureWorkflowBlocksLoaded() return } @@ -301,12 +342,18 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS useEffect(() => { setWorkflowBlocks([]) + workflowBlocksLoadingRef.current = false setIsLoadingWorkflowBlocks(false) }, [workflowId]) + useEffect(() => { + setBlocksList([]) + setIsLoadingBlocks(false) + }, [locale]) + useEffect(() => { void ensureWorkflowBlocksLoaded() - }, [ensureWorkflowBlocksLoaded]) + }, [locale, workflowId, workflowStoreBlocks]) useEffect(() => { if (workflowId && workspaceEntities.workflow.length === 0) { @@ -335,16 +382,16 @@ export function useUserInputMentionSources({ workspaceId }: UseUserInputMentionS } const mentionLoading: Record = { - Chats: isLoadingPastChats, - Workflows: workspaceEntityLoading.workflow, - Skills: workspaceEntityLoading.skill, - Indicators: workspaceEntityLoading.indicator, - 'Custom Tools': workspaceEntityLoading.custom_tool, - 'MCP Servers': workspaceEntityLoading.mcp_server, - 'Workflow Blocks': isLoadingWorkflowBlocks, - Blocks: isLoadingBlocks, - Knowledge: isLoadingKnowledge, - Logs: isLoadingLogs, + chats: isLoadingPastChats, + workflow: workspaceEntityLoading.workflow, + skill: workspaceEntityLoading.skill, + indicator: workspaceEntityLoading.indicator, + custom_tool: workspaceEntityLoading.custom_tool, + mcp_server: workspaceEntityLoading.mcp_server, + workflow_blocks: isLoadingWorkflowBlocks, + blocks: isLoadingBlocks, + knowledge: isLoadingKnowledge, + logs: isLoadingLogs, } return { diff --git a/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mentions.test.tsx b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mentions.test.tsx new file mode 100644 index 000000000..b22a58e6d --- /dev/null +++ b/apps/tradinggoose/widgets/widgets/copilot/components/user-input/hooks/use-user-input-mentions.test.tsx @@ -0,0 +1,287 @@ +/** + * @vitest-environment jsdom + */ + +import { act, useEffect, useRef, useState } from 'react' +import { NextIntlClientProvider } from 'next-intl' +import { createRoot, type Root } from 'react-dom/client' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import enMessages from '../../../../../../i18n/messages/en.json' +import esMessages from '../../../../../../i18n/messages/es.json' +import { MENTION_SUBMENUS } from '../constants' +import type { MentionSources } from '../types' +import { useUserInputMentions } from './use-user-input-mentions' + +const reactActEnvironment = globalThis as typeof globalThis & { + IS_REACT_ACT_ENVIRONMENT?: boolean +} + +type MentionHookSnapshot = ReturnType & { + message: string + textarea: HTMLTextAreaElement | null +} + +const createMentionSources = (): MentionSources => ({ + pastChats: [], + workspaceEntities: { + workflow: [], + skill: [], + indicator: [], + custom_tool: [], + mcp_server: [], + }, + knowledgeBases: [], + blocksList: [], + logsList: [], + workflowBlocks: [], +}) + +function MentionsHarness({ + mentionSources, + onRender, + workspaceId, + ensureSubmenuLoaded, +}: { + mentionSources: MentionSources + onRender: (value: MentionHookSnapshot) => void + workspaceId: string + ensureSubmenuLoaded: (submenu: any) => Promise +}) { + const [message, setMessage] = useState('') + const menuListRef = useRef(null) + const textareaRef = useRef(null) + const result = useUserInputMentions({ + disabled: false, + isLoading: false, + menuListRef, + message, + mentionSources, + setMessage, + textareaRef, + workspaceId, + loaders: { + ensureSubmenuLoaded, + }, + }) + + useEffect(() => { + onRender({ + ...result, + message, + textarea: textareaRef.current, + }) + }, [message, onRender, result]) + + return ( + <> +