1+ from collections .abc import Iterable , Sequence
12from dataclasses import dataclass
23import logging
34import os
@@ -84,9 +85,10 @@ def sync(self, *, linger: str | None, install_dpath: str | os.PathLike | None =
8485 linger_disable ()
8586
8687 def unit_names (self ):
87- return [u .unit_name ('service' ) for u in self .units ] + [
88- u .unit_name ('timer' ) for u in self .units
89- ]
88+ names : list [str ] = []
89+ for u in self .units :
90+ names .extend (u .managed_unit_names ())
91+ return names
9092
9193 def stale (self ):
9294 managed_names = set (self .unit_names ())
@@ -111,12 +113,53 @@ def remove_all(self):
111113 """Remove all unit files, services, and timers that match the prefix."""
112114 self .remove_stale ()
113115 for unit in self .units :
114- # Timer first to avoid systemd warning about timer being able to start service.
115- self .remove_unit (unit .unit_name ('timer' ))
116- self .remove_unit (unit .unit_name ('service' ))
116+ for name in unit .managed_unit_names ():
117+ self .remove_unit (name )
117118
118119 daemon_reload ()
119120
121+ def chain (self , chain_name : str , * units ):
122+ if not units :
123+ raise ValueError ('chain must include at least one unit' )
124+ if len ({id (u ) for u in units }) != len (units ):
125+ raise ValueError ('chain units must be unique instances' )
126+
127+ self .register (* units )
128+
129+ def _append_success (unit : object , next_unit : object ):
130+ try :
131+ curr = unit .on_success # type: ignore[attr-defined]
132+ except AttributeError :
133+ curr = None
134+ items : list [object ]
135+ if curr is None :
136+ items = []
137+ elif isinstance (curr , str | bytes ):
138+ items = [curr ]
139+ else :
140+ items = list (curr ) # type: ignore[arg-type]
141+ items .append (next_unit )
142+ unit .on_success = items # type: ignore[attr-defined]
143+
144+ for i in range (len (units ) - 1 ):
145+ _append_success (units [i ], units [i + 1 ])
146+
147+ # Create a target that wants the first trigger (timer or service)
148+ first = units [0 ]
149+ wants : list [str ] = []
150+ try :
151+ wants .append (first .unit_name ('timer' )) # type: ignore[arg-type]
152+ except AssertionError :
153+ wants .append (first .unit_name ('service' )) # type: ignore[arg-type]
154+
155+ tgt = TargetUnit (
156+ description = f'Chain: { chain_name } ' ,
157+ unit_basename = slugify (chain_name ),
158+ wants = wants ,
159+ )
160+ tgt .unit_prefix = self .unit_prefix
161+ self .units .append (tgt )
162+
120163 def remove_stale (self ):
121164 """
122165 Remove any unit files, services, or timers that match the prefix but aren't being
@@ -266,6 +309,10 @@ class TimedUnit:
266309 # Exec Pre/Post support
267310 exec_wrap : ExecWrap | None = None
268311
312+ # Chain/Dependency support (Unit options)
313+ on_success : str | object | Sequence [str | object ] | None = None
314+ on_failure : str | object | Sequence [str | object ] | None = None
315+
269316 # Other Unit Config
270317 service_extra : list [str ] | None = None
271318 timer_extra : list [str ] | None = None
@@ -333,6 +380,39 @@ def option(self, lines, opt_name):
333380
334381 lines .append (f'{ opt_name } ={ value } ' )
335382
383+ def _normalize_refs (self , refs : str | object | Sequence [str | object ] | None ) -> list [str ]:
384+ if refs is None :
385+ return []
386+ if isinstance (refs , str | bytes ):
387+ items : Iterable [str | object ] = [refs ]
388+ else :
389+ items = refs # type: ignore[assignment]
390+ names : list [str ] = []
391+ for r in items :
392+ if isinstance (r , str | bytes ):
393+ names .append (r )
394+ else :
395+ try :
396+ rpfx = getattr (r , 'unit_prefix' , None )
397+ spfx = getattr (self , 'unit_prefix' , None )
398+ if rpfx is None and spfx is not None and hasattr (r , 'unit_basename' ):
399+ names .append (f'{ spfx } { r .unit_basename } .service' ) # type: ignore[attr-defined]
400+ else :
401+ names .append (r .unit_name ('service' )) # type: ignore[attr-defined]
402+ except Exception as e : # pragma: no cover - defensive
403+ raise TypeError ('Invalid unit reference for OnSuccess/OnFailure' ) from e
404+ return names
405+
406+ def _unit_dependency_lines (self ) -> list [str ]:
407+ lines : list [str ] = []
408+ succ = self ._normalize_refs (self .on_success )
409+ fail = self ._normalize_refs (self .on_failure )
410+ if succ :
411+ lines .append ('OnSuccess=' + ' ' .join (succ ))
412+ if fail :
413+ lines .append ('OnFailure=' + ' ' .join (fail ))
414+ return lines
415+
336416 def timer (self ):
337417 lines = []
338418 lines .extend (
@@ -374,6 +454,9 @@ def service(self):
374454 f'Description={ self .description } ' ,
375455 ),
376456 )
457+ # Add chain dependencies if configured
458+ lines .extend (self ._unit_dependency_lines ())
459+
377460 if self .retry_max_tries and self .retry_interval_seconds :
378461 # limit interval must be set to more than (tries * interval) to contain the burst
379462 limit_interval = (self .retry_max_tries * self .retry_interval_seconds ) + 15
@@ -433,3 +516,137 @@ def install(self, install_dpath):
433516 def unit_name (self , type_ ):
434517 assert type_ in ('service' , 'timer' )
435518 return f'{ self .unit_prefix } { self .unit_basename } .{ type_ } '
519+
520+ def managed_unit_names (self ) -> list [str ]:
521+ return [self .unit_name ('timer' ), self .unit_name ('service' )]
522+
523+
524+ @dataclass
525+ class ServiceUnit :
526+ description : str
527+ exec_args : str = ''
528+ exec_bin : str = ''
529+
530+ service_type : str = 'oneshot'
531+
532+ retry_interval_seconds : int | None = None
533+ retry_max_tries : int | None = None
534+
535+ exec_wrap : ExecWrap | None = None
536+
537+ on_success : str | object | Sequence [str | object ] | None = None
538+ on_failure : str | object | Sequence [str | object ] | None = None
539+
540+ service_extra : list [str ] | None = None
541+
542+ unit_basename : str | None = None
543+ unit_prefix : str | None = None
544+
545+ def __post_init__ (self ):
546+ if not self .exec_bin :
547+ raise ValueError ('exec_bin must be set' )
548+ self .unit_basename = self .unit_basename or slugify (self .description )
549+
550+ @property
551+ def exec_start (self ):
552+ return f'{ self .exec_bin } { self .exec_args } ' .strip ()
553+
554+ def _normalize_refs (self , refs : str | object | Sequence [str | object ] | None ) -> list [str ]:
555+ if refs is None :
556+ return []
557+ if isinstance (refs , str | bytes ):
558+ items : Iterable [str | object ] = [refs ]
559+ else :
560+ items = refs # type: ignore[assignment]
561+ names : list [str ] = []
562+ for r in items :
563+ if isinstance (r , str | bytes ):
564+ names .append (r )
565+ else :
566+ try :
567+ rpfx = getattr (r , 'unit_prefix' , None )
568+ spfx = getattr (self , 'unit_prefix' , None )
569+ if rpfx is None and spfx is not None and hasattr (r , 'unit_basename' ):
570+ names .append (f'{ spfx } { r .unit_basename } .service' ) # type: ignore[attr-defined]
571+ else :
572+ names .append (r .unit_name ('service' )) # type: ignore[attr-defined]
573+ except Exception as e : # pragma: no cover
574+ raise TypeError ('Invalid unit reference for OnSuccess/OnFailure' ) from e
575+ return names
576+
577+ def _unit_dependency_lines (self ) -> list [str ]:
578+ lines : list [str ] = []
579+ succ = self ._normalize_refs (self .on_success )
580+ fail = self ._normalize_refs (self .on_failure )
581+ if succ :
582+ lines .append ('OnSuccess=' + ' ' .join (succ ))
583+ if fail :
584+ lines .append ('OnFailure=' + ' ' .join (fail ))
585+ return lines
586+
587+ def service (self ):
588+ lines : list [str ] = []
589+ lines .extend (('[Unit]' , f'Description={ self .description } ' ))
590+ lines .extend (self ._unit_dependency_lines ())
591+
592+ if self .retry_max_tries and self .retry_interval_seconds :
593+ limit_interval = (self .retry_max_tries * self .retry_interval_seconds ) + 15
594+ lines .extend (
595+ (
596+ f'StartLimitInterval={ limit_interval } ' ,
597+ f'StartLimitBurst={ self .retry_max_tries } ' ,
598+ ),
599+ )
600+
601+ lines .extend (('' , '[Service]' , f'Type={ self .service_type } ' ))
602+ if self .retry_interval_seconds :
603+ lines .extend (('Restart=on-failure' , f'RestartSec={ self .retry_interval_seconds } ' ))
604+
605+ lines .append (f'ExecStart={ self .exec_start } ' )
606+
607+ if self .exec_wrap :
608+ lines .append (f'ExecStartPre={ self .exec_wrap .pre ()} ' )
609+ lines .append (f'ExecStopPost={ self .exec_wrap .post ()} ' )
610+
611+ lines .extend (self .service_extra or ())
612+ return '\n ' .join (lines ) + '\n '
613+
614+ def install (self , install_dpath : Path ):
615+ install_dpath .mkdir (parents = True , exist_ok = True )
616+ service_fname = self .unit_name ('service' )
617+ service_fpath = install_dpath .joinpath (service_fname )
618+ service_fpath .write_text (self .service ())
619+ log .info (f'(Re)installed { service_fname } ' )
620+ daemon_reload ()
621+
622+ def unit_name (self , type_ : str ):
623+ assert type_ == 'service'
624+ return f'{ self .unit_prefix } { self .unit_basename } .{ type_ } '
625+
626+ def managed_unit_names (self ) -> list [str ]:
627+ return [self .unit_name ('service' )]
628+
629+
630+ @dataclass
631+ class TargetUnit :
632+ description : str
633+ unit_basename : str
634+ wants : list [str ] | None = None
635+ unit_prefix : str | None = None
636+
637+ def install (self , install_dpath : Path ):
638+ install_dpath .mkdir (parents = True , exist_ok = True )
639+ fname = self .unit_name ()
640+ fpath = install_dpath .joinpath (fname )
641+ lines = ['[Unit]' , f'Description={ self .description } ' ]
642+ if self .wants :
643+ lines .append ('Wants=' + ' ' .join (self .wants ))
644+ fpath .write_text ('\n ' .join (lines ) + '\n ' )
645+ log .info (f'(Re)installed { fname } ' )
646+ daemon_reload ()
647+
648+ def unit_name (self ) -> str :
649+ return f'{ self .unit_prefix } { self .unit_basename } .target'
650+
651+ def managed_unit_names (self ) -> list [str ]:
652+ return [self .unit_name ()]
0 commit comments