From 3c074674065e542f6ad233038784bb400e683554 Mon Sep 17 00:00:00 2001 From: Victor Carreras <34163765+vicajilau@users.noreply.github.com> Date: Thu, 16 Apr 2026 10:34:47 +0200 Subject: [PATCH] Refactor KEX handling to async and offload X25519 crypto --- lib/src/kex/kex_x25519.dart | 43 +++++++++++-- lib/src/ssh_transport.dart | 91 ++++++++++++++++----------- lib/src/utils/compute.dart | 7 +++ lib/src/utils/compute_io.dart | 63 +++++++++++++++++++ lib/src/utils/compute_stub.dart | 6 ++ test/src/ssh_transport_aead_test.dart | 38 ++++++----- 6 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 lib/src/utils/compute.dart create mode 100644 lib/src/utils/compute_io.dart create mode 100644 lib/src/utils/compute_stub.dart diff --git a/lib/src/kex/kex_x25519.dart b/lib/src/kex/kex_x25519.dart index 77a8ff6..18e32f4 100644 --- a/lib/src/kex/kex_x25519.dart +++ b/lib/src/kex/kex_x25519.dart @@ -1,21 +1,36 @@ import 'dart:typed_data'; import 'package:dartssh2/src/ssh_kex.dart'; +import 'package:dartssh2/src/utils/compute.dart'; import 'package:dartssh2/src/utils/bigint.dart'; import 'package:dartssh2/src/utils/list.dart'; import 'package:pinenacl/tweetnacl.dart'; class SSHKexX25519 implements SSHKexECDH { /// Randomly generated private key. - late final Uint8List privateKey; + final Uint8List privateKey; /// Public key computed from the private key. @override - late final Uint8List publicKey; + final Uint8List publicKey; + + factory SSHKexX25519() { + final privateKey = randomBytes(32); + final publicKey = _ScalarMult.scalseMultBase(privateKey); + return SSHKexX25519._( + privateKey: privateKey, + publicKey: publicKey, + ); + } + + SSHKexX25519._({required this.privateKey, required this.publicKey}); - SSHKexX25519() { - privateKey = randomBytes(32); - publicKey = _ScalarMult.scalseMultBase(privateKey); + static Future createAsync() async { + final keyPair = await sshCompute(_computeX25519KeyPair, null); + return SSHKexX25519._( + privateKey: keyPair[0], + publicKey: keyPair[1], + ); } @override @@ -23,6 +38,24 @@ class SSHKexX25519 implements SSHKexECDH { final secret = _ScalarMult.scalseMult(privateKey, remotePublicKey); return decodeBigIntWithSign(1, secret); } + + Future computeSecretAsync(Uint8List remotePublicKey) async { + final secret = await sshCompute( + _computeX25519Secret, + [privateKey, remotePublicKey], + ); + return decodeBigIntWithSign(1, secret); + } +} + +List _computeX25519KeyPair(void _) { + final privateKey = randomBytes(32); + final publicKey = _ScalarMult.scalseMultBase(privateKey); + return [privateKey, publicKey]; +} + +Uint8List _computeX25519Secret(List data) { + return _ScalarMult.scalseMult(data[0], data[1]); } /// Scalar multiplication, Implements curve25519. diff --git a/lib/src/ssh_transport.dart b/lib/src/ssh_transport.dart index d66a245..38e28cf 100644 --- a/lib/src/ssh_transport.dart +++ b/lib/src/ssh_transport.dart @@ -114,6 +114,9 @@ class SSHTransport { /// transport is closed. StreamSubscription? _socketSubscription; + /// Guards asynchronous packet processing to preserve message order. + var _isProcessingData = false; + /// Identification string sent by us without trailing \r\n. For example, /// "SSH-2.0-DartSSH_2.0". String get _localVersion => 'SSH-2.0-$version'; @@ -408,13 +411,7 @@ class SSHTransport { void _onSocketData(Uint8List data) { _buffer.add(data); - try { - _processData(); - } on SSHError catch (e, stackTrace) { - closeWithError(e, stackTrace); - } catch (e) { - rethrow; - } + _scheduleProcessData(); } void _onSocketError(Object error, StackTrace stackTrace) { @@ -427,11 +424,33 @@ class SSHTransport { close(); } - void _processData() { + void _scheduleProcessData() { + if (_isProcessingData || isClosed) { + return; + } + + _isProcessingData = true; + + _processDataAsync().catchError((error, stackTrace) { + if (error is SSHError) { + closeWithError(error, stackTrace); + } else { + closeWithError(SSHInternalError(error), stackTrace); + } + }).whenComplete(() { + _isProcessingData = false; + if (_buffer.isNotEmpty && !isClosed) { + _scheduleProcessData(); + } + }); + } + + Future _processDataAsync() async { if (_remoteVersion == null) { _processVersionExchange(); - } else { - _processPackets(); + } + if (_remoteVersion != null) { + await _processPackets(); } } @@ -473,12 +492,12 @@ class SSHTransport { _sendKexInit(); } - // There maybe more data in the buffer, so process it. - _processPackets(); + // There maybe more data in the buffer, so it will be consumed by the + // asynchronous packet processing queue. } /// Process one or more SSH packets queued in [_buffer]. - void _processPackets() { + Future _processPackets() async { printDebug?.call('SSHTransport._processPackets'); while (_buffer.isNotEmpty && !isClosed) { @@ -491,7 +510,7 @@ class SSHTransport { // throw SSHPacketError('Packet too long: ${payload.length}'); // } - _handleMessage(payload); + await _handleMessage(payload); _remotePacketSN.increase(); } @@ -980,7 +999,7 @@ class SSHTransport { sendPacket(message.encode()); } - void _handleMessage(Uint8List message) { + Future _handleMessage(Uint8List message) async { final messageId = SSHMessage.readMessageId(message); switch (messageId) { case SSH_Message_KexInit.messageId: @@ -995,7 +1014,7 @@ class SSHTransport { } } - void _handleMessageKexInit(Uint8List payload) { + Future _handleMessageKexInit(Uint8List payload) async { printDebug?.call('SSHTransport._handleMessageKexInit'); // If this message initiates a new key-exchange round from the remote @@ -1073,7 +1092,7 @@ class SSHTransport { switch (_kexType) { case SSHKexType.x25519: - _kex = SSHKexX25519(); + _kex = await SSHKexX25519.createAsync(); break; case SSHKexType.nistp256: _kex = SSHKexNist.p256(); @@ -1107,7 +1126,7 @@ class SSHTransport { /// When client receives [SSH_Message_KexECDH_Reply], it should verify the /// server's signature with the server's public key. Then send NEW_KEYS /// message back to the server. - void _handleMessageKexReply(Uint8List payload) { + Future _handleMessageKexReply(Uint8List payload) async { printDebug?.call('SSHTransport._handleMessageKexReply'); if (isServer) throw SSHStateError('Unexpected KEX_REPLY'); @@ -1149,7 +1168,11 @@ class SSHTransport { hostSignature = message.signature; serverKexKey = message.ecdhPublicKey; clientKexKey = kex.publicKey; - sharedSecret = kex.computeSecret(message.ecdhPublicKey); + if (kex is SSHKexX25519) { + sharedSecret = await kex.computeSecretAsync(message.ecdhPublicKey); + } else { + sharedSecret = kex.computeSecret(message.ecdhPublicKey); + } } else { throw UnimplementedError('$kex'); } @@ -1189,27 +1212,21 @@ class SSHTransport { } final userVerified = onVerifyHostKey != null - ? onVerifyHostKey!(_hostkeyType!.name, fingerprint) + ? await Future.value(onVerifyHostKey!(_hostkeyType!.name, fingerprint)) : true; - Future.value(userVerified).then( - (verified) { - if (!verified) { - closeWithError(SSHHostkeyError('Hostkey verification failed')); - } else { - _hostkeyVerified = true; - _sendNewKeys(); - _applyLocalKeys(); - onReady?.call(); - } - }, - onError: (error) { - closeWithError(error); - }, - ); + if (!userVerified) { + closeWithError(SSHHostkeyError('Hostkey verification failed')); + return; + } + + _hostkeyVerified = true; + _sendNewKeys(); + _applyLocalKeys(); + onReady?.call(); } - void _handleMessageKexGexReply(Uint8List payload) { + Future _handleMessageKexGexReply(Uint8List payload) async { printDebug?.call('SSHTransport._handleMessageKexGexReply'); if (isServer) throw SSHStateError('Unexpected KEX_GEX_REPLY'); @@ -1220,7 +1237,7 @@ class SSHTransport { _sendKexDHGexInit(); } - void _handleMessageNewKeys(Uint8List message) { + Future _handleMessageNewKeys(Uint8List message) async { printDebug?.call('SSHTransport._handleMessageNewKeys'); printTrace?.call('<- $socket: SSH_Message_NewKeys'); diff --git a/lib/src/utils/compute.dart b/lib/src/utils/compute.dart new file mode 100644 index 0000000..9bc081c --- /dev/null +++ b/lib/src/utils/compute.dart @@ -0,0 +1,7 @@ +import 'compute_stub.dart' if (dart.library.isolate) 'compute_io.dart'; + +typedef SSHComputeCallback = R Function(M message); + +Future sshCompute(SSHComputeCallback callback, M message) { + return sshComputeImpl(callback, message); +} diff --git a/lib/src/utils/compute_io.dart b/lib/src/utils/compute_io.dart new file mode 100644 index 0000000..f0c4154 --- /dev/null +++ b/lib/src/utils/compute_io.dart @@ -0,0 +1,63 @@ +import 'dart:async'; +import 'dart:isolate'; + +class _ComputeConfiguration { + final R Function(M message) callback; + final M message; + final SendPort resultPort; + + const _ComputeConfiguration({ + required this.callback, + required this.message, + required this.resultPort, + }); +} + +class _ComputeError { + final String error; + final String stackTrace; + + const _ComputeError(this.error, this.stackTrace); +} + +void _spawn(_ComputeConfiguration configuration) { + try { + final result = configuration.callback(configuration.message); + Isolate.exit(configuration.resultPort, result); + } catch (error, stackTrace) { + Isolate.exit( + configuration.resultPort, + _ComputeError(error.toString(), stackTrace.toString()), + ); + } +} + +Future sshComputeImpl( + R Function(M message) callback, + M message, +) async { + final resultPort = RawReceivePort(); + final completer = Completer(); + + resultPort.handler = (response) { + resultPort.close(); + if (response is _ComputeError) { + completer.completeError( + RemoteError(response.error, response.stackTrace), + ); + return; + } + completer.complete(response as R); + }; + + await Isolate.spawn<_ComputeConfiguration>( + _spawn, + _ComputeConfiguration( + callback: callback, + message: message, + resultPort: resultPort.sendPort, + ), + ); + + return completer.future; +} diff --git a/lib/src/utils/compute_stub.dart b/lib/src/utils/compute_stub.dart new file mode 100644 index 0000000..5fdd1a2 --- /dev/null +++ b/lib/src/utils/compute_stub.dart @@ -0,0 +1,6 @@ +Future sshComputeImpl( + R Function(M message) callback, + M message, +) { + return Future.sync(() => callback(message)); +} diff --git a/test/src/ssh_transport_aead_test.dart b/test/src/ssh_transport_aead_test.dart index 6718738..84c39e4 100644 --- a/test/src/ssh_transport_aead_test.dart +++ b/test/src/ssh_transport_aead_test.dart @@ -283,7 +283,7 @@ void main() { transport.close(); }); - test('kexinit allows missing MAC when AEAD cipher is selected', () { + test('kexinit allows missing MAC when AEAD cipher is selected', () async { final socket = _CaptureSSHSocket(); final transport = SSHTransport( socket, @@ -308,11 +308,9 @@ void main() { firstKexPacketFollows: false, ).encode(); - expect( - () => reflect(transport) - .invoke(privateSymbol('_handleMessageKexInit'), [payload]), - returnsNormally, - ); + final result = reflect(transport) + .invoke(privateSymbol('_handleMessageKexInit'), [payload]).reflectee; + await expectLater(result, completes); transport.close(); }); @@ -402,7 +400,8 @@ void main() { transport.close(); }); - test('kexinit requires client MAC when non-AEAD cipher is selected', () { + test('kexinit requires client MAC when non-AEAD cipher is selected', + () async { final socket = _CaptureSSHSocket(); final transport = SSHTransport( socket, @@ -427,16 +426,22 @@ void main() { firstKexPacketFollows: false, ).encode(); - expect( - () => reflect(transport) - .invoke(privateSymbol('_handleMessageKexInit'), [payload]), + await expectLater( + () async { + final result = reflect(transport).invoke( + privateSymbol('_handleMessageKexInit'), [payload]).reflectee; + if (result is Future) { + await result; + } + }, throwsA(isA()), ); transport.close(); }); - test('kexinit requires server MAC when non-AEAD cipher is selected', () { + test('kexinit requires server MAC when non-AEAD cipher is selected', + () async { final socket = _CaptureSSHSocket(); final transport = SSHTransport( socket, @@ -461,9 +466,14 @@ void main() { firstKexPacketFollows: false, ).encode(); - expect( - () => reflect(transport) - .invoke(privateSymbol('_handleMessageKexInit'), [payload]), + await expectLater( + () async { + final result = reflect(transport).invoke( + privateSymbol('_handleMessageKexInit'), [payload]).reflectee; + if (result is Future) { + await result; + } + }, throwsA(isA()), );