Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/DIRAC/Resources/Computing/SSHComputingElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ def _get(self, connection: Connection, remote: str, local: str, preserveMode: bo
except SSHException as e:
return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}")

def close(self):
"""Close the SSH connection(s) to the remote host.

Fabric/paramiko connections recommended to be closed explicitly.
"""
if self.connection is None:
return
# The gateway (jump host) is a distinct Connection referenced by the
# main connection; close it as well to avoid leaking its Transport.
gateway = getattr(self.connection, "gateway", None)
try:
self.connection.close()
except Exception as e:
self.log.warn("Failed to close SSH connection", str(e))
if isinstance(gateway, Connection):
try:
gateway.close()
except Exception as e:
self.log.warn("Failed to close SSH gateway connection", str(e))
self.connection = None

#############################################################################

def _getBatchSystem(self):
Expand Down
78 changes: 63 additions & 15 deletions src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
"""
import datetime

from DIRAC import S_OK, gConfig
from DIRAC import S_OK
from DIRAC.AccountingSystem.Client.DataStoreClient import gDataStoreClient
from DIRAC.AccountingSystem.Client.Types.Pilot import Pilot as PilotAccounting
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping, getQueue
from DIRAC.Core.Base.AgentModule import AgentModule
from DIRAC.Core.Utilities import TimeUtilities
from DIRAC.Interfaces.API.DiracAdmin import DiracAdmin
from DIRAC.WorkloadManagementSystem.Client import PilotStatus
from DIRAC.WorkloadManagementSystem.Client.PilotManagerClient import PilotManagerClient
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import setPilotCredentials
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache


class PilotStatusAgent(AgentModule):
Expand All @@ -39,14 +40,17 @@ def __init__(self, *args, **kwargs):

self.jobDB = None
self.pilotDB = None
self.diracadmin = None
# Cache of ComputingElement instances keyed by queue. CEs (and their
# SSH connections) are reused across cycles to avoid opening a new
# connection for every pilot kill, following the SiteDirector pattern.
self.ceCache = None

#############################################################################
def initialize(self):
"""Sets defaults"""

self.pilotDB = PilotAgentsDB()
self.diracadmin = DiracAdmin()
self.ceCache = QueueCECache()
self.jobDB = JobDB()
self.clearPilotsDelay = self.am_getOption("ClearPilotsDelay", 30)
self.clearAbortedDelay = self.am_getOption("ClearAbortedPilotsDelay", 7)
Expand Down Expand Up @@ -199,17 +203,61 @@ def _addPilotsAccountingReport(self, pilotsData):
return retVal
return S_OK()

def _getCEForQueue(self, vo, site, ce, queue, gridType):
"""Return a cached ComputingElement for the given queue.

The CE (and its SSH connection) is reused across cycles via the
:class:`QueueCECache`, and rebuilt automatically if the queue
configuration changes.
"""
result = getQueue(site, ce, queue)
if not result["OK"]:
return result
queueKey = "@@@".join([vo, site, ce, queue])
return self.ceCache.getCE(queueKey, gridType, ce, result["Value"])

def _killPilots(self, acc):
for i in sorted(acc.keys()):
result = self.diracadmin.getPilotInfo(i)
if result["OK"] and i in result["Value"] and "Status" in result["Value"][i]:
ret = self.diracadmin.killPilot(str(i))
if ret["OK"]:
self.log.info("Successfully deleted", f": {i} (Status : {result['Value'][i]['Status']})")
else:
self.log.error("Failed to delete pilot: ", f"{i} : {ret['Message']}")
else:
self.log.error("Failed to get pilot info", f"{i} : {str(result)}")
"""Declare the given pilots killed on their CEs.

Pilots are grouped per queue and killed in a single call per queue,
reusing a cached CE/connection per queue across cycles.
"""
# Group the pilots to kill per queue
pilotsByQueue = {}
for pRef in acc:
pilotDict = acc[pRef]
queueFields = [pilotDict["VO"], pilotDict["GridSite"], pilotDict["DestinationSite"], pilotDict["Queue"]]
# A pilot with an incomplete queue definition cannot be located on a
# CE; skip it rather than letting it abort the whole batch.
if not all(queueFields):
self.log.warn("Cannot determine queue for pilot, skipping kill", f"{pRef} : {queueFields}")
continue
queueKey = "@@@".join(queueFields)
queueData = pilotsByQueue.setdefault(queueKey, {"GridType": pilotDict["GridType"], "PilotList": []})
queueData["PilotList"].append(pRef)

for queueKey, queueData in pilotsByQueue.items():
vo, site, ce, queue = queueKey.split("@@@")
result = self._getCEForQueue(vo, site, ce, queue, queueData["GridType"])
if not result["OK"]:
self.log.error("Failed to get CE for queue", f"{queueKey} : {result['Message']}")
continue
computingElement = result["Value"]

# The connection is reused, but pilot credentials (proxy/token)
# expire and must be refreshed every cycle.
result = setPilotCredentials(computingElement, {"VO": vo})
if not result["OK"]:
self.log.error("Failed to set pilot credentials", f"{queueKey} : {result['Message']}")
continue

result = computingElement.killJob(queueData["PilotList"])
if not result["OK"]:
self.log.error("Failed to delete pilots", f"{queueKey} : {result['Message']}")
# Drop the (possibly stale) connection so it is rebuilt next cycle
self.ceCache.drop(queueKey)
continue
self.log.info("Successfully deleted pilots", f": {len(queueData['PilotList'])} in queue {queueKey}")

def _checkJobLastUpdateTime(self, joblist, StalledDays):
timeLimitToConsider = datetime.datetime.utcnow() - TimeUtilities.day * StalledDays
Expand Down
4 changes: 2 additions & 2 deletions src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
transferInputSandbox,
)
from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache
from DIRAC.WorkloadManagementSystem.Utilities.Utils import createJobWrapper

MAX_JOBS_MANAGED = 100
Expand All @@ -61,7 +61,7 @@ def __init__(self, agentName, loadName, baseAgentName=False, properties=None):
self.firstPass = True
self.maxJobsToSubmit = MAX_JOBS_MANAGED
self.queueDict = {}
self.queueCECache = {}
self.queueCECache = QueueCECache()

self.pilotDN = ""
self.vo = ""
Expand Down
10 changes: 6 additions & 4 deletions src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
getPilotFilesCompressedEncodedDict,
pilotWrapperScript,
)
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache

MAX_PILOTS_TO_SUBMIT = 100

Expand All @@ -59,8 +59,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.queueDict = {}
# self.queueCECache aims at saving CEs information over the cycles to avoid to create the exact same CEs each cycle
self.queueCECache = {}
# self.queueCECache saves CE instances (and their connections) over the cycles
# to avoid re-creating the exact same CEs -- and re-opening connections -- each cycle
self.queueCECache = QueueCECache()
self.failedQueues = defaultdict(int)
self.maxPilotsToSubmit = MAX_PILOTS_TO_SUBMIT

Expand Down Expand Up @@ -575,7 +576,8 @@ def _getExecutable(self, queue: str, proxy: X509Chain, jobExecDir: str = "", env
# in your machine, the executable files will be in the same place
# but it does not matter since they are very temporary

ce = self.queueCECache[queue]["CE"]
# Same CE instance as the cache holds; queueDict is the canonical accessor
ce = self.queueDict[queue]["CE"]
workingDirectory = getattr(ce, "workingDirectory", self.workingDirectory)

executable = self._writePilotScript(
Expand Down
81 changes: 63 additions & 18 deletions src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import hashlib

from DIRAC import S_OK, S_ERROR
from DIRAC import S_OK, S_ERROR, gLogger
from DIRAC.Core.Utilities.List import fromChar
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getDIRACPlatform
Expand All @@ -16,7 +16,6 @@ def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, inst
The main goal of this method is to return a dictionary of queues
"""
queueDict = {}
ceFactory = ComputingElementFactory()

for site in siteDict:
for ce in siteDict[site]:
Expand Down Expand Up @@ -50,27 +49,20 @@ def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, inst
ceQueueDict.update(queueDict[queueName]["ParametersDict"])

if instantiateCEs:
# Generate the CE object for the queue or pick the already existing one
# if the queue definition did not change
queueHash = generateQueueHash(ceQueueDict)
if queueName in queueCECache and queueCECache[queueName]["Hash"] == queueHash:
queueCE = queueCECache[queueName]["CE"]
else:
result = ceFactory.getCE(ceName=ce, ceType=ceDict["CEType"], ceParametersDict=ceQueueDict)
if not result["OK"]:
queueDict.pop(queueName)
continue
queueCECache.setdefault(queueName, {})
queueCECache[queueName]["Hash"] = queueHash
queueCECache[queueName]["CE"] = result["Value"]
queueCE = queueCECache[queueName]["CE"]
# Get the CE object for the queue, reusing the cached one if the
# queue definition did not change, or (re)building it otherwise.
result = queueCECache.getCE(queueName, ceDict["CEType"], ce, ceQueueDict)
if not result["OK"]:
queueDict.pop(queueName)
continue
queueCE = result["Value"]

queueDict[queueName]["ParametersDict"].update(queueCE.ceParameters)
queueDict[queueName]["CE"] = queueCE
result = queueDict[queueName]["CE"].isValid()
result = queueCE.isValid()
if not result["OK"]:
queueDict.pop(queueName)
queueCECache.pop(queueName)
queueCECache.drop(queueName)
continue

queueDict[queueName]["CEName"] = ce
Expand Down Expand Up @@ -141,6 +133,59 @@ def generateQueueHash(queueDict):
return hexstring


class QueueCECache:
"""A cache of ComputingElement instances keyed by queue.

CEs -- and, for connection-based CEs such as the SSHComputingElement, their
underlying connections -- are reused across cycles instead of being
re-created on every use. A CE is rebuilt only when its queue parameters
change, detected through a hash of the parameters dictionary: the same
invalidation strategy used by the SiteDirector (see :func:`getQueuesResolved`).
"""

def __init__(self):
# queueKey -> {"Hash": <str>, "CE": <ComputingElement>}
self._cache = {}
self._ceFactory = ComputingElementFactory()
self.log = gLogger.getSubLogger(self.__class__.__name__)

def getCE(self, queueKey, ceType, ceName, ceParametersDict):
"""Return a cached CE for ``queueKey``, (re)building it when needed.

:param str queueKey: unique identifier of the queue, used as cache key
:param str ceType: CE type passed to the ComputingElementFactory
:param str ceName: CE name passed to the ComputingElementFactory
:param dict ceParametersDict: queue/CE parameters; a change triggers a rebuild
:return: S_OK(ce)/S_ERROR
"""
queueHash = generateQueueHash(ceParametersDict)
cached = self._cache.get(queueKey)
if cached is not None and cached["Hash"] == queueHash:
return S_OK(cached["CE"])

# First use, or the queue definition changed: drop any stale CE
# (releasing its connection) and build a fresh one.
self.drop(queueKey)
result = self._ceFactory.getCE(ceType=ceType, ceName=ceName, ceParametersDict=ceParametersDict)
if not result["OK"]:
return result
self._cache[queueKey] = {"Hash": queueHash, "CE": result["Value"]}
return S_OK(result["Value"])

def drop(self, queueKey):
"""Remove a cached CE and release its resources (e.g. close connections)."""
entry = self._cache.pop(queueKey, None)
if entry is None:
return
ce = entry["CE"]
# Only connection-based CEs (e.g. SSHComputingElement) define close()
if hasattr(ce, "close"):
try:
ce.close()
except Exception as e:
self.log.warn("Failed to close CE", f"{queueKey} : {str(e)}")


def matchQueue(jobJDL, queueDict, fullMatch=False):
"""
Match the job description to the queue definition
Expand Down
Loading
Loading