From ecbd6d713edbe88620e41106565feec09ca33535 Mon Sep 17 00:00:00 2001 From: jdalton Date: Tue, 8 Jul 2025 16:28:33 -0400 Subject: [PATCH] Don't gate PRs on open or closed state --- src/commands/fix/agent-fix.mts | 176 ++++++++---------- src/commands/fix/fix-branch-helpers.mts | 42 ++--- src/commands/fix/fix-env-helpers.mts | 61 +++--- src/commands/fix/git.mts | 20 +- src/commands/fix/npm-fix.mts | 48 +++-- src/commands/fix/pnpm-fix.mts | 24 +-- .../fix/{open-pr.mts => pull-request.mts} | 112 ++++++----- src/commands/fix/shared.mts | 2 +- src/shadow/npm/arborist-helpers.mts | 4 +- src/utils/alerts-map.mts | 2 - src/utils/socket-package-alert.mts | 21 +-- 11 files changed, 244 insertions(+), 268 deletions(-) rename src/commands/fix/{open-pr.mts => pull-request.mts} (82%) diff --git a/src/commands/fix/agent-fix.mts b/src/commands/fix/agent-fix.mts index 1c6ff0ad8..9ceb531ec 100644 --- a/src/commands/fix/agent-fix.mts +++ b/src/commands/fix/agent-fix.mts @@ -15,8 +15,10 @@ import { } from '@socketsecurity/registry/lib/packages' import { naturalCompare } from '@socketsecurity/registry/lib/sorts' import { isNonEmptyString } from '@socketsecurity/registry/lib/strings' +import { pluralize } from '@socketsecurity/registry/lib/words' -import { getActiveBranchesForPackage } from './fix-branch-helpers.mts' +import { getPrsForPurl } from './fix-branch-helpers.mts' +import { getFixEnv } from './fix-env-helpers.mts' import { getActualTree } from './get-actual-tree.mts' import { getSocketBranchName, @@ -28,12 +30,11 @@ import { gitUnstagedModifiedFiles, } from './git.mts' import { - cleanupOpenPrs, + cleanupPrs, enablePrAutoMerge, openPr, - prExistForBranch, setGitRemoteGithubRepoUrl, -} from './open-pr.mts' +} from './pull-request.mts' import constants from '../../constants.mts' import { findBestPatchVersion, @@ -50,14 +51,15 @@ import { getCveInfoFromAlertsMap } from '../../utils/socket-package-alert.mts' import { idToPurl } from '../../utils/spec.mts' import { getOverridesData } from '../optimize/get-overrides-by-agent.mts' -import type { CiEnv } from './fix-env-helpers.mts' -import type { PrMatch } from './open-pr.mts' import type { NodeClass } from '../../shadow/npm/arborist/types.mts' import type { CResult } from '../../types.mts' import type { EnvDetails } from '../../utils/package-environment.mts' import type { RangeStyle } from '../../utils/semver.mts' import type { AlertsByPurl } from '../../utils/socket-package-alert.mts' -import type { EditablePackageJson } from '@socketsecurity/registry/lib/packages' +import type { + EditablePackageJson, + Packument, +} from '@socketsecurity/registry/lib/packages' import type { Spinner } from '@socketsecurity/registry/lib/spinner' export type FixConfig = { @@ -80,7 +82,7 @@ export type InstallOptions = { export type InstallPhaseHandler = ( editablePkgJson: EditablePackageJson, - name: string, + packument: Packument, oldVersion: string, newVersion: string, vulnerableVersionRange: string, @@ -109,11 +111,12 @@ export async function agentFix( afterInstall?: InstallPhaseHandler | undefined revertInstall?: InstallPhaseHandler | undefined }, - ciEnv: CiEnv | null, - openPrs: PrMatch[], fixConfig: FixConfig, ): Promise> { const { pkgPath: rootPath } = pkgEnvDetails + + const fixEnv = await getFixEnv() + const { autoMerge, cwd, @@ -128,7 +131,7 @@ export async function agentFix( let count = 0 const infoByPartialPurl = getCveInfoFromAlertsMap(alertsMap, { - limit: Math.max(limit, openPrs.length), + exclude: { upgradable: true }, }) if (!infoByPartialPurl) { spinner?.stop() @@ -141,8 +144,14 @@ export async function agentFix( return { ok: true, data: { fixed: false } } } - if (isDebug('notice')) { - debugFn('notice', 'found: cves for', Array.from(infoByPartialPurl.keys())) + if (isDebug('notice,inspect')) { + const partialPurls = Array.from(infoByPartialPurl.keys()) + const { length: purlsCount } = partialPurls + debugFn( + 'notice', + `found: ${purlsCount} ${pluralize('PURL', purlsCount)} with CVEs`, + ) + debugDir('inspect', { partialPurls }) } // Lazily access constants.packumentCache. @@ -190,10 +199,11 @@ export async function agentFix( const infos = Array.from(infoEntry[1].values()) if (!infos.length) { + debugFn('notice', `miss: CVEs expected, but not found, for ${name}`) continue infoEntriesLoop } - logger.log(`Processing vulns for ${name}:`) + logger.log(`Processing vulns for ${name}`) logger.indent() spinner?.indent() @@ -208,12 +218,8 @@ export async function agentFix( continue infoEntriesLoop } - const activeBranches = getActiveBranchesForPackage( - ciEnv, - infoEntry[0], - openPrs, - ) const availableVersions = Object.keys(packument.versions) + const prs = getPrsForPurl(fixEnv, infoEntry[0]) const warningsForAfter = new Set() // eslint-disable-next-line no-unused-labels @@ -230,18 +236,17 @@ export async function agentFix( const workspace = isWorkspaceRoot ? 'root' : path.relative(rootPath, pkgPath) - const branchWorkspace = ciEnv + const branchWorkspace = fixEnv.isCi ? getSocketBranchWorkspaceComponent(workspace) : '' - // actualTree may not be defined on the first iteration of pkgJsonPathsLoop. if (!actualTree) { - if (!ciEnv) { + if (!fixEnv.isCi) { // eslint-disable-next-line no-await-in-loop await removeNodeModules(cwd) } const maybeActualTree = - ciEnv && existsSync(path.join(rootPath, 'node_modules')) + fixEnv.isCi && existsSync(path.join(rootPath, 'node_modules')) ? // eslint-disable-next-line no-await-in-loop await getActualTree(cwd) : // eslint-disable-next-line no-await-in-loop @@ -282,7 +287,7 @@ export async function agentFix( let hasAnnouncedWorkspace = false let workspaceLogCallCount = logger.logCallCount - if (isDebug()) { + if (isDebug('notice')) { debugFn('notice', `check: workspace ${workspace}`) hasAnnouncedWorkspace = true workspaceLogCallCount = logger.logCallCount @@ -319,23 +324,34 @@ export async function agentFix( continue infosLoop } if (semver.gte(oldVersion, newVersion)) { - debugFn('notice', `skip: ${oldId} is >= ${newVersion}`) + debugFn('silly', `skip: ${oldId} is >= ${newVersion}`) + continue infosLoop + } + const branch = getSocketBranchName(oldPurl, newVersion, workspace) + const pr = prs.find( + ({ parsedBranch: b }) => + b.workspace === branchWorkspace && b.newVersion === newVersion, + ) + if (pr) { + debugFn('notice', `skip: PR #${pr.number} for ${name} exists`) + if (++count >= limit) { + cleanupInfoEntriesLoop() + break infoEntriesLoop + } continue infosLoop } if ( - activeBranches.find( - b => - b.workspace === branchWorkspace && b.newVersion === newVersion, - ) + fixEnv.isCi && + // eslint-disable-next-line no-await-in-loop + (await gitRemoteBranchExists(branch, cwd)) ) { - debugFn('notice', `skip: open PR found for ${name}@${newVersion}`) + debugFn('notice', `skip: remote branch "${branch}" exists`) if (++count >= limit) { cleanupInfoEntriesLoop() break infoEntriesLoop } continue infosLoop } - const { overrides: oldOverrides } = getOverridesData( pkgEnvDetails, editablePkgJson.content, @@ -351,12 +367,13 @@ export async function agentFix( // eslint-disable-next-line no-await-in-loop await beforeInstall( editablePkgJson, - name, + packument, oldVersion, newVersion, vulnerableVersionRange, fixConfig, ) + updatePackageJsonFromNode( editablePkgJson, actualTree, @@ -364,13 +381,26 @@ export async function agentFix( newVersion, rangeStyle, ) + // eslint-disable-next-line no-await-in-loop - if (!(await editablePkgJson.save({ ignoreWhitespace: true }))) { - debugFn('notice', `skip: ${workspace}/package.json unchanged`) + const unstagedCResult = await gitUnstagedModifiedFiles(cwd) + const moddedFilepaths = unstagedCResult.ok + ? unstagedCResult.data.filter(filepath => { + const basename = path.basename(filepath) + return ( + basename === 'package.json' || + basename === pkgEnvDetails.lockName + ) + }) + : [] + if (!moddedFilepaths.length) { + logger.warn( + 'Unexpected condition: Nothing to commit, skipping PR creation.', + ) // Reset things just in case. - if (ciEnv) { + if (fixEnv.isCi) { // eslint-disable-next-line no-await-in-loop - await gitResetAndClean(ciEnv.baseBranch, cwd) + await gitResetAndClean(fixEnv.baseBranch, cwd) } continue infosLoop } @@ -402,7 +432,7 @@ export async function agentFix( // eslint-disable-next-line no-await-in-loop await afterInstall( editablePkgJson, - name, + packument, oldVersion, newVersion, vulnerableVersionRange, @@ -426,48 +456,9 @@ export async function agentFix( spinner?.stop() // Check repoInfo to make TypeScript happy. - if (!errored && ciEnv?.repoInfo) { + if (!errored && fixEnv.isCi && fixEnv.repoInfo) { try { - // eslint-disable-next-line no-await-in-loop - const unstagedCResult = await gitUnstagedModifiedFiles(cwd) - if (!unstagedCResult.ok) { - logger.warn( - 'Unexpected condition: Nothing to commit, skipping PR creation.', - ) - continue - } - const moddedFilepaths = unstagedCResult.data.filter(filepath => { - const basename = path.basename(filepath) - return ( - basename === 'package.json' || - basename === pkgEnvDetails.lockName - ) - }) - if (!moddedFilepaths.length) { - logger.warn( - 'Unexpected condition: Nothing to commit, skipping PR creation.', - ) - continue infosLoop - } - - const branch = getSocketBranchName(oldPurl, newVersion, workspace) - let skipPr = false if ( - // eslint-disable-next-line no-await-in-loop - await prExistForBranch( - ciEnv.repoInfo.owner, - ciEnv.repoInfo.repo, - branch, - ) - ) { - skipPr = true - debugFn('notice', `skip: branch "${branch}" exists`) - } - // eslint-disable-next-line no-await-in-loop - else if (await gitRemoteBranchExists(branch, cwd)) { - skipPr = true - debugFn('notice', `skip: remote branch "${branch}" exists`) - } else if ( // eslint-disable-next-line no-await-in-loop !(await gitCreateAndPushBranch( branch, @@ -475,19 +466,16 @@ export async function agentFix( moddedFilepaths, { cwd, - email: ciEnv.gitEmail, - user: ciEnv.gitUser, + email: fixEnv.gitEmail, + user: fixEnv.gitUser, }, )) ) { - skipPr = true logger.warn( 'Unexpected condition: Push failed, skipping PR creation.', ) - } - if (skipPr) { // eslint-disable-next-line no-await-in-loop - await gitResetAndClean(ciEnv.baseBranch, cwd) + await gitResetAndClean(fixEnv.baseBranch, cwd) // eslint-disable-next-line no-await-in-loop const maybeActualTree = await installer(pkgEnvDetails, { cwd, @@ -508,12 +496,12 @@ export async function agentFix( // eslint-disable-next-line no-await-in-loop await Promise.allSettled([ setGitRemoteGithubRepoUrl( - ciEnv.repoInfo.owner, - ciEnv.repoInfo.repo, - ciEnv.githubToken!, + fixEnv.repoInfo.owner, + fixEnv.repoInfo.repo, + fixEnv.githubToken!, cwd, ), - cleanupOpenPrs(ciEnv.repoInfo.owner, ciEnv.repoInfo.repo, { + cleanupPrs(fixEnv.repoInfo.owner, fixEnv.repoInfo.repo, { newVersion, purl: oldPurl, workspace, @@ -521,13 +509,13 @@ export async function agentFix( ]) // eslint-disable-next-line no-await-in-loop const prResponse = await openPr( - ciEnv.repoInfo.owner, - ciEnv.repoInfo.repo, + fixEnv.repoInfo.owner, + fixEnv.repoInfo.repo, branch, oldPurl, newVersion, { - baseBranch: ciEnv.baseBranch, + baseBranch: fixEnv.baseBranch, cwd, workspace, }, @@ -561,10 +549,10 @@ export async function agentFix( } } - if (ciEnv) { + if (fixEnv.isCi) { spinner?.start() // eslint-disable-next-line no-await-in-loop - await gitResetAndClean(ciEnv.baseBranch, cwd) + await gitResetAndClean(fixEnv.baseBranch, cwd) // eslint-disable-next-line no-await-in-loop const maybeActualTree = await installer(pkgEnvDetails, { cwd, @@ -578,12 +566,12 @@ export async function agentFix( } } if (errored) { - if (!ciEnv) { + if (!fixEnv.isCi) { spinner?.start() // eslint-disable-next-line no-await-in-loop await revertInstall( editablePkgJson, - name, + packument, oldVersion, newVersion, vulnerableVersionRange, diff --git a/src/commands/fix/fix-branch-helpers.mts b/src/commands/fix/fix-branch-helpers.mts index bd5954b32..6387a2ea5 100644 --- a/src/commands/fix/fix-branch-helpers.mts +++ b/src/commands/fix/fix-branch-helpers.mts @@ -1,52 +1,48 @@ -import { debugFn, isDebug } from '@socketsecurity/registry/lib/debug' +import { debugDir, debugFn, isDebug } from '@socketsecurity/registry/lib/debug' import { resolvePackageName } from '@socketsecurity/registry/lib/packages' import { + genericSocketBranchParser, getSocketBranchFullNameComponent, getSocketBranchPurlTypeComponent, } from './git.mts' import { getPurlObject } from '../../utils/purl.mts' -import type { CiEnv } from './fix-env-helpers.mts' -import type { SocketBranchParseResult } from './git.mts' -import type { PrMatch } from './open-pr.mts' +import type { FixEnv } from './fix-env-helpers.mts' +import type { PrMatch } from './pull-request.mts' -export function getActiveBranchesForPackage( - ciEnv: CiEnv | null | undefined, +export function getPrsForPurl( + fixEnv: FixEnv | null | undefined, partialPurl: string, - openPrs: PrMatch[], -): SocketBranchParseResult[] { - if (!ciEnv) { +): PrMatch[] { + if (!fixEnv) { return [] } - const activeBranches: SocketBranchParseResult[] = [] + const prs: PrMatch[] = [] const partialPurlObj = getPurlObject(partialPurl) const branchFullName = getSocketBranchFullNameComponent(partialPurlObj) const branchPurlType = getSocketBranchPurlTypeComponent(partialPurlObj) - for (const pr of openPrs) { - const parsedBranch = ciEnv.branchParser(pr.headRefName) + for (const pr of fixEnv.prs) { + const parsedBranch = genericSocketBranchParser(pr.headRefName) if ( branchPurlType === parsedBranch?.type && branchFullName === parsedBranch?.fullName ) { - activeBranches.push(parsedBranch) + prs.push(pr) } } - if (isDebug('notice')) { + if (isDebug('notice,inspect')) { const fullName = resolvePackageName(partialPurlObj) - if (activeBranches.length) { - debugFn( - 'notice', - `found: ${activeBranches.length} active branches for ${fullName}\n`, - activeBranches, - ) - } else if (openPrs.length) { - debugFn('notice', `miss: 0 active branches found for ${fullName}`) + if (prs.length) { + debugFn('notice', `found: ${prs.length} PRs for ${fullName}`) + debugDir('inspect', { prs }) + } else if (fixEnv.prs.length) { + debugFn('notice', `miss: 0 PRs found for ${fullName}`) } } - return activeBranches + return prs } diff --git a/src/commands/fix/fix-env-helpers.mts b/src/commands/fix/fix-env-helpers.mts index 9b16254e3..ffc0768b2 100644 --- a/src/commands/fix/fix-env-helpers.mts +++ b/src/commands/fix/fix-env-helpers.mts @@ -1,15 +1,11 @@ import { debugFn } from '@socketsecurity/registry/lib/debug' -import { - createSocketBranchParser, - getBaseGitBranch, - gitRepoInfo, -} from './git.mts' -import { getOpenSocketPrs } from './open-pr.mts' +import { getBaseGitBranch, gitRepoInfo } from './git.mts' +import { getSocketPrs } from './pull-request.mts' import constants from '../../constants.mts' -import type { RepoInfo, SocketBranchParser } from './git.mts' -import type { PrMatch } from './open-pr.mts' +import type { RepoInfo } from './git.mts' +import type { PrMatch } from './pull-request.mts' async function getEnvRepoInfo( cwd?: string | undefined, @@ -27,50 +23,41 @@ async function getEnvRepoInfo( repo: ownerSlashRepo.slice(slashIndex + 1), } } + debugFn('notice', 'falling back to `git remote get-url origin`') return await gitRepoInfo(cwd) } -export interface CiEnv { +export interface FixEnv { + baseBranch: string gitEmail: string - gitUser: string githubToken: string - repoInfo: RepoInfo - baseBranch: string - branchParser: SocketBranchParser + gitUser: string + isCi: boolean + prs: PrMatch[] + repoInfo: RepoInfo | null } -export async function getCiEnv(): Promise { +export async function getFixEnv(): Promise { + const baseBranch = await getBaseGitBranch() const gitEmail = constants.ENV.SOCKET_CLI_GIT_USER_EMAIL const gitUser = constants.ENV.SOCKET_CLI_GIT_USER_NAME const githubToken = constants.ENV.SOCKET_CLI_GITHUB_TOKEN const isCi = !!(constants.ENV.CI && gitEmail && gitUser && githubToken) - if (!isCi) { - return null - } - const baseBranch = await getBaseGitBranch() - if (!baseBranch) { - return null - } const repoInfo = await getEnvRepoInfo() - if (!repoInfo) { - return null - } + const prs = + isCi && repoInfo + ? await getSocketPrs(repoInfo.owner, repoInfo.repo, { + author: gitUser, + states: 'all', + }) + : [] return { + baseBranch, gitEmail, - gitUser, githubToken, + gitUser, + isCi, + prs, repoInfo, - baseBranch, - branchParser: createSocketBranchParser(), } } - -export async function getOpenPrsForEnvironment( - env: CiEnv | null | undefined, -): Promise { - return env - ? await getOpenSocketPrs(env.repoInfo.owner, env.repoInfo.repo, { - author: env.gitUser, - }) - : [] -} diff --git a/src/commands/fix/git.mts b/src/commands/fix/git.mts index d1060c5d4..9e93633d1 100644 --- a/src/commands/fix/git.mts +++ b/src/commands/fix/git.mts @@ -73,6 +73,8 @@ export function createSocketBranchParser( } } +export const genericSocketBranchParser = createSocketBranchParser() + export async function getBaseGitBranch(cwd = process.cwd()): Promise { // Lazily access constants.ENV properties. const { GITHUB_BASE_REF, GITHUB_REF_NAME, GITHUB_REF_TYPE } = constants.ENV @@ -109,10 +111,10 @@ export function getSocketBranchFullNameComponent( ? PackageURL.fromString(`pkg:unknown/${pkgName}`) : pkgName, ) - const fmtMaybeNamespace = purlObj.namespace + const branchMaybeNamespace = purlObj.namespace ? `${formatBranchName(purlObj.namespace)}--` : '' - return `${fmtMaybeNamespace}${formatBranchName(purlObj.name)}` + return `${branchMaybeNamespace}${formatBranchName(purlObj.name)}` } export function getSocketBranchName( @@ -121,12 +123,12 @@ export function getSocketBranchName( workspace?: string | undefined, ): string { const purlObj = getPurlObject(purl) - const fmtType = getSocketBranchPurlTypeComponent(purlObj) - const fmtWorkspace = getSocketBranchWorkspaceComponent(workspace) - const fmtFullName = getSocketBranchFullNameComponent(purlObj) - const fmtVersion = getSocketBranchPackageVersionComponent(purlObj.version!) - const fmtNewVersion = formatBranchName(newVersion) - return `socket/${fmtType}/${fmtWorkspace}/${fmtFullName}_${fmtVersion}_${fmtNewVersion}` + const branchType = getSocketBranchPurlTypeComponent(purlObj) + const branchWorkspace = getSocketBranchWorkspaceComponent(workspace) + const branchFullName = getSocketBranchFullNameComponent(purlObj) + const branchVersion = getSocketBranchPackageVersionComponent(purlObj.version!) + const branchNewVersion = formatBranchName(newVersion) + return `socket/${branchType}/${branchWorkspace}/${branchFullName}_${branchVersion}_${branchNewVersion}` } export function getSocketBranchPackageVersionComponent( @@ -288,7 +290,7 @@ export async function gitRepoInfo( debugFn('error', 'git: unmatched git remote URL format') debugDir('inspect', { remoteUrl }) } catch (e) { - debugFn('error', 'caught: git remote get-url origin failed') + debugFn('error', 'caught: `git remote get-url origin` failed') debugDir('inspect', { error: e }) } return null diff --git a/src/commands/fix/npm-fix.mts b/src/commands/fix/npm-fix.mts index 232b29d2f..a3f54619f 100644 --- a/src/commands/fix/npm-fix.mts +++ b/src/commands/fix/npm-fix.mts @@ -1,20 +1,23 @@ import { debugDir, debugFn, isDebug } from '@socketsecurity/registry/lib/debug' import { agentFix } from './agent-fix.mts' -import { getCiEnv, getOpenPrsForEnvironment } from './fix-env-helpers.mts' import { getActualTree } from './get-actual-tree.mts' -import { getAlertsMapOptions } from './shared.mts' +import { getFixAlertsMapOptions } from './shared.mts' +import { Arborist } from '../../shadow/npm/arborist/index.mts' import { - Arborist, - SAFE_ARBORIST_REIFY_OPTIONS_OVERRIDES, -} from '../../shadow/npm/arborist/index.mts' -import { getAlertsMapFromArborist } from '../../shadow/npm/arborist-helpers.mts' + findPackageNode, + getAlertsMapFromArborist, + updateNode, +} from '../../shadow/npm/arborist-helpers.mts' import { runAgentInstall } from '../../utils/agent.mts' import { getAlertsMapFromPurls } from '../../utils/alerts-map.mts' import { getNpmConfig } from '../../utils/npm-config.mts' import type { FixConfig, InstallOptions } from './agent-fix.mts' -import type { NodeClass } from '../../shadow/npm/arborist/types.mts' +import type { + ArboristInstance, + NodeClass, +} from '../../shadow/npm/arborist/types.mts' import type { CResult } from '../../types.mts' import type { EnvDetails } from '../../utils/package-environment.mts' import type { PackageJson } from '@socketsecurity/registry/lib/packages' @@ -42,38 +45,28 @@ export async function npmFix( pkgEnvDetails: EnvDetails, fixConfig: FixConfig, ): Promise> { - const { limit, purls, spinner } = fixConfig + const { purls, spinner } = fixConfig spinner?.start() - const ciEnv = await getCiEnv() - const openPrs = ciEnv ? await getOpenPrsForEnvironment(ciEnv) : [] - + let arb: ArboristInstance let actualTree: NodeClass | undefined let alertsMap try { if (purls.length) { - alertsMap = await getAlertsMapFromPurls( - purls, - getAlertsMapOptions({ limit: Math.max(limit, openPrs.length) }), - ) + alertsMap = await getAlertsMapFromPurls(purls, getFixAlertsMapOptions()) } else { const flatConfig = await getNpmConfig({ npmVersion: pkgEnvDetails.agentVersion, }) - - const arb = new Arborist({ + arb = new Arborist({ path: pkgEnvDetails.pkgPath, ...flatConfig, - ...SAFE_ARBORIST_REIFY_OPTIONS_OVERRIDES, }) actualTree = await arb.reify() // Calling arb.reify() creates the arb.diff object, nulls-out arb.idealTree, // and populates arb.actualTree. - alertsMap = await getAlertsMapFromArborist( - arb, - getAlertsMapOptions({ limit: Math.max(limit, openPrs.length) }), - ) + alertsMap = await getAlertsMapFromArborist(arb, getFixAlertsMapOptions()) } } catch (e) { spinner?.stop() @@ -94,7 +87,7 @@ export async function npmFix( alertsMap, install, { - async beforeInstall(editablePkgJson) { + async beforeInstall(editablePkgJson, packument, oldVersion, newVersion) { revertData = { ...(editablePkgJson.content.dependencies && { dependencies: { ...editablePkgJson.content.dependencies }, @@ -108,6 +101,13 @@ export async function npmFix( peerDependencies: { ...editablePkgJson.content.peerDependencies }, }), } as PackageJson + + const idealTree = await arb.buildIdealTree() + const node = findPackageNode(idealTree, packument.name, oldVersion) + if (node) { + updateNode(node, newVersion, packument.versions[newVersion]!) + await arb.reify() + } }, async revertInstall(editablePkgJson) { if (revertData) { @@ -115,8 +115,6 @@ export async function npmFix( } }, }, - ciEnv, - openPrs, fixConfig, ) } diff --git a/src/commands/fix/pnpm-fix.mts b/src/commands/fix/pnpm-fix.mts index ee8be5895..3e60a2bee 100644 --- a/src/commands/fix/pnpm-fix.mts +++ b/src/commands/fix/pnpm-fix.mts @@ -4,9 +4,8 @@ import { debugDir, debugFn, isDebug } from '@socketsecurity/registry/lib/debug' import { hasKeys } from '@socketsecurity/registry/lib/objects' import { agentFix } from './agent-fix.mts' -import { getCiEnv, getOpenPrsForEnvironment } from './fix-env-helpers.mts' import { getActualTree } from './get-actual-tree.mts' -import { getAlertsMapOptions } from './shared.mts' +import { getFixAlertsMapOptions } from './shared.mts' import constants from '../../constants.mts' import { runAgentInstall } from '../../utils/agent.mts' import { @@ -61,7 +60,7 @@ export async function pnpmFix( pkgEnvDetails: EnvDetails, fixConfig: FixConfig, ): Promise> { - const { cwd, limit, purls, spinner } = fixConfig + const { cwd, purls, spinner } = fixConfig spinner?.start() @@ -102,20 +101,11 @@ export async function pnpmFix( } } - const ciEnv = await getCiEnv() - const openPrs = ciEnv ? await getOpenPrsForEnvironment(ciEnv) : [] - let alertsMap try { alertsMap = purls.length - ? await getAlertsMapFromPurls( - purls, - getAlertsMapOptions({ limit: Math.max(limit, openPrs.length) }), - ) - : await getAlertsMapFromPnpmLockfile( - lockfile, - getAlertsMapOptions({ limit: Math.max(limit, openPrs.length) }), - ) + ? await getAlertsMapFromPurls(purls, getFixAlertsMapOptions()) + : await getAlertsMapFromPnpmLockfile(lockfile, getFixAlertsMapOptions()) } catch (e) { spinner?.stop() debugFn('error', 'caught: PURL API') @@ -139,7 +129,7 @@ export async function pnpmFix( { async beforeInstall( editablePkgJson, - name, + packument, oldVersion, newVersion, vulnerableVersionRange, @@ -155,7 +145,7 @@ export async function pnpmFix( const oldPnpmSection = editablePkgJson.content[PNPM] as | StringKeyValueObject | undefined - const overrideKey = `${name}@${vulnerableVersionRange}` + const overrideKey = `${packument.name}@${vulnerableVersionRange}` revertOverrides = undefined revertOverridesSrc = extractOverridesFromPnpmLockSrc(lockSrc) @@ -228,8 +218,6 @@ export async function pnpmFix( } }, }, - ciEnv, - openPrs, fixConfig, ) } diff --git a/src/commands/fix/open-pr.mts b/src/commands/fix/pull-request.mts similarity index 82% rename from src/commands/fix/open-pr.mts rename to src/commands/fix/pull-request.mts index 7738a7d71..e4ba13ebc 100644 --- a/src/commands/fix/open-pr.mts +++ b/src/commands/fix/pull-request.mts @@ -17,7 +17,7 @@ import { isNonEmptyString } from '@socketsecurity/registry/lib/strings' import { createSocketBranchParser, - getSocketBranchPattern, + genericSocketBranchParser, getSocketPullRequestBody, getSocketPullRequestTitle, } from './git.mts' @@ -25,6 +25,7 @@ import constants from '../../constants.mts' import { safeStatsSync } from '../../utils/fs.mts' import { getPurlObject } from '../../utils/purl.mts' +import type { SocketBranchParseResult } from './git.mts' import type { SocketArtifact } from '../../utils/alert/artifact.mts' import type { components } from '@octokit/openapi-types' import type { OctokitResponse } from '@octokit/types' @@ -110,7 +111,7 @@ async function writeCache(key: string, data: JsonContent): Promise { export type Pr = components['schemas']['pull-request'] -export type MERGE_STATE_STATUS = +export type GQL_MERGE_STATE_STATUS = | 'BEHIND' | 'BLOCKED' | 'CLEAN' @@ -120,12 +121,16 @@ export type MERGE_STATE_STATUS = | 'UNKNOWN' | 'UNSTABLE' +export type GQL_PR_STATE = 'OPEN' | 'CLOSED' | 'MERGED' + export type PrMatch = { author: string baseRefName: string headRefName: string - mergeStateStatus: MERGE_STATE_STATUS + mergeStateStatus: GQL_MERGE_STATE_STATUS number: number + parsedBranch: SocketBranchParseResult + state: GQL_PR_STATE title: string } @@ -135,16 +140,12 @@ export type CleanupPrsOptions = { workspace?: string | undefined } -export async function cleanupOpenPrs( +export async function cleanupPrs( owner: string, repo: string, options?: CleanupPrsOptions | undefined, ): Promise { - const contextualMatches = await getOpenSocketPrsWithContext( - owner, - repo, - options, - ) + const contextualMatches = await getSocketPrsWithContext(owner, repo, options) if (!contextualMatches.length) { return [] @@ -272,21 +273,20 @@ export async function enablePrAutoMerge({ return { enabled: false } } -export type GetOpenSocketPrsOptions = { +export type SocketPrsOptions = { author?: string | undefined newVersion?: string | undefined purl?: string | undefined + states?: string[] | string | undefined workspace?: string | undefined } -export async function getOpenSocketPrs( +export async function getSocketPrs( owner: string, repo: string, - options?: GetOpenSocketPrsOptions | undefined, + options?: SocketPrsOptions | undefined, ): Promise { - return (await getOpenSocketPrsWithContext(owner, repo, options)).map( - d => d.match, - ) + return (await getSocketPrsWithContext(owner, repo, options)).map(d => d.match) } type ContextualPrMatch = { @@ -301,19 +301,26 @@ type ContextualPrMatch = { match: PrMatch } -async function getOpenSocketPrsWithContext( +async function getSocketPrsWithContext( owner: string, repo: string, - options_?: GetOpenSocketPrsOptions | undefined, + options?: SocketPrsOptions | undefined, ): Promise { - const options = { __proto__: null, ...options_ } as GetOpenSocketPrsOptions - const { author } = options + const { author, states: statesValue = 'all' } = { + __proto__: null, + ...options, + } as SocketPrsOptions const checkAuthor = isNonEmptyString(author) const octokit = getOctokit() const octokitGraphql = getOctokitGraphql() - const branchPattern = getSocketBranchPattern(options) - const contextualMatches: ContextualPrMatch[] = [] + const states = ( + typeof statesValue === 'string' + ? statesValue.toLowerCase() === 'all' + ? ['OPEN', 'CLOSED', 'MERGED'] + : [statesValue] + : statesValue + ).map(s => s.toUpperCase()) try { // Optimistically fetch only the first 50 open PRs using GraphQL to minimize // API quota usage. Fallback to REST if no matching PRs are found. @@ -321,9 +328,9 @@ async function getOpenSocketPrsWithContext( const gqlResp = await cacheFetch(gqlCacheKey, () => octokitGraphql( ` - query($owner: String!, $repo: String!) { + query($owner: String!, $repo: String!, $states: [PullRequestState!]) { repository(owner: $owner, name: $repo) { - pullRequests(first: 50, states: OPEN, orderBy: {field: CREATED_AT, direction: DESC}) { + pullRequests(first: 50, states: $states, orderBy: {field: CREATED_AT, direction: DESC}) { nodes { author { login @@ -332,13 +339,18 @@ async function getOpenSocketPrsWithContext( headRefName mergeStateStatus number + state title } } } } `, - { owner, repo }, + { + owner, + repo, + states, + }, ), ) @@ -348,8 +360,9 @@ async function getOpenSocketPrsWithContext( } baseRefName: string headRefName: string - mergeStateStatus: MERGE_STATE_STATUS + mergeStateStatus: GQL_MERGE_STATE_STATUS number: number + state: GQL_PR_STATE title: string } const nodes: GqlPrNode[] = @@ -358,8 +371,8 @@ async function getOpenSocketPrsWithContext( const node = nodes[i]! const login = node.author?.login const matchesAuthor = checkAuthor ? login === author : true - const matchesBranch = branchPattern.test(node.headRefName) - if (matchesAuthor && matchesBranch) { + const parsedBranch = genericSocketBranchParser(node.headRefName) + if (matchesAuthor && parsedBranch) { contextualMatches.push({ context: { apiType: 'graphql', @@ -372,6 +385,7 @@ async function getOpenSocketPrsWithContext( match: { ...node, author: login ?? '', + parsedBranch, }, }) } @@ -383,49 +397,59 @@ async function getOpenSocketPrsWithContext( } // Fallback to REST if GraphQL found no matching PRs. - let allOpenPrs: Pr[] | undefined - const cacheKey = `${repo}-open-prs` + let allPrs: Pr[] | undefined + const cacheKey = `${repo}-pull-requests` try { - allOpenPrs = await cacheFetch( + allPrs = await cacheFetch( cacheKey, async () => (await octokit.paginate(octokit.pulls.list, { owner, repo, - state: 'open', + state: 'all', per_page: 100, })) as Pr[], ) } catch {} - if (!allOpenPrs) { + if (!allPrs) { return contextualMatches } - for (let i = 0, { length } = allOpenPrs; i < length; i += 1) { - const pr = allOpenPrs[i]! + for (let i = 0, { length } = allPrs; i < length; i += 1) { + const pr = allPrs[i]! const login = pr.user?.login + const headRefName = pr.head.ref const matchesAuthor = checkAuthor ? login === author : true - const matchesBranch = branchPattern.test(pr.head.ref) - if (matchesAuthor && matchesBranch) { + const parsedBranch = genericSocketBranchParser(headRefName) + if (matchesAuthor && parsedBranch) { + // Upper cased mergeable_state is equivalent to mergeStateStatus. + // https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#get-a-pull-request + const mergeStateStatus = (pr.mergeable_state?.toUpperCase?.() ?? + 'UNKNOWN') as GQL_MERGE_STATE_STATUS + // The REST API does not have a distinct merged state for pull requests. + // Instead, a merged pull request is represented as a closed pull request + // with a non-null merged_at timestamp. + const state = ( + pr.merged_at ? 'MERGED' : pr.state.toUpperCase() + ) as GQL_PR_STATE contextualMatches.push({ context: { apiType: 'rest', cacheKey, - data: allOpenPrs, + data: allPrs, entry: pr, index: i, - parent: allOpenPrs, + parent: allPrs, }, match: { author: login ?? '', baseRefName: pr.base.ref, - headRefName: pr.head.ref, - // Upper cased mergeable_state is equivalent to mergeStateStatus. - // https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#get-a-pull-request - mergeStateStatus: (pr.mergeable_state?.toUpperCase?.() ?? - 'UNKNOWN') as MERGE_STATE_STATUS, + headRefName, + mergeStateStatus, number: pr.number, + parsedBranch, + state, title: pr.title, }, }) @@ -494,7 +518,7 @@ export async function prExistForBranch( owner, repo, head: `${owner}:${branch}`, - state: 'open', + state: 'all', per_page: 1, }) return prs.length > 0 diff --git a/src/commands/fix/shared.mts b/src/commands/fix/shared.mts index b6c70351e..e4f14cc3e 100644 --- a/src/commands/fix/shared.mts +++ b/src/commands/fix/shared.mts @@ -3,7 +3,7 @@ import type { Remap } from '@socketsecurity/registry/lib/objects' export const CMD_NAME = 'socket fix' -export function getAlertsMapOptions( +export function getFixAlertsMapOptions( options: GetAlertsMapFromPurlsOptions = {}, ) { return { diff --git a/src/shadow/npm/arborist-helpers.mts b/src/shadow/npm/arborist-helpers.mts index 3ce442102..28fd80edc 100644 --- a/src/shadow/npm/arborist-helpers.mts +++ b/src/shadow/npm/arborist-helpers.mts @@ -1,11 +1,13 @@ import semver from 'semver' +import { PackageURL } from '@socketregistry/packageurl-js' import { getManifestData } from '@socketsecurity/registry' import { debugFn } from '@socketsecurity/registry/lib/debug' import { hasOwn } from '@socketsecurity/registry/lib/objects' import { fetchPackagePackument } from '@socketsecurity/registry/lib/packages' import constants from '../../constants.mts' +import { Edge } from './arborist/index.mts' import { DiffAction } from './arborist/types.mts' import { getAlertsMapFromPurls } from '../../utils/alerts-map.mts' import { type AliasResult, npa } from '../../utils/npm-package-arg.mts' @@ -15,6 +17,7 @@ import { idToNpmPurl } from '../../utils/spec.mts' import type { ArboristInstance, Diff, + EdgeClass, LinkClass, NodeClass, } from './arborist/types.mts' @@ -182,7 +185,6 @@ export async function getAlertsMapFromArborist( __proto__: null, consolidate: false, include: undefined, - limit: Infinity, nothrow: false, ...options_, } as GetAlertsMapFromArboristOptions diff --git a/src/utils/alerts-map.mts b/src/utils/alerts-map.mts index b93e5ab95..4a8e67426 100644 --- a/src/utils/alerts-map.mts +++ b/src/utils/alerts-map.mts @@ -36,7 +36,6 @@ export async function getAlertsMapFromPnpmLockfile( export type GetAlertsMapFromPurlsOptions = { consolidate?: boolean | undefined include?: AlertIncludeFilter | undefined - limit?: number | undefined overrides?: { [key: string]: string } | undefined nothrow?: boolean | undefined spinner?: Spinner | undefined @@ -50,7 +49,6 @@ export async function getAlertsMapFromPurls( __proto__: null, consolidate: false, include: undefined, - limit: Infinity, nothrow: false, ...options_, } as GetAlertsMapFromPurlsOptions diff --git a/src/utils/socket-package-alert.mts b/src/utils/socket-package-alert.mts index fae2731df..744631504 100644 --- a/src/utils/socket-package-alert.mts +++ b/src/utils/socket-package-alert.mts @@ -338,27 +338,23 @@ export type CveInfoByPartialPurl = Map export type GetCveInfoByPackageOptions = { exclude?: CveExcludeFilter | undefined - limit?: number | undefined } export function getCveInfoFromAlertsMap( alertsMap: AlertsByPurl, - options_?: GetCveInfoByPackageOptions | undefined, + options?: GetCveInfoByPackageOptions | undefined, ): CveInfoByPartialPurl | null { - const options = { + const { exclude: exclude_ } = { __proto__: null, - exclude: undefined, - limit: Infinity, - ...options_, + ...options, } as GetCveInfoByPackageOptions - - options.exclude = { + const exclude = { __proto__: null, - ...options.exclude, + ...exclude_, } as CveExcludeFilter - let count = 0 let infoByPartialPurl: CveInfoByPartialPurl | null = null + // eslint-disable-next-line no-unused-labels alertsMapLoop: for (const { 0: purl, 1: sockPkgAlerts } of alertsMap) { const purlObj = getPurlObject(purl) const partialPurl = new PackageURL( @@ -371,7 +367,7 @@ export function getCveInfoFromAlertsMap( const alert = sockPkgAlert.raw if ( alert.fix?.type !== ALERT_FIX_TYPE.cve || - (options.exclude.upgradable && + (exclude.upgradable && getManifestData(sockPkgAlert.ecosystem as any, name)) ) { continue sockPkgAlertsLoop @@ -406,9 +402,6 @@ export function getCveInfoFromAlertsMap( .replace(/; +/g, ' || '), ).format(), }) - if (++count >= options.limit!) { - break alertsMapLoop - } continue sockPkgAlertsLoop } catch (e) { error = e