diff --git a/index.d.ts b/index.d.ts index 77679869..0142bcd6 100644 --- a/index.d.ts +++ b/index.d.ts @@ -45,6 +45,7 @@ export interface Size { } export const MAX_U16_VALUE: number; export const MIN_U16_VALUE: number; +export declare function getSyntheticEofSequence(): Buffer; /** Resize the terminal. */ export declare function ptyResize(fd: number, size: Size): void; /** @@ -66,5 +67,10 @@ export declare class Pty { * once (it will error the second time). The caller is responsible for closing the file * descriptor. */ - takeFd(): c_int; + takeControllerFd(): c_int; + /** + * Closes the owned file descriptor for the PTY controller. The Nodejs side must call this + * when it is done with the file descriptor to avoid leaking FDs. + */ + dropUserFd(): void; } diff --git a/index.js b/index.js index 5b9f1d8b..64824e59 100644 --- a/index.js +++ b/index.js @@ -326,6 +326,7 @@ const { Operation, MAX_U16_VALUE, MIN_U16_VALUE, + getSyntheticEofSequence, ptyResize, setCloseOnExec, getCloseOnExec, @@ -335,6 +336,7 @@ module.exports.Pty = Pty; module.exports.Operation = Operation; module.exports.MAX_U16_VALUE = MAX_U16_VALUE; module.exports.MIN_U16_VALUE = MIN_U16_VALUE; +module.exports.getSyntheticEofSequence = getSyntheticEofSequence; module.exports.ptyResize = ptyResize; module.exports.setCloseOnExec = setCloseOnExec; module.exports.getCloseOnExec = getCloseOnExec; diff --git a/npm/darwin-arm64/package.json b/npm/darwin-arm64/package.json index bf9f0bf5..c72485d4 100644 --- a/npm/darwin-arm64/package.json +++ b/npm/darwin-arm64/package.json @@ -1,6 +1,6 @@ { "name": "@replit/ruspty-darwin-arm64", - "version": "3.5.3", + "version": "3.6.0", "os": [ "darwin" ], diff --git a/npm/darwin-x64/package.json b/npm/darwin-x64/package.json index 8dceb4ea..d0396f5d 100644 --- a/npm/darwin-x64/package.json +++ b/npm/darwin-x64/package.json @@ -1,6 +1,6 @@ { "name": "@replit/ruspty-darwin-x64", - "version": "3.5.3", + "version": "3.6.0", "os": [ "darwin" ], diff --git a/npm/linux-x64-gnu/package.json b/npm/linux-x64-gnu/package.json index a8baa81a..5a4bed3a 100644 --- a/npm/linux-x64-gnu/package.json +++ b/npm/linux-x64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@replit/ruspty-linux-x64-gnu", - "version": "3.5.3", + "version": "3.6.0", "os": [ "linux" ], diff --git a/package-lock.json b/package-lock.json index 2d1b14f4..f8c7190a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/ruspty", - "version": "3.5.3", + "version": "3.6.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@replit/ruspty", - "version": "3.5.3", + "version": "3.6.0", "license": "MIT", "devDependencies": { "@napi-rs/cli": "^2.18.4", diff --git a/package.json b/package.json index 1ad2399e..0fc4e87f 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@replit/ruspty", - "version": "3.5.3", + "version": "3.6.0", "main": "dist/wrapper.js", "types": "dist/wrapper.d.ts", "author": "Szymon Kaliski ", @@ -40,7 +40,7 @@ "build:wrapper": "tsup", "prepublishOnly": "napi prepublish -t npm", "test": "vitest run", - "test:ci": "vitest --reporter=verbose --reporter=github-actions run", + "test:ci": "vitest --reporter=verbose --reporter=github-actions --allowOnly run", "test:hang": "vitest run --reporter=hanging-process", "universal": "napi universal", "version": "napi version", diff --git a/src/lib.rs b/src/lib.rs index a75c33ed..8397c505 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,17 +7,14 @@ use std::os::fd::{FromRawFd, IntoRawFd, RawFd}; use std::os::unix::process::CommandExt; use std::process::{Command, Stdio}; use std::thread; -use std::time::Duration; -use backoff::backoff::Backoff; -use backoff::ExponentialBackoffBuilder; -use napi::bindgen_prelude::JsFunction; +use napi::bindgen_prelude::{Buffer, JsFunction}; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}; use napi::Status::GenericFailure; use napi::{self, Env}; use nix::errno::Errno; use nix::fcntl::{fcntl, FcntlArg, FdFlag, OFlag}; -use nix::libc::{self, c_int, ioctl, FIONREAD, TIOCOUTQ, TIOCSCTTY, TIOCSWINSZ}; +use nix::libc::{self, c_int, TIOCSCTTY, TIOCSWINSZ}; use nix::pty::{openpty, Winsize}; use nix::sys::termios::{self, SetArg}; @@ -31,6 +28,7 @@ mod sandbox; #[allow(dead_code)] struct Pty { controller_fd: Option, + user_fd: Option, /// The pid of the forked process. pub pid: u32, } @@ -89,61 +87,15 @@ pub const MAX_U16_VALUE: u16 = u16::MAX; #[napi] pub const MIN_U16_VALUE: u16 = u16::MIN; -fn cast_to_napi_error(err: Errno) -> napi::Error { - napi::Error::new(GenericFailure, err) -} - -// if the child process exits before the controller fd is fully read or the user fd is fully -// flushed, we might accidentally end in a case where onExit is called but js hasn't had -// the chance to fully read the controller fd -// let's wait until the controller fd is fully read before we call onExit -fn poll_pty_fds_until_read(controller_fd: RawFd, user_fd: RawFd) { - let mut backoff = ExponentialBackoffBuilder::default() - .with_initial_interval(Duration::from_millis(1)) - .with_max_interval(Duration::from_millis(100)) - .with_max_elapsed_time(Some(Duration::from_secs(1))) - .build(); - - loop { - // check both input and output queues for both FDs - let mut controller_inq: i32 = 0; - let mut controller_outq: i32 = 0; - let mut user_inq: i32 = 0; - let mut user_outq: i32 = 0; - - // safe because we're passing valid file descriptors and properly sized integers - unsafe { - // check bytes waiting to be read (FIONREAD, equivalent to TIOCINQ on Linux) - if ioctl(controller_fd, FIONREAD, &mut controller_inq) == -1 - || ioctl(user_fd, FIONREAD, &mut user_inq) == -1 - { - // break if we can't read - break; - } - - // check bytes waiting to be written (TIOCOUTQ) - if ioctl(controller_fd, TIOCOUTQ, &mut controller_outq) == -1 - || ioctl(user_fd, TIOCOUTQ, &mut user_outq) == -1 - { - // break if we can't read - break; - } - } +const SYNTHETIC_EOF: &[u8] = b"\x1B]7878\x1B\\"; - // if all queues are empty, we're done - if controller_inq == 0 && controller_outq == 0 && user_inq == 0 && user_outq == 0 { - break; - } +#[napi] +pub fn get_synthetic_eof_sequence() -> Buffer { + SYNTHETIC_EOF.into() +} - // apply backoff strategy - if let Some(d) = backoff.next_backoff() { - thread::sleep(d); - continue; - } else { - // we have exhausted our attempts - break; - } - } +fn cast_to_napi_error(err: Errno) -> napi::Error { + napi::Error::new(GenericFailure, err) } #[napi] @@ -347,9 +299,10 @@ impl Pty { thread::spawn(move || { let wait_result = child.wait(); - // try to wait for the controller fd to be fully read - poll_pty_fds_until_read(raw_controller_fd, raw_user_fd); - drop(user_fd); + // by this point, child has closed its copy of the user_fd + // lets inject our synthetic EOF OSC into the user_fd + // its ok to ignore the result here as we have a timeout on the nodejs side to handle if this write fails + let _ = write_syn_eof_to_fd(raw_user_fd); match wait_result { Ok(status) => { @@ -379,6 +332,7 @@ impl Pty { Ok(Pty { controller_fd: Some(controller_fd), + user_fd: Some(user_fd), pid, }) } @@ -388,7 +342,7 @@ impl Pty { /// descriptor. #[napi] #[allow(dead_code)] - pub fn take_fd(&mut self) -> Result { + pub fn take_controller_fd(&mut self) -> Result { if let Some(fd) = self.controller_fd.take() { Ok(fd.into_raw_fd()) } else { @@ -398,6 +352,15 @@ impl Pty { )) } } + + /// Closes the owned file descriptor for the PTY controller. The Nodejs side must call this + /// when it is done with the file descriptor to avoid leaking FDs. + #[napi] + #[allow(dead_code)] + pub fn drop_user_fd(&mut self) -> Result<(), napi::Error> { + self.user_fd.take(); + Ok(()) + } } /// Resize the terminal. @@ -492,3 +455,35 @@ fn set_nonblocking(fd: i32) -> Result<(), napi::Error> { } Ok(()) } + +fn write_syn_eof_to_fd(fd: libc::c_int) -> std::io::Result<()> { + let mut remaining = SYNTHETIC_EOF; + while !remaining.is_empty() { + match unsafe { + libc::write( + fd, + remaining.as_ptr() as *const libc::c_void, + remaining.len(), + ) + } { + -1 => { + let err = std::io::Error::last_os_error(); + if err.kind() == std::io::ErrorKind::Interrupted { + continue; + } + + return Err(err); + } + 0 => { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "write returned 0", + )); + } + n => { + remaining = &remaining[n as usize..]; + } + } + } + Ok(()) +} diff --git a/syntheticEof.ts b/syntheticEof.ts new file mode 100644 index 00000000..581b4127 --- /dev/null +++ b/syntheticEof.ts @@ -0,0 +1,85 @@ +import { Transform } from 'node:stream'; +import { getSyntheticEofSequence } from './index.js'; + +// keep in sync with lib.rs::SYNTHETIC_EOF +export const SYNTHETIC_EOF = getSyntheticEofSequence(); +export const EOF_EVENT = 'synthetic-eof'; + +// get the longest suffix of buffer that is a prefix of SYNTHETIC_EOF +function getBufferEndPrefixLength(buffer: Buffer) { + const maxLen = Math.min(buffer.length, SYNTHETIC_EOF.length); + for (let len = maxLen; len > 0; len--) { + let match = true; + for (let i = 0; i < len; i++) { + if (buffer[buffer.length - len + i] !== SYNTHETIC_EOF[i]) { + match = false; + break; + } + } + + if (match) { + return len; + } + } + + return 0; +} + +export class SyntheticEOFDetector extends Transform { + buffer: Buffer; + + constructor(options = {}) { + super(options); + this.buffer = Buffer.alloc(0); + } + + _transform(chunk: Buffer, _encoding: string, callback: () => void) { + const searchData = Buffer.concat([this.buffer, chunk]); + const eofIndex = searchData.indexOf(SYNTHETIC_EOF); + + if (eofIndex !== -1) { + // found EOF - emit everything before it + if (eofIndex > 0) { + this.push(searchData.subarray(0, eofIndex)); + } + + this.emit(EOF_EVENT); + + // emit everything after EOF (if any) and clear buffer + const afterEOF = searchData.subarray(eofIndex + SYNTHETIC_EOF.length); + if (afterEOF.length > 0) { + this.push(afterEOF); + } + + this.buffer = Buffer.alloc(0); + } else { + // no EOF - buffer potential partial match at end + + // get the longest suffix of buffer that is a prefix of SYNTHETIC_EOF + // and emit everything before it + // this is done for the case which the eof happened to be split across multiple chunks + const commonPrefixLen = getBufferEndPrefixLength(searchData); + + if (commonPrefixLen > 0) { + const emitSize = searchData.length - commonPrefixLen; + if (emitSize > 0) { + this.push(searchData.subarray(0, emitSize)); + } + this.buffer = searchData.subarray(emitSize); + } else { + this.push(searchData); + this.buffer = Buffer.alloc(0); + } + } + + callback(); + } + + _flush(callback: () => void) { + if (this.buffer.length > 0) { + this.push(this.buffer); + } + + callback(); + } +} diff --git a/tests/index.test.ts b/tests/index.test.ts index ae87f5bb..7d55bb10 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -77,9 +77,9 @@ describe('PTY', { repeats: 500 }, () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(buffer.trim()).toBe(message); - expect(getOpenFds()).toStrictEqual(oldFds); expect(pty.write.writable).toBe(false); expect(pty.read.readable).toBe(false); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('captures an exit code', async () => { @@ -132,7 +132,7 @@ describe('PTY', { repeats: 500 }, () => { const expectedResult = 'hello cat\r\nhello cat\r\n'; expect(result.trim()).toStrictEqual(expectedResult.trim()); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('can be started in non-interactive fashion', async () => { @@ -157,7 +157,7 @@ describe('PTY', { repeats: 500 }, () => { let result = buffer.toString(); const expectedResult = '\r\n'; expect(result.trim()).toStrictEqual(expectedResult.trim()); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('can be resized', async () => { @@ -211,7 +211,7 @@ describe('PTY', { repeats: 500 }, () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(state).toBe('done'); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('respects working directory', async () => { @@ -234,7 +234,7 @@ describe('PTY', { repeats: 500 }, () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(buffer.trim()).toBe(cwd); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('respects env', async () => { @@ -260,7 +260,7 @@ describe('PTY', { repeats: 500 }, () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(buffer.trim()).toBe(message); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test("resize after exit shouldn't throw", async () => { @@ -330,7 +330,7 @@ describe('PTY', { repeats: 500 }, () => { ).toBe(i); } - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('doesnt miss large output from fast commands', async () => { @@ -354,6 +354,29 @@ describe('PTY', { repeats: 500 }, () => { expect(buffer.toString().length).toBe(payload.length); }); + test('doesnt miss lots of lines from bash', async () => { + const payload = Array.from({ length: 5000 }, (_, i) => i).join('\n'); + let buffer = Buffer.from(''); + const onExit = vi.fn(); + + const pty = new Pty({ + command: 'bash', + args: ['-c', `echo -n "${payload}"`], + onExit, + }); + + const readStream = pty.read; + readStream.on('data', (data) => { + buffer = Buffer.concat([buffer, data]); + }); + + await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); + expect(onExit).toHaveBeenCalledWith(null, 0); + expect(buffer.toString().trim().replace(/\r/g, '').length).toBe( + payload.length, + ); + }); + testSkipOnDarwin('does not leak files', async () => { const oldFds = getOpenFds(); const promises = []; @@ -394,7 +417,7 @@ describe('PTY', { repeats: 500 }, () => { } await Promise.all(promises); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('can run concurrent shells', async () => { @@ -451,7 +474,7 @@ describe('PTY', { repeats: 500 }, () => { expect(result).toStrictEqual(expectedResult); } - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test("doesn't break when executing non-existing binary", async () => { @@ -464,7 +487,7 @@ describe('PTY', { repeats: 500 }, () => { }); }).rejects.toThrow('No such file or directory'); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('cannot be written to after closing', async () => { @@ -492,7 +515,7 @@ describe('PTY', { repeats: 500 }, () => { } }); await vi.waitFor(() => receivedError); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); test('cannot resize when out of range', async () => { @@ -531,7 +554,7 @@ describe('PTY', { repeats: 500 }, () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, -1); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); }); @@ -722,7 +745,7 @@ describe('cgroup opts', async () => { // Verify that the process was placed in the correct cgroup by // checking its output contains our unique slice name expect(buffer).toContain(cgroupState.sliceName); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); testOnlyOnDarwin('cgroup is not supported on darwin', async () => { @@ -799,7 +822,7 @@ describe('sandbox opts', { repeats: 10 }, async () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(buffer).toContain('hello'); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); testSkipOnDarwin('basic protection against git-yeetage', async () => { @@ -844,7 +867,7 @@ describe('sandbox opts', { repeats: 10 }, async () => { expect(buffer.trimEnd()).toBe( `Tried to delete a forbidden path: ${gitPath}`, ); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); testSkipOnDarwin('can exclude prefixes', async () => { @@ -890,7 +913,7 @@ describe('sandbox opts', { repeats: 10 }, async () => { await vi.waitFor(() => expect(onExit).toHaveBeenCalledTimes(1)); expect(onExit).toHaveBeenCalledWith(null, 0); expect(buffer.trimEnd()).toBe(''); - expect(getOpenFds()).toStrictEqual(oldFds); + await vi.waitFor(() => expect(getOpenFds()).toStrictEqual(oldFds)); }); }); diff --git a/tests/syntheticEOF.test.ts b/tests/syntheticEOF.test.ts new file mode 100644 index 00000000..df5837c1 --- /dev/null +++ b/tests/syntheticEOF.test.ts @@ -0,0 +1,137 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { + SyntheticEOFDetector, + SYNTHETIC_EOF, + EOF_EVENT, +} from '../syntheticEof'; + +describe('sequence', () => { + it('should have correct EOF sequence', () => { + expect(SYNTHETIC_EOF).toEqual( + Buffer.from([0x1b, 0x5d, 0x37, 0x38, 0x37, 0x38, 0x1b, 0x5c]), + ); + expect(SYNTHETIC_EOF.length).toBe(8); + }); +}); + +describe('SyntheticEOFDetector', () => { + let detector: SyntheticEOFDetector; + let onData: (data: Buffer) => void; + let onEOF: () => void; + let output: Buffer; + + beforeEach(() => { + detector = new SyntheticEOFDetector(); + output = Buffer.alloc(0); + onData = vi.fn((data: Buffer) => (output = Buffer.concat([output, data]))); + onEOF = vi.fn(); + + detector.on('data', onData); + detector.on(EOF_EVENT, onEOF); + }); + + it('should handle EOF at the end of stream', async () => { + detector.write('Before EOF'); + detector.write(SYNTHETIC_EOF); + detector.end(); + + expect(output.toString()).toBe('Before EOF'); + expect(onEOF).toHaveBeenCalledTimes(1); + }); + + it('should handle EOF split across chunks', async () => { + detector.write('Data1'); + detector.write('\x1B]78'); // Partial EOF + detector.write('78\x1B\\'); // Complete EOF + detector.write('Data2'); + detector.end(); + + expect(output.toString()).toBe('Data1Data2'); + expect(onEOF).toHaveBeenCalledTimes(1); + }); + + it('should pass through data when no EOF is present', async () => { + detector.write('Just normal data'); + detector.write(' with no EOF'); + detector.end(); + + expect(output.toString()).toBe('Just normal data with no EOF'); + expect(onEOF).not.toHaveBeenCalled(); + }); + + it('should not trigger on partial EOF at end', async () => { + detector.write('Data'); + detector.write('\x1B]78'); // Incomplete EOF + detector.end(); + + expect(output.toString()).toBe('Data\x1B]78'); + expect(onEOF).not.toHaveBeenCalled(); + }); + + it('should handle EOF split after escape', async () => { + detector.write('\x1B'); + detector.write(']7878\x1B\\'); + detector.write('data1'); + detector.end(); + + expect(output.toString()).toBe('data1'); + expect(onEOF).toHaveBeenCalledTimes(1); + }); + + it('should handle EOF split in the middle', async () => { + detector.write('\x1B]78'); + detector.write('78\x1B\\'); + detector.write('data2'); + detector.end(); + + expect(output.toString()).toBe('data2'); + expect(onEOF).toHaveBeenCalledTimes(1); + }); + + it('should not hold up data that isnt a prefix of EOF', async () => { + detector.write('Data that is definitely not an EOF prefix'); + + expect(output.toString()).toBe('Data that is definitely not an EOF prefix'); + expect(onEOF).not.toHaveBeenCalled(); + + detector.end(); + expect(onEOF).not.toHaveBeenCalled(); + }); + + it('should emit events in correct order', async () => { + const detector = new SyntheticEOFDetector(); + const events: Array< + | { + type: 'eof'; + } + | { + type: 'data'; + data: string; + } + > = []; + + detector.on('data', (chunk) => { + events.push({ type: 'data', data: chunk.toString() }); + }); + detector.on(EOF_EVENT, () => { + events.push({ type: 'eof' }); + }); + + const finished = new Promise((resolve) => { + detector.on('end', resolve); + }); + + detector.write('before'); + detector.write(SYNTHETIC_EOF); + detector.write('after'); + detector.end(); + + await finished; + + expect(events).toEqual([ + { type: 'data', data: 'before' }, + { type: 'eof' }, + { type: 'data', data: 'after' }, + ]); + }); +}); diff --git a/vitest.config.ts b/vitest.config.ts index 9a4c46f9..d0ba3102 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -1,5 +1,7 @@ export default { test: { exclude: ['node_modules', 'dist', '.direnv'], + fileParallelism: false, + pool: 'forks', }, }; diff --git a/wrapper.ts b/wrapper.ts index 2ce68c50..8598e125 100644 --- a/wrapper.ts +++ b/wrapper.ts @@ -1,4 +1,4 @@ -import type { Readable, Writable } from 'node:stream'; +import { type Readable, type Writable } from 'node:stream'; import { ReadStream } from 'node:tty'; import { Pty as RawPty, @@ -15,6 +15,7 @@ import { type SandboxRule, type SandboxOptions, } from './index.js'; +import { EOF_EVENT, SyntheticEOFDetector } from './syntheticEof.js'; export { Operation, type SandboxRule, type SandboxOptions, type PtyOptions }; @@ -53,17 +54,13 @@ export class Pty { #fd: number; #handledClose: boolean = false; - #fdClosed: boolean = false; + #socketClosed: boolean = false; + #userFdDropped: boolean = false; + #fdDropTimeout: ReturnType | null = null; #socket: ReadStream; - - get read(): Readable { - return this.#socket; - } - - get write(): Writable { - return this.#socket; - } + read: Readable; + write: Writable; constructor(options: PtyOptions) { const realExit = options.onExit; @@ -77,27 +74,33 @@ export class Pty { let readFinished = new Promise((resolve) => { markReadFinished = resolve; }); - const mockedExit = (error: NodeJS.ErrnoException | null, code: number) => { - markExited({ error, code }); - }; // when pty exits, we should wait until the fd actually ends (end OR error) // before closing the pty // we use a mocked exit function to capture the exit result // and then call the real exit function after the fd is fully read - this.#pty = new RawPty({ ...options, onExit: mockedExit }); - // Transfer ownership of the FD to us. - this.#fd = this.#pty.takeFd(); - + this.#pty = new RawPty({ + ...options, + onExit: (error, code) => { + // give nodejs a max of 1s to read the fd before + // dropping the fd to avoid leaking it + this.#fdDropTimeout = setTimeout(() => { + this.dropUserFd(); + }, 1000); + + markExited({ error, code }); + }, + }); + this.#fd = this.#pty.takeControllerFd(); this.#socket = new ReadStream(this.#fd); // catch end events const handleClose = async () => { - if (this.#fdClosed) { + if (this.#socketClosed) { return; } - this.#fdClosed = true; + this.#socketClosed = true; // must wait for fd close and exit result before calling real exit await readFinished; @@ -105,9 +108,6 @@ export class Pty { realExit(result.error, result.code); }; - this.read.once('end', markReadFinished); - this.read.once('close', handleClose); - // PTYs signal their done-ness with an EIO error. we therefore need to filter them out (as well as // cleaning up other spurious errors) so that the user doesn't need to handle them and be in // blissful peace. @@ -123,10 +123,10 @@ export class Pty { // EIO only happens when the child dies. It is therefore our only true signal that there // is nothing left to read and we can start tearing things down. If we hadn't received an // error so far, we are considered to be in good standing. - this.read.off('error', handleError); + this.#socket.off('error', handleError); // emit 'end' to signal no more data // this will trigger our 'end' handler which marks readFinished - this.read.emit('end'); + this.#socket.emit('end'); return; } } @@ -135,7 +135,37 @@ export class Pty { throw err; }; - this.read.on('error', handleError); + // we need this synthetic eof detector as the pty stream has no way + // of distinguishing the program exiting vs the data being fully read + // this is injected on the rust side after the .wait on the child process + // returns + // more details: https://github.com/replit/ruspty/pull/93 + this.read = this.#socket.pipe(new SyntheticEOFDetector()); + this.write = this.#socket; + + this.#socket.on('error', handleError); + this.#socket.once('end', markReadFinished); + this.#socket.once('close', handleClose); + this.read.once(EOF_EVENT, async () => { + // even if the program accidentally emits our synthetic eof + // we dont yank the user fd away from them until the program actually exits + // (and drops its copy of the user fd) + await exitResult; + this.dropUserFd(); + }); + } + + private dropUserFd() { + if (this.#userFdDropped) { + return; + } + + if (this.#fdDropTimeout) { + clearTimeout(this.#fdDropTimeout); + } + + this.#userFdDropped = true; + this.#pty.dropUserFd(); } close() { @@ -144,10 +174,11 @@ export class Pty { // end instead of destroy so that the user can read the last bits of data // and allow graceful close event to mark the fd as ended this.#socket.end(); + this.dropUserFd(); } resize(size: Size) { - if (this.#handledClose || this.#fdClosed) { + if (this.#handledClose || this.#socketClosed) { return; }