diff --git a/tfjs-node/src/io/file_system.ts b/tfjs-node/src/io/file_system.ts index 74b4cc582d..9990e0f248 100644 --- a/tfjs-node/src/io/file_system.ts +++ b/tfjs-node/src/io/file_system.ts @@ -17,7 +17,7 @@ import * as tf from '@tensorflow/tfjs'; import * as fs from 'fs'; -import {dirname, join, resolve} from 'path'; +import {dirname, join, resolve, sep} from 'path'; import {promisify} from 'util'; import {toArrayBuffer} from './io_utils'; @@ -25,6 +25,7 @@ const stat = promisify(fs.stat); const writeFile = promisify(fs.writeFile); const readFile = promisify(fs.readFile); const mkdir = promisify(fs.mkdir); +const realpath = promisify(fs.realpath); function doesNotExistHandler(name: string): (e: NodeJS.ErrnoException) => never { @@ -185,11 +186,26 @@ export class NodeFileSystem implements tf.io.IOHandler { weightsManifest: tf.io.WeightsManifestConfig, path: string): Promise<[tf.io.WeightsManifestEntry[], ArrayBuffer]> { const dirName = dirname(path); + // Resolve the model directory to its canonical path so that symlinks + // inside the directory cannot be used to escape the boundary check. + const resolvedBase = + await realpath(resolve(dirName)).catch(() => resolve(dirName)); const buffers: Buffer[] = []; const weightSpecs: tf.io.WeightsManifestEntry[] = []; for (const group of weightsManifest) { - for (const path of group.paths) { - const weightFilePath = join(dirName, path); + for (const weightPath of group.paths) { + const weightFilePath = join(dirName, weightPath); + // Verify that the resolved weight path stays within the model + // directory. This prevents path-traversal attacks where a malicious + // model.json uses "../" sequences to read arbitrary files. + const resolvedWeight = await realpath(weightFilePath) + .catch(() => resolve(dirName, weightPath)); + if (!resolvedWeight.startsWith(resolvedBase + sep)) { + throw new Error( + `Weight file path "${weightPath}" is outside the model ` + + `directory. Loading weights from paths outside the model ` + + `directory is not supported.`); + } const buffer = await readFile(weightFilePath) .catch(doesNotExistHandler('Weight file')); buffers.push(buffer);