@@ -454,6 +454,8 @@ class SSHClientBase(api.ExecHelper):
454454 :type sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None
455455 :param keepalive: keepalive period
456456 :type keepalive: int | bool
457+ :param allow_ssh_agent: use SSH Agent if available
458+ :type allow_ssh_agent: bool
457459
458460 .. note:: auth has priority over username/password/private_keys
459461 .. note::
@@ -471,6 +473,7 @@ class SSHClientBase(api.ExecHelper):
471473 .. versionchanged:: 7.0.0 private_keys is removed
472474 .. versionchanged:: 7.0.0 keepalive_mode is removed
473475 .. versionchanged:: 7.4.0 return of keepalive_mode to prevent mix with keepalive period. Default is `False`
476+ .. versionchanged:: 8.0.0 expose SSH Agent usage override
474477 """
475478
476479 __slots__ = (
@@ -486,6 +489,7 @@ class SSHClientBase(api.ExecHelper):
486489 "__ssh_config" ,
487490 "__sock" ,
488491 "__conn_chain" ,
492+ "__allow_agent" ,
489493 )
490494
491495 def __hash__ (self ) -> int :
@@ -509,8 +513,19 @@ def __init__(
509513 ssh_auth_map : dict [str , ssh_auth .SSHAuth ] | ssh_auth .SSHAuthMapping | None = None ,
510514 sock : paramiko .ProxyCommand | paramiko .Channel | socket .socket | None = None ,
511515 keepalive : KeepAlivePeriodT = 1 ,
516+ allow_ssh_agent : bool = True ,
512517 ) -> None :
513518 """Main SSH Client helper."""
519+ self .__sudo_mode = False
520+ self .__keepalive_period : int = int (keepalive )
521+ self .__keepalive_mode = False
522+ self .__verbose : bool = verbose
523+ self .__sock = sock
524+
525+ self .__ssh : paramiko .SSHClient
526+ self .__sftp : paramiko .SFTPClient | None = None
527+ self .__allow_agent = allow_ssh_agent
528+
514529 # Init ssh config. It's main source for connection parameters
515530 if isinstance (ssh_config , _ssh_helpers .HostsSSHConfigs ):
516531 self .__ssh_config : _ssh_helpers .HostsSSHConfigs = ssh_config
@@ -533,35 +548,25 @@ def __init__(
533548 if self .hostname not in self .__auth_mapping and host in self .__auth_mapping :
534549 self .__auth_mapping [self .hostname ] = self .__auth_mapping [host ]
535550
536- self .__sudo_mode = False
537- self .__keepalive_period : int = int (keepalive )
538- self .__keepalive_mode = False
539- self .__verbose : bool = verbose
540- self .__sock = sock
541-
542- self .__ssh : paramiko .SSHClient
543- self .__sftp : paramiko .SFTPClient | None = None
544-
545551 # Rebuild SSHAuth object if required.
546552 # Priority: auth > credentials > auth mapping
547- if auth is not None :
548- self .__auth_mapping [self .hostname ] = real_auth = copy .copy (auth )
549- elif self .hostname not in self .__auth_mapping or any ((username , password )):
550- self .__auth_mapping [self .hostname ] = real_auth = ssh_auth .SSHAuth (
551- username = username if username is not None else config .user ,
552- password = password ,
553- key_filename = config .identityfile ,
554- )
555- else :
556- real_auth = self .__auth_mapping [self .hostname ]
553+ real_auth = self .__handle_explicit_auth (
554+ username = username ,
555+ config_username = config .user ,
556+ password = password ,
557+ auth = auth ,
558+ key_filename = config .identityfile ,
559+ )
557560
558561 # Init super with host and real port and username
559562 mod_name = "exec_helpers" if self .__module__ .startswith ("exec_helpers" ) else self .__module__
560563 log_username : str = real_auth .username if real_auth .username is not None else getpass .getuser ()
561564
562565 super ().__init__ (
563- logger = logging .getLogger (f"{ mod_name } .{ self .__class__ .__name__ } " ).getChild (
564- f"({ log_username } @{ host } :{ self .port } )"
566+ logger = logging .getLogger (
567+ f"{ mod_name } .{ self .__class__ .__name__ } " ,
568+ ).getChild (
569+ f"({ log_username } @{ host } :{ self .port } )" ,
565570 )
566571 )
567572
@@ -577,6 +582,26 @@ def __init__(
577582
578583 self .__connect ()
579584
585+ def __handle_explicit_auth (
586+ self ,
587+ * ,
588+ username : str | None ,
589+ config_username : str | None ,
590+ password : str | None ,
591+ auth : ssh_auth .SSHAuth | None ,
592+ key_filename : Iterable [str ] | None ,
593+ ) -> ssh_auth .SSHAuth :
594+ if auth is not None :
595+ self .__auth_mapping [self .hostname ] = auth
596+ elif self .hostname not in self .__auth_mapping or any ((username , password )):
597+ self .__auth_mapping [self .hostname ] = ssh_auth .SSHAuth (
598+ username = username if username is not None else config_username ,
599+ password = password ,
600+ key_filename = key_filename ,
601+ )
602+
603+ return self .__auth_mapping [self .hostname ]
604+
580605 def __rebuild_ssh_config (self ) -> None :
581606 """Rebuild main ssh config from available information."""
582607 self .__ssh_config [self .hostname ] = self .__ssh_config [self .hostname ].overridden_by (
@@ -598,7 +623,11 @@ def __build_connection_chain(self) -> list[tuple[_ssh_helpers.SSHConfig, ssh_aut
598623
599624 config = self .ssh_config [self .hostname ]
600625 default_auth = ssh_auth .SSHAuth (username = config .user , key_filename = config .identityfile )
601- auth = self .__auth_mapping .get_with_alt_hostname (config .hostname , self .hostname , default = default_auth )
626+ auth = self .__auth_mapping .get_with_alt_hostname (
627+ config .hostname ,
628+ self .hostname ,
629+ default = default_auth ,
630+ )
602631 conn_chain .append ((config , auth ))
603632
604633 while config .proxyjump is not None :
@@ -621,6 +650,15 @@ def auth(self) -> ssh_auth.SSHAuth:
621650 """
622651 return self .__auth_mapping [self .hostname ]
623652
653+ @property
654+ def allow_ssh_agent (self ) -> bool :
655+ """Use SSH Agent if available.
656+
657+ :return: SSH Agent usage allowed
658+ :rtype: bool
659+ """
660+ return self .__allow_agent
661+
624662 @property
625663 def hostname (self ) -> str :
626664 """Connected remote host name.
@@ -714,16 +752,15 @@ def __connect(self) -> None:
714752 """Main method for connection open."""
715753 with self .lock :
716754 if self .__sock is not None :
717- sock = self .__sock
718-
719755 self .__ssh = paramiko .SSHClient ()
720756 self .__ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
721757 self .auth .connect (
722758 client = self .__ssh ,
723759 hostname = self .hostname ,
724760 port = self .port ,
725761 log = self .__verbose ,
726- sock = sock ,
762+ sock = self .__sock ,
763+ allow_ssh_agent = self .allow_ssh_agent ,
727764 )
728765 else :
729766 self .__ssh = self .__get_client ()
@@ -745,15 +782,14 @@ def __get_client(self) -> paramiko.SSHClient:
745782 last_ssh_client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
746783
747784 config , auth = self .__conn_chain [0 ]
748- if config .proxycommand :
749- auth .connect (
750- last_ssh_client ,
751- hostname = config .hostname ,
752- port = config .port or 22 ,
753- sock = paramiko .ProxyCommand (config .proxycommand ),
754- )
755- else :
756- auth .connect (last_ssh_client , hostname = config .hostname , port = config .port or 22 )
785+
786+ auth .connect (
787+ last_ssh_client ,
788+ hostname = config .hostname ,
789+ port = config .port or 22 ,
790+ sock = paramiko .ProxyCommand (config .proxycommand ) if config .proxycommand else None ,
791+ allow_ssh_agent = self .allow_ssh_agent ,
792+ )
757793
758794 for config , auth in self .__conn_chain [1 :]: # start has another logic, so do it out of cycle
759795 ssh = paramiko .SSHClient ()
@@ -768,7 +804,13 @@ def __get_client(self) -> paramiko.SSHClient:
768804 dest_addr = (config .hostname , config .port or 22 ),
769805 src_addr = (config .proxyjump , 0 ),
770806 )
771- auth .connect (ssh , hostname = config .hostname , port = config .port or 22 , sock = sock )
807+ auth .connect (
808+ ssh ,
809+ hostname = config .hostname ,
810+ port = config .port or 22 ,
811+ sock = sock ,
812+ allow_ssh_agent = self .allow_ssh_agent ,
813+ )
772814 last_ssh_client = ssh
773815 continue
774816
@@ -1421,33 +1463,6 @@ def check_stderr(
14211463 ** kwargs ,
14221464 )
14231465
1424- def _get_proxy_channel (
1425- self ,
1426- port : int | None ,
1427- ssh_config : _ssh_helpers .SSHConfig ,
1428- ) -> paramiko .Channel :
1429- """Get ssh proxy channel.
1430-
1431- :param port: target port
1432- :type port: int | None
1433- :param ssh_config: pre-parsed ssh config
1434- :type ssh_config: SSHConfig
1435- :return: ssh channel for usage as socket for new connection over it
1436- :rtype: paramiko.Channel
1437-
1438- .. versionadded:: 6.0.0
1439- """
1440- if port is not None :
1441- dest_port : int = port
1442- else :
1443- dest_port = ssh_config .port if ssh_config .port is not None else 22
1444-
1445- return self ._ssh_transport .open_channel (
1446- kind = "direct-tcpip" ,
1447- dest_addr = (ssh_config .hostname , dest_port ),
1448- src_addr = (self .hostname , 0 ),
1449- )
1450-
14511466 def proxy_to (
14521467 self ,
14531468 host : str ,
@@ -1498,13 +1513,25 @@ def proxy_to(
14981513 else :
14991514 parsed_ssh_config = _ssh_helpers .parse_ssh_config (ssh_config , host )
15001515
1501- hostname = parsed_ssh_config [host ].hostname
1516+ host_config = parsed_ssh_config [host ]
1517+
1518+ if port is not None :
1519+ dest_port : int = port
1520+ elif host_config .port is not None :
1521+ dest_port = host_config .port
1522+ else :
1523+ dest_port = 22
1524+
1525+ sock : paramiko .Channel = self ._ssh_transport .open_channel (
1526+ kind = "direct-tcpip" ,
1527+ dest_addr = (host_config .hostname , dest_port ),
1528+ src_addr = (self .hostname , 0 ),
1529+ )
15021530
1503- sock : paramiko .Channel = self ._get_proxy_channel (port = port , ssh_config = parsed_ssh_config [hostname ])
15041531 cls : type [Self ] = self .__class__
15051532 return cls (
15061533 host = host ,
1507- port = port ,
1534+ port = dest_port ,
15081535 username = username ,
15091536 password = password ,
15101537 auth = auth ,
0 commit comments