Skip to content

Commit 95331b5

Browse files
committed
Improve runtime profiling storage resilience
1 parent 5466132 commit 95331b5

4 files changed

Lines changed: 34 additions & 25 deletions

File tree

src/_pytask/journal.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,25 @@ def append(self, payload: msgspec.Struct) -> None:
2828
journal_file.write(msgspec.json.encode(payload) + b"\n")
2929

3030
def read(self) -> list[T]:
31-
"""Read entries, clearing the journal on decode errors."""
31+
"""Read entries, keeping valid entries on decode errors."""
3232
if not self.path.exists():
3333
return []
3434

3535
entries: list[T] = []
36-
for line in self.path.read_bytes().splitlines():
37-
if not line.strip():
36+
data = self.path.read_bytes()
37+
offset = 0
38+
for line in data.splitlines(keepends=True):
39+
stripped = line.strip()
40+
if not stripped:
41+
offset += len(line)
3842
continue
3943
try:
40-
entries.append(msgspec.json.decode(line, type=self.type_))
44+
entries.append(msgspec.json.decode(stripped, type=self.type_))
4145
except msgspec.DecodeError:
42-
self.delete()
43-
return []
46+
with self.path.open("rb+") as journal_file:
47+
journal_file.truncate(offset)
48+
return entries
49+
offset += len(line)
4450
return entries
4551

4652
def delete(self) -> None:

src/_pytask/profile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def pytask_unconfigure(self, session: Session) -> None:
9999
self.runtime_state.flush()
100100

101101

102+
@dataclass
102103
class DurationNameSpace:
103104
"""A namespace for adding durations to the profile."""
104105

src/_pytask/runtime_store.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
from _pytask.node_protocols import PTask
1717

18-
CURRENT_RUNTIME_VERSION = "1"
19-
2018

2119
class _RuntimeEntry(msgspec.Struct):
2220
id: str
@@ -25,12 +23,10 @@ class _RuntimeEntry(msgspec.Struct):
2523

2624

2725
class _RuntimeFile(msgspec.Struct, forbid_unknown_fields=False):
28-
runtime_version: str = msgspec.field(name="runtime-version")
2926
task: list[_RuntimeEntry] = msgspec.field(default_factory=list)
3027

3128

32-
class _RuntimeJournalEntry(msgspec.Struct):
33-
runtime_version: str = msgspec.field(name="runtime-version")
29+
class _RuntimeJournalEntry(msgspec.Struct, forbid_unknown_fields=False):
3430
id: str
3531
date: float
3632
duration: float
@@ -56,10 +52,6 @@ def _read_runtimes(path: Path) -> _RuntimeFile | None:
5652
except msgspec.DecodeError:
5753
path.unlink()
5854
return None
59-
60-
if data.runtime_version != CURRENT_RUNTIME_VERSION:
61-
path.unlink()
62-
return None
6355
return data
6456

6557

@@ -73,12 +65,7 @@ def _write_runtimes(path: Path, runtimes: _RuntimeFile) -> None:
7365
def _read_journal(
7466
journal: JsonlJournal[_RuntimeJournalEntry],
7567
) -> list[_RuntimeJournalEntry]:
76-
entries = journal.read()
77-
for entry in entries:
78-
if entry.runtime_version != CURRENT_RUNTIME_VERSION:
79-
journal.delete()
80-
return []
81-
return entries
68+
return journal.read()
8269

8370

8471
def _apply_journal(
@@ -92,7 +79,6 @@ def _apply_journal(
9279
id=entry.id, date=entry.date, duration=entry.duration
9380
)
9481
return _RuntimeFile(
95-
runtime_version=CURRENT_RUNTIME_VERSION,
9682
task=list(index.values()),
9783
)
9884

@@ -116,7 +102,6 @@ def from_root(cls, root: Path) -> RuntimeState:
116102
journal_entries = _read_journal(journal)
117103
if existing is None:
118104
runtimes = _RuntimeFile(
119-
runtime_version=CURRENT_RUNTIME_VERSION,
120105
task=[],
121106
)
122107
runtimes = _apply_journal(runtimes, journal_entries)
@@ -137,12 +122,10 @@ def update_task(self, task: PTask, start: float, end: float) -> None:
137122
entry = _RuntimeEntry(id=task_id, date=start, duration=end - start)
138123
self._index[entry.id] = entry
139124
self.runtimes = _RuntimeFile(
140-
runtime_version=CURRENT_RUNTIME_VERSION,
141125
task=list(self._index.values()),
142126
)
143127
self._rebuild_index()
144128
journal_entry = _RuntimeJournalEntry(
145-
runtime_version=CURRENT_RUNTIME_VERSION,
146129
id=entry.id,
147130
date=entry.date,
148131
duration=entry.duration,

tests/test_runtime_store.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,22 @@ def test_runtime_state_flushes_journal(tmp_path):
5858

5959
reloaded = RuntimeState.from_root(tmp_path)
6060
assert reloaded.get_duration(task) == pytest.approx(3.5)
61+
62+
63+
def test_runtime_state_recovers_from_corrupt_journal(tmp_path):
64+
tmp_path.joinpath(".pytask").mkdir()
65+
task_a = DummyTask(name="task_a")
66+
task_b = DummyTask(name="task_b")
67+
68+
state = RuntimeState.from_root(tmp_path)
69+
state.update_task(task_a, 1.0, 3.0)
70+
state.update_task(task_b, 2.0, 6.0)
71+
72+
journal_path = tmp_path / ".pytask" / "runtimes.journal"
73+
with journal_path.open("ab") as journal_file:
74+
journal_file.write(b'{"id": "corrupt"')
75+
76+
recovered = RuntimeState.from_root(tmp_path)
77+
assert recovered.get_duration(task_a) == pytest.approx(2.0)
78+
assert recovered.get_duration(task_b) == pytest.approx(4.0)
79+
assert b'"corrupt"' not in journal_path.read_bytes()

0 commit comments

Comments
 (0)