Skip to content

Commit 11bbcb8

Browse files
authored
Add support for chained units via OnSuccess=
* Add support for chained units * Address pip audit warnings
1 parent 95f9c20 commit 11bbcb8

6 files changed

Lines changed: 422 additions & 42 deletions

File tree

src/sysd_example.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from sysdi import TimedUnit, UnitManager
3+
from sysdi import ServiceUnit, TimedUnit, UnitManager
44
from sysdi.contrib import cronitor
55

66

@@ -53,9 +53,30 @@ class Starship(TimedUnit):
5353
),
5454
)
5555

56+
57+
# Service-only (no timer) unit for chaining
58+
@dataclass
59+
class SvcStarship(ServiceUnit):
60+
exec_bin: str = '/bin/starship'
61+
62+
63+
# Chain: A runs on a schedule; on success triggers B; on success triggers C
64+
um_chain = UnitManager(unit_prefix='utm-chain-')
65+
alpha = Starship(
66+
'Diagnostics Head',
67+
'diagnostics run',
68+
start_delay='30s',
69+
run_every='15m',
70+
)
71+
beta = SvcStarship('Diagnostics Beta', 'beta stage')
72+
gamma = SvcStarship('Diagnostics Gamma', 'gamma stage')
73+
um_chain.chain('Diagnostics Chain', alpha, beta, gamma)
74+
75+
5676
# Call this in a cli command (or something) to:
5777
# - Write units to disk
5878
# - Reload systemd daemon
5979
# - Enable timer units
6080
# - Enable login linger: which indicates timers should run even when the user is logged out
6181
# um.sync(linger='enable')
82+
# um_chain.sync(linger=None)

src/sysdi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .core import ExecWrap as ExecWrap
2+
from .core import ServiceUnit as ServiceUnit
23
from .core import TimedUnit as TimedUnit
34
from .core import UnitManager as UnitManager
45
from .core import WebPing as WebPing

src/sysdi/core.py

Lines changed: 223 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterable, Sequence
12
from dataclasses import dataclass
23
import logging
34
import os
@@ -84,9 +85,10 @@ def sync(self, *, linger: str | None, install_dpath: str | os.PathLike | None =
8485
linger_disable()
8586

8687
def unit_names(self):
87-
return [u.unit_name('service') for u in self.units] + [
88-
u.unit_name('timer') for u in self.units
89-
]
88+
names: list[str] = []
89+
for u in self.units:
90+
names.extend(u.managed_unit_names())
91+
return names
9092

9193
def stale(self):
9294
managed_names = set(self.unit_names())
@@ -111,12 +113,53 @@ def remove_all(self):
111113
"""Remove all unit files, services, and timers that match the prefix."""
112114
self.remove_stale()
113115
for unit in self.units:
114-
# Timer first to avoid systemd warning about timer being able to start service.
115-
self.remove_unit(unit.unit_name('timer'))
116-
self.remove_unit(unit.unit_name('service'))
116+
for name in unit.managed_unit_names():
117+
self.remove_unit(name)
117118

118119
daemon_reload()
119120

121+
def chain(self, chain_name: str, *units):
122+
if not units:
123+
raise ValueError('chain must include at least one unit')
124+
if len({id(u) for u in units}) != len(units):
125+
raise ValueError('chain units must be unique instances')
126+
127+
self.register(*units)
128+
129+
def _append_success(unit: object, next_unit: object):
130+
try:
131+
curr = unit.on_success # type: ignore[attr-defined]
132+
except AttributeError:
133+
curr = None
134+
items: list[object]
135+
if curr is None:
136+
items = []
137+
elif isinstance(curr, str | bytes):
138+
items = [curr]
139+
else:
140+
items = list(curr) # type: ignore[arg-type]
141+
items.append(next_unit)
142+
unit.on_success = items # type: ignore[attr-defined]
143+
144+
for i in range(len(units) - 1):
145+
_append_success(units[i], units[i + 1])
146+
147+
# Create a target that wants the first trigger (timer or service)
148+
first = units[0]
149+
wants: list[str] = []
150+
try:
151+
wants.append(first.unit_name('timer')) # type: ignore[arg-type]
152+
except AssertionError:
153+
wants.append(first.unit_name('service')) # type: ignore[arg-type]
154+
155+
tgt = TargetUnit(
156+
description=f'Chain: {chain_name}',
157+
unit_basename=slugify(chain_name),
158+
wants=wants,
159+
)
160+
tgt.unit_prefix = self.unit_prefix
161+
self.units.append(tgt)
162+
120163
def remove_stale(self):
121164
"""
122165
Remove any unit files, services, or timers that match the prefix but aren't being
@@ -266,6 +309,10 @@ class TimedUnit:
266309
# Exec Pre/Post support
267310
exec_wrap: ExecWrap | None = None
268311

312+
# Chain/Dependency support (Unit options)
313+
on_success: str | object | Sequence[str | object] | None = None
314+
on_failure: str | object | Sequence[str | object] | None = None
315+
269316
# Other Unit Config
270317
service_extra: list[str] | None = None
271318
timer_extra: list[str] | None = None
@@ -333,6 +380,39 @@ def option(self, lines, opt_name):
333380

334381
lines.append(f'{opt_name}={value}')
335382

383+
def _normalize_refs(self, refs: str | object | Sequence[str | object] | None) -> list[str]:
384+
if refs is None:
385+
return []
386+
if isinstance(refs, str | bytes):
387+
items: Iterable[str | object] = [refs]
388+
else:
389+
items = refs # type: ignore[assignment]
390+
names: list[str] = []
391+
for r in items:
392+
if isinstance(r, str | bytes):
393+
names.append(r)
394+
else:
395+
try:
396+
rpfx = getattr(r, 'unit_prefix', None)
397+
spfx = getattr(self, 'unit_prefix', None)
398+
if rpfx is None and spfx is not None and hasattr(r, 'unit_basename'):
399+
names.append(f'{spfx}{r.unit_basename}.service') # type: ignore[attr-defined]
400+
else:
401+
names.append(r.unit_name('service')) # type: ignore[attr-defined]
402+
except Exception as e: # pragma: no cover - defensive
403+
raise TypeError('Invalid unit reference for OnSuccess/OnFailure') from e
404+
return names
405+
406+
def _unit_dependency_lines(self) -> list[str]:
407+
lines: list[str] = []
408+
succ = self._normalize_refs(self.on_success)
409+
fail = self._normalize_refs(self.on_failure)
410+
if succ:
411+
lines.append('OnSuccess=' + ' '.join(succ))
412+
if fail:
413+
lines.append('OnFailure=' + ' '.join(fail))
414+
return lines
415+
336416
def timer(self):
337417
lines = []
338418
lines.extend(
@@ -374,6 +454,9 @@ def service(self):
374454
f'Description={self.description}',
375455
),
376456
)
457+
# Add chain dependencies if configured
458+
lines.extend(self._unit_dependency_lines())
459+
377460
if self.retry_max_tries and self.retry_interval_seconds:
378461
# limit interval must be set to more than (tries * interval) to contain the burst
379462
limit_interval = (self.retry_max_tries * self.retry_interval_seconds) + 15
@@ -433,3 +516,137 @@ def install(self, install_dpath):
433516
def unit_name(self, type_):
434517
assert type_ in ('service', 'timer')
435518
return f'{self.unit_prefix}{self.unit_basename}.{type_}'
519+
520+
def managed_unit_names(self) -> list[str]:
521+
return [self.unit_name('timer'), self.unit_name('service')]
522+
523+
524+
@dataclass
525+
class ServiceUnit:
526+
description: str
527+
exec_args: str = ''
528+
exec_bin: str = ''
529+
530+
service_type: str = 'oneshot'
531+
532+
retry_interval_seconds: int | None = None
533+
retry_max_tries: int | None = None
534+
535+
exec_wrap: ExecWrap | None = None
536+
537+
on_success: str | object | Sequence[str | object] | None = None
538+
on_failure: str | object | Sequence[str | object] | None = None
539+
540+
service_extra: list[str] | None = None
541+
542+
unit_basename: str | None = None
543+
unit_prefix: str | None = None
544+
545+
def __post_init__(self):
546+
if not self.exec_bin:
547+
raise ValueError('exec_bin must be set')
548+
self.unit_basename = self.unit_basename or slugify(self.description)
549+
550+
@property
551+
def exec_start(self):
552+
return f'{self.exec_bin} {self.exec_args}'.strip()
553+
554+
def _normalize_refs(self, refs: str | object | Sequence[str | object] | None) -> list[str]:
555+
if refs is None:
556+
return []
557+
if isinstance(refs, str | bytes):
558+
items: Iterable[str | object] = [refs]
559+
else:
560+
items = refs # type: ignore[assignment]
561+
names: list[str] = []
562+
for r in items:
563+
if isinstance(r, str | bytes):
564+
names.append(r)
565+
else:
566+
try:
567+
rpfx = getattr(r, 'unit_prefix', None)
568+
spfx = getattr(self, 'unit_prefix', None)
569+
if rpfx is None and spfx is not None and hasattr(r, 'unit_basename'):
570+
names.append(f'{spfx}{r.unit_basename}.service') # type: ignore[attr-defined]
571+
else:
572+
names.append(r.unit_name('service')) # type: ignore[attr-defined]
573+
except Exception as e: # pragma: no cover
574+
raise TypeError('Invalid unit reference for OnSuccess/OnFailure') from e
575+
return names
576+
577+
def _unit_dependency_lines(self) -> list[str]:
578+
lines: list[str] = []
579+
succ = self._normalize_refs(self.on_success)
580+
fail = self._normalize_refs(self.on_failure)
581+
if succ:
582+
lines.append('OnSuccess=' + ' '.join(succ))
583+
if fail:
584+
lines.append('OnFailure=' + ' '.join(fail))
585+
return lines
586+
587+
def service(self):
588+
lines: list[str] = []
589+
lines.extend(('[Unit]', f'Description={self.description}'))
590+
lines.extend(self._unit_dependency_lines())
591+
592+
if self.retry_max_tries and self.retry_interval_seconds:
593+
limit_interval = (self.retry_max_tries * self.retry_interval_seconds) + 15
594+
lines.extend(
595+
(
596+
f'StartLimitInterval={limit_interval}',
597+
f'StartLimitBurst={self.retry_max_tries}',
598+
),
599+
)
600+
601+
lines.extend(('', '[Service]', f'Type={self.service_type}'))
602+
if self.retry_interval_seconds:
603+
lines.extend(('Restart=on-failure', f'RestartSec={self.retry_interval_seconds}'))
604+
605+
lines.append(f'ExecStart={self.exec_start}')
606+
607+
if self.exec_wrap:
608+
lines.append(f'ExecStartPre={self.exec_wrap.pre()}')
609+
lines.append(f'ExecStopPost={self.exec_wrap.post()}')
610+
611+
lines.extend(self.service_extra or ())
612+
return '\n'.join(lines) + '\n'
613+
614+
def install(self, install_dpath: Path):
615+
install_dpath.mkdir(parents=True, exist_ok=True)
616+
service_fname = self.unit_name('service')
617+
service_fpath = install_dpath.joinpath(service_fname)
618+
service_fpath.write_text(self.service())
619+
log.info(f'(Re)installed {service_fname}')
620+
daemon_reload()
621+
622+
def unit_name(self, type_: str):
623+
assert type_ == 'service'
624+
return f'{self.unit_prefix}{self.unit_basename}.{type_}'
625+
626+
def managed_unit_names(self) -> list[str]:
627+
return [self.unit_name('service')]
628+
629+
630+
@dataclass
631+
class TargetUnit:
632+
description: str
633+
unit_basename: str
634+
wants: list[str] | None = None
635+
unit_prefix: str | None = None
636+
637+
def install(self, install_dpath: Path):
638+
install_dpath.mkdir(parents=True, exist_ok=True)
639+
fname = self.unit_name()
640+
fpath = install_dpath.joinpath(fname)
641+
lines = ['[Unit]', f'Description={self.description}']
642+
if self.wants:
643+
lines.append('Wants=' + ' '.join(self.wants))
644+
fpath.write_text('\n'.join(lines) + '\n')
645+
log.info(f'(Re)installed {fname}')
646+
daemon_reload()
647+
648+
def unit_name(self) -> str:
649+
return f'{self.unit_prefix}{self.unit_basename}.target'
650+
651+
def managed_unit_names(self) -> list[str]:
652+
return [self.unit_name()]

src/sysdi_tests/test_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_ok(self, um: UnitManager, tmp_path: Path):
4444
'Check exec wrap',
4545
exec_bin='/usr/bin/true',
4646
exec_wrap=FileWrap(tmp_path),
47+
on_active_sec='1000s',
4748
),
4849
)
4950
um.sync(linger=None)
@@ -59,6 +60,7 @@ def test_fail(self, um: UnitManager, tmp_path: Path):
5960
'Check exec wrap',
6061
exec_bin='/usr/bin/false',
6162
exec_wrap=FileWrap(tmp_path),
63+
on_active_sec='1000s',
6264
),
6365
)
6466
um.sync(linger=None)

0 commit comments

Comments
 (0)