diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index 0ad353eb4a9..66cddb29418 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -158,6 +158,27 @@ def _get(self, connection: Connection, remote: str, local: str, preserveMode: bo except SSHException as e: return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") + def close(self): + """Close the SSH connection(s) to the remote host. + + Fabric/paramiko connections recommended to be closed explicitly. + """ + if self.connection is None: + return + # The gateway (jump host) is a distinct Connection referenced by the + # main connection; close it as well to avoid leaking its Transport. + gateway = getattr(self.connection, "gateway", None) + try: + self.connection.close() + except Exception as e: + self.log.warn("Failed to close SSH connection", str(e)) + if isinstance(gateway, Connection): + try: + gateway.close() + except Exception as e: + self.log.warn("Failed to close SSH gateway connection", str(e)) + self.connection = None + ############################################################################# def _getBatchSystem(self): diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py index 4a6a67c99b6..56538a5a4f9 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py @@ -9,17 +9,18 @@ """ import datetime -from DIRAC import S_OK, gConfig +from DIRAC import S_OK from DIRAC.AccountingSystem.Client.DataStoreClient import gDataStoreClient from DIRAC.AccountingSystem.Client.Types.Pilot import Pilot as PilotAccounting -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping, getQueue from DIRAC.Core.Base.AgentModule import AgentModule from DIRAC.Core.Utilities import TimeUtilities -from DIRAC.Interfaces.API.DiracAdmin import DiracAdmin from DIRAC.WorkloadManagementSystem.Client import PilotStatus from DIRAC.WorkloadManagementSystem.Client.PilotManagerClient import PilotManagerClient from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB +from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import setPilotCredentials +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache class PilotStatusAgent(AgentModule): @@ -39,14 +40,17 @@ def __init__(self, *args, **kwargs): self.jobDB = None self.pilotDB = None - self.diracadmin = None + # Cache of ComputingElement instances keyed by queue. CEs (and their + # SSH connections) are reused across cycles to avoid opening a new + # connection for every pilot kill, following the SiteDirector pattern. + self.ceCache = None ############################################################################# def initialize(self): """Sets defaults""" self.pilotDB = PilotAgentsDB() - self.diracadmin = DiracAdmin() + self.ceCache = QueueCECache() self.jobDB = JobDB() self.clearPilotsDelay = self.am_getOption("ClearPilotsDelay", 30) self.clearAbortedDelay = self.am_getOption("ClearAbortedPilotsDelay", 7) @@ -199,17 +203,61 @@ def _addPilotsAccountingReport(self, pilotsData): return retVal return S_OK() + def _getCEForQueue(self, vo, site, ce, queue, gridType): + """Return a cached ComputingElement for the given queue. + + The CE (and its SSH connection) is reused across cycles via the + :class:`QueueCECache`, and rebuilt automatically if the queue + configuration changes. + """ + result = getQueue(site, ce, queue) + if not result["OK"]: + return result + queueKey = "@@@".join([vo, site, ce, queue]) + return self.ceCache.getCE(queueKey, gridType, ce, result["Value"]) + def _killPilots(self, acc): - for i in sorted(acc.keys()): - result = self.diracadmin.getPilotInfo(i) - if result["OK"] and i in result["Value"] and "Status" in result["Value"][i]: - ret = self.diracadmin.killPilot(str(i)) - if ret["OK"]: - self.log.info("Successfully deleted", f": {i} (Status : {result['Value'][i]['Status']})") - else: - self.log.error("Failed to delete pilot: ", f"{i} : {ret['Message']}") - else: - self.log.error("Failed to get pilot info", f"{i} : {str(result)}") + """Declare the given pilots killed on their CEs. + + Pilots are grouped per queue and killed in a single call per queue, + reusing a cached CE/connection per queue across cycles. + """ + # Group the pilots to kill per queue + pilotsByQueue = {} + for pRef in acc: + pilotDict = acc[pRef] + queueFields = [pilotDict["VO"], pilotDict["GridSite"], pilotDict["DestinationSite"], pilotDict["Queue"]] + # A pilot with an incomplete queue definition cannot be located on a + # CE; skip it rather than letting it abort the whole batch. + if not all(queueFields): + self.log.warn("Cannot determine queue for pilot, skipping kill", f"{pRef} : {queueFields}") + continue + queueKey = "@@@".join(queueFields) + queueData = pilotsByQueue.setdefault(queueKey, {"GridType": pilotDict["GridType"], "PilotList": []}) + queueData["PilotList"].append(pRef) + + for queueKey, queueData in pilotsByQueue.items(): + vo, site, ce, queue = queueKey.split("@@@") + result = self._getCEForQueue(vo, site, ce, queue, queueData["GridType"]) + if not result["OK"]: + self.log.error("Failed to get CE for queue", f"{queueKey} : {result['Message']}") + continue + computingElement = result["Value"] + + # The connection is reused, but pilot credentials (proxy/token) + # expire and must be refreshed every cycle. + result = setPilotCredentials(computingElement, {"VO": vo}) + if not result["OK"]: + self.log.error("Failed to set pilot credentials", f"{queueKey} : {result['Message']}") + continue + + result = computingElement.killJob(queueData["PilotList"]) + if not result["OK"]: + self.log.error("Failed to delete pilots", f"{queueKey} : {result['Message']}") + # Drop the (possibly stale) connection so it is rebuilt next cycle + self.ceCache.drop(queueKey) + continue + self.log.info("Successfully deleted pilots", f": {len(queueData['PilotList'])} in queue {queueKey}") def _checkJobLastUpdateTime(self, joblist, StalledDays): timeLimitToConsider = datetime.datetime.utcnow() - TimeUtilities.day * StalledDays diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py index 73965c3071c..24482430017 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py @@ -44,7 +44,7 @@ transferInputSandbox, ) from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials -from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache from DIRAC.WorkloadManagementSystem.Utilities.Utils import createJobWrapper MAX_JOBS_MANAGED = 100 @@ -61,7 +61,7 @@ def __init__(self, agentName, loadName, baseAgentName=False, properties=None): self.firstPass = True self.maxJobsToSubmit = MAX_JOBS_MANAGED self.queueDict = {} - self.queueCECache = {} + self.queueCECache = QueueCECache() self.pilotDN = "" self.vo = "" diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py index 418413adc84..bc61d0e5184 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py @@ -43,7 +43,7 @@ getPilotFilesCompressedEncodedDict, pilotWrapperScript, ) -from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache MAX_PILOTS_TO_SUBMIT = 100 @@ -59,8 +59,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.queueDict = {} - # self.queueCECache aims at saving CEs information over the cycles to avoid to create the exact same CEs each cycle - self.queueCECache = {} + # self.queueCECache saves CE instances (and their connections) over the cycles + # to avoid re-creating the exact same CEs -- and re-opening connections -- each cycle + self.queueCECache = QueueCECache() self.failedQueues = defaultdict(int) self.maxPilotsToSubmit = MAX_PILOTS_TO_SUBMIT @@ -575,7 +576,8 @@ def _getExecutable(self, queue: str, proxy: X509Chain, jobExecDir: str = "", env # in your machine, the executable files will be in the same place # but it does not matter since they are very temporary - ce = self.queueCECache[queue]["CE"] + # Same CE instance as the cache holds; queueDict is the canonical accessor + ce = self.queueDict[queue]["CE"] workingDirectory = getattr(ce, "workingDirectory", self.workingDirectory) executable = self._writePilotScript( diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py index 58311f3d3b3..0915cce1f3a 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py @@ -2,7 +2,7 @@ import hashlib -from DIRAC import S_OK, S_ERROR +from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Utilities.List import fromChar from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getDIRACPlatform @@ -16,7 +16,6 @@ def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, inst The main goal of this method is to return a dictionary of queues """ queueDict = {} - ceFactory = ComputingElementFactory() for site in siteDict: for ce in siteDict[site]: @@ -50,27 +49,20 @@ def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, inst ceQueueDict.update(queueDict[queueName]["ParametersDict"]) if instantiateCEs: - # Generate the CE object for the queue or pick the already existing one - # if the queue definition did not change - queueHash = generateQueueHash(ceQueueDict) - if queueName in queueCECache and queueCECache[queueName]["Hash"] == queueHash: - queueCE = queueCECache[queueName]["CE"] - else: - result = ceFactory.getCE(ceName=ce, ceType=ceDict["CEType"], ceParametersDict=ceQueueDict) - if not result["OK"]: - queueDict.pop(queueName) - continue - queueCECache.setdefault(queueName, {}) - queueCECache[queueName]["Hash"] = queueHash - queueCECache[queueName]["CE"] = result["Value"] - queueCE = queueCECache[queueName]["CE"] + # Get the CE object for the queue, reusing the cached one if the + # queue definition did not change, or (re)building it otherwise. + result = queueCECache.getCE(queueName, ceDict["CEType"], ce, ceQueueDict) + if not result["OK"]: + queueDict.pop(queueName) + continue + queueCE = result["Value"] queueDict[queueName]["ParametersDict"].update(queueCE.ceParameters) queueDict[queueName]["CE"] = queueCE - result = queueDict[queueName]["CE"].isValid() + result = queueCE.isValid() if not result["OK"]: queueDict.pop(queueName) - queueCECache.pop(queueName) + queueCECache.drop(queueName) continue queueDict[queueName]["CEName"] = ce @@ -141,6 +133,59 @@ def generateQueueHash(queueDict): return hexstring +class QueueCECache: + """A cache of ComputingElement instances keyed by queue. + + CEs -- and, for connection-based CEs such as the SSHComputingElement, their + underlying connections -- are reused across cycles instead of being + re-created on every use. A CE is rebuilt only when its queue parameters + change, detected through a hash of the parameters dictionary: the same + invalidation strategy used by the SiteDirector (see :func:`getQueuesResolved`). + """ + + def __init__(self): + # queueKey -> {"Hash": , "CE": } + self._cache = {} + self._ceFactory = ComputingElementFactory() + self.log = gLogger.getSubLogger(self.__class__.__name__) + + def getCE(self, queueKey, ceType, ceName, ceParametersDict): + """Return a cached CE for ``queueKey``, (re)building it when needed. + + :param str queueKey: unique identifier of the queue, used as cache key + :param str ceType: CE type passed to the ComputingElementFactory + :param str ceName: CE name passed to the ComputingElementFactory + :param dict ceParametersDict: queue/CE parameters; a change triggers a rebuild + :return: S_OK(ce)/S_ERROR + """ + queueHash = generateQueueHash(ceParametersDict) + cached = self._cache.get(queueKey) + if cached is not None and cached["Hash"] == queueHash: + return S_OK(cached["CE"]) + + # First use, or the queue definition changed: drop any stale CE + # (releasing its connection) and build a fresh one. + self.drop(queueKey) + result = self._ceFactory.getCE(ceType=ceType, ceName=ceName, ceParametersDict=ceParametersDict) + if not result["OK"]: + return result + self._cache[queueKey] = {"Hash": queueHash, "CE": result["Value"]} + return S_OK(result["Value"]) + + def drop(self, queueKey): + """Remove a cached CE and release its resources (e.g. close connections).""" + entry = self._cache.pop(queueKey, None) + if entry is None: + return + ce = entry["CE"] + # Only connection-based CEs (e.g. SSHComputingElement) define close() + if hasattr(ce, "close"): + try: + ce.close() + except Exception as e: + self.log.warn("Failed to close CE", f"{queueKey} : {str(e)}") + + def matchQueue(jobJDL, queueDict, fullMatch=False): """ Match the job description to the queue definition diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py index 28db38d959c..484989d23be 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from DIRAC import S_OK +from DIRAC import S_OK, S_ERROR from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import * siteDict1 = { @@ -153,7 +153,7 @@ def test_setPlatform(ceDict, queueDict, dictExpected): ) def test_getQueuesResolved(mocker, queueDict, queuesExpected): """Test the getQueuesResolvedEnhanced function""" - queueCECache = {} + queueCECache = QueueCECache() queueDictLocal = copy.deepcopy(queueDict) ce = MagicMock() @@ -165,3 +165,84 @@ def test_getQueuesResolved(mocker, queueDict, queuesExpected): assert queueDictResolved["OK"] for qName, qDictResolved in queueDictResolved["Value"].items(): assert sorted(qDictResolved) == sorted(queuesExpected[qName]) + + +# Target used to patch the CE factory used internally by QueueCECache. +# The factory is set to return a DIFFERENT CE per build (side_effect=[ce1, ce2, ...]), +# so that *which* CE we get back distinguishes "served from cache" (ce1 again) from +# "rebuilt" (ce2). That, plus the factory call_count and close() calls, is what proves +# the cache logic -- asserting we get back the value the mock returned would prove nothing. +GET_CE = "DIRAC.Resources.Computing.ComputingElementFactory.ComputingElementFactory.getCE" + + +def test_QueueCECache_cacheHitDoesNotRebuild(mocker): + """A second call with unchanged parameters reuses the cached CE instead of rebuilding.""" + ce1, ce2 = MagicMock(), MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + params = {"CEType": "SSH", "Host": "host1"} + + first = cache.getCE("queue1", "SSH", "ce1", params) + second = cache.getCE("queue1", "SSH", "ce1", params) + + # Factory invoked exactly once, with the forwarded arguments + getCEMock.assert_called_once_with(ceType="SSH", ceName="ce1", ceParametersDict=params) + # Were the cache broken, the 2nd call would rebuild and hand back ce2 instead of ce1 + assert first["Value"] is ce1 + assert second["Value"] is ce1 + + +def test_QueueCECache_parameterChangeRebuildsAndClosesStaleCE(mocker): + """Changed parameters rebuild the CE (new hash) and close the stale one.""" + ce1, ce2 = MagicMock(), MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + first = cache.getCE("queue1", "SSH", "ce1", {"Host": "host1"}) + second = cache.getCE("queue1", "SSH", "ce1", {"Host": "host2"}) + + assert getCEMock.call_count == 2 # rebuilt because the parameter hash changed + assert first["Value"] is ce1 + assert second["Value"] is ce2 # the new CE, not the stale cached one + ce1.close.assert_called_once() # the stale CE was released on rebuild + ce2.close.assert_not_called() + + +def test_QueueCECache_dropClosesAndForcesRebuild(mocker): + """drop() closes the cached CE and removes it, so the next call rebuilds.""" + ce1, ce2 = MagicMock(), MagicMock() + mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + params = {"Host": "host1"} + + cache.getCE("queue1", "SSH", "ce1", params) + cache.drop("queue1") + ce1.close.assert_called_once() # drop released the connection + + rebuilt = cache.getCE("queue1", "SSH", "ce1", params) + assert rebuilt["Value"] is ce2 # cache miss after drop -> a fresh CE was built + + +def test_QueueCECache_failedBuildIsNotCached(mocker): + """A failed build leaves no cache entry, so a later call retries rather than re-returning the error.""" + ceOK = MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_ERROR("boom"), S_OK(ceOK)]) + + cache = QueueCECache() + params = {"Host": "host1"} + + failed = cache.getCE("queue1", "SSH", "ce1", params) + assert not failed["OK"] + + retried = cache.getCE("queue1", "SSH", "ce1", params) + assert retried["OK"] + assert retried["Value"] is ceOK + assert getCEMock.call_count == 2 # the failure cached nothing, so the 2nd call rebuilt + + +def test_QueueCECache_dropMissingKeyIsNoOp(): + """drop() on an unknown queue key does nothing and does not raise.""" + cache = QueueCECache() + cache.drop("does-not-exist") # must not raise diff --git a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py index da8d8e519ee..4682217223c 100644 --- a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py +++ b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py @@ -128,7 +128,7 @@ def buildQueues(vo, sites, ces, ceTypes): :return: A dictionary containing the queues for the given parameters. """ from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getQueues - from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved + from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache result = getQueues( community=vo, @@ -142,7 +142,7 @@ def buildQueues(vo, sites, ces, ceTypes): result = getQueuesResolved( siteDict=result["Value"], - queueCECache={}, + queueCECache=QueueCECache(), vo=vo, instantiateCEs=True, ) diff --git a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py index 14a4ff5c81c..04c2dc27fef 100644 --- a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py +++ b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py @@ -39,7 +39,7 @@ def main(): from DIRAC.Core.Utilities.PrettyPrint import printTable from DIRAC.ResourceStatusSystem.Client.ResourceStatus import ResourceStatus from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus - from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, matchQueue + from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, matchQueue, QueueCECache with open(args[0]) as f: jdl = f.read() @@ -56,7 +56,7 @@ def main(): gLogger.error("Failed to get CE information") DIRACExit(-1) siteDict = resultQueues["Value"] - result = getQueuesResolved(siteDict, {}, checkPlatform=True) + result = getQueuesResolved(siteDict, QueueCECache(), checkPlatform=True) if not resultQueues["OK"]: gLogger.error("Failed to get CE information") DIRACExit(-1)