Skip to content

Commit ce960e5

Browse files
committed
Report base mismatch on integration setup
1 parent 9f5e6b9 commit ce960e5

2 files changed

Lines changed: 49 additions & 22 deletions

File tree

custom_components/patch/__init__.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@ async def async_reload(_: ServiceCall) -> None:
131131

132132
hass.services.async_register(DOMAIN, SERVICE_RELOAD, async_reload, vol.Schema({}))
133133

134-
event.async_track_point_in_time(
135-
hass,
136-
PatchManager(hass, config[DOMAIN]).run_after_migration,
137-
dt_util.now() + datetime.timedelta(seconds=config[DOMAIN][CONF_DELAY]),
138-
)
134+
patch_manager = PatchManager(hass, config[DOMAIN])
135+
if await patch_manager.init():
136+
event.async_track_point_in_time(
137+
hass,
138+
patch_manager.apply_after_migration,
139+
dt_util.now() + datetime.timedelta(seconds=config[DOMAIN][CONF_DELAY]),
140+
)
139141

140142
return True
141143

@@ -157,13 +159,14 @@ async def _read(self, path: Path | URL) -> str:
157159
response.raise_for_status()
158160
return await response.text()
159161

160-
async def init(self) -> None:
162+
async def init(self) -> bool:
161163
"""Get the content of the files."""
162164
self._destination, self._base, self._patch = await asyncio.gather(
163165
self._read(self.config[CONF_DESTINATION]),
164166
self._read(self.config[CONF_BASE]),
165167
self._read(self.config[CONF_PATCH]),
166168
)
169+
return self._check()
167170

168171
def _is_base(self) -> bool:
169172
"""Check if the destination is identical to the base file."""
@@ -173,7 +176,7 @@ def _is_patched(self) -> bool:
173176
"""Check if the destination is identical to the patch file."""
174177
return self._destination == self._patch
175178

176-
def check(self) -> bool:
179+
def _check(self) -> bool:
177180
"""Check if patch is needed and then if it's as base."""
178181
if not self._is_patched() and not self._is_base():
179182
LOGGER.error(
@@ -216,43 +219,48 @@ def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
216219
self._patches = [Patch(hass, patch) for patch in config.get(CONF_FILES, [])]
217220

218221
@callback
219-
async def run_after_migration(self, _: datetime.datetime | None = None) -> None:
220-
"""Run if there is no migration in progress."""
222+
async def apply_after_migration(self, _: datetime.datetime | None = None) -> None:
223+
"""Apply patches if there is no DB migration in progress."""
221224
if recorder.async_migration_in_progress(self._hass):
222225
LOGGER.info("Recorder migration in progress. Checking again in a minute.")
223226
event.async_track_point_in_time(
224227
self._hass,
225-
self.run_after_migration,
228+
self.apply_after_migration,
226229
dt_util.now() + datetime.timedelta(minutes=1),
227230
)
228231
else:
229232
await self.run()
230233

231-
async def run(self) -> None:
232-
"""Execute."""
233-
if not self._patches:
234-
return
235-
236-
await asyncio.gather(*(patch.init() for patch in self._patches))
237-
234+
async def init(self) -> bool:
235+
"""Initialize all patches."""
236+
results = await asyncio.gather(*(patch.init() for patch in self._patches))
238237
if base_mismatch := [
239-
patch.config for patch in self._patches if not patch.check()
238+
patch.config
239+
for index, patch in enumerate(self._patches)
240+
if not results[index]
240241
]:
241242
self._repair(base_mismatch)
242-
return
243+
return False
244+
return True
243245

246+
async def _apply(self) -> None:
247+
"""Execute."""
244248
results = await asyncio.gather(*(patch.apply() for patch in self._patches))
245-
updates = [
249+
if updates := [
246250
patch.config for index, patch in enumerate(self._patches) if results[index]
247-
]
248-
if updates:
251+
]:
249252
self._applied(updates)
250253
if self._config[CONF_RESTART]:
251254
LOGGER.warning("Restarting HA core.")
252255
await self._hass.services.async_call(
253256
HA_DOMAIN, SERVICE_HOMEASSISTANT_RESTART
254257
)
255258

259+
async def run(self) -> None:
260+
"""Run the patching process."""
261+
if await self.init():
262+
await self._apply()
263+
256264
def _format_files(self, files: list[PatchType]) -> str:
257265
"""Format list of files for logging."""
258266
return f"- {'\n- '.join(f'`{file[CONF_DESTINATION]}`' for file in files)}"

tests/test_init.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,22 @@ def _delay_count(log: str) -> int:
441441
async_migration_in_progress_mock.return_value = False
442442
await async_next_day(hass, freezer)
443443
assert _delay_count(caplog.text) == i
444+
445+
446+
@pytest.mark.allowed_logs(["Destination file"])
447+
async def test_immediate_base_mismatch(hass: HomeAssistant) -> None:
448+
"""Test base mismatch is reported during integration setup."""
449+
repairs = async_capture_events(hass, ir.EVENT_REPAIRS_ISSUE_REGISTRY_UPDATED)
450+
await async_setup(
451+
hass,
452+
yaml.load(
453+
"""
454+
files:
455+
- destination: "{homeassistant}/__init__.py"
456+
base: "{homeassistant}/__main__.py"
457+
patch: "{homeassistant}/py.typed"
458+
""",
459+
Loader=yaml.SafeLoader,
460+
),
461+
)
462+
assert repairs[0].data["issue_id"].startswith("patch_file_base_mismatch")

0 commit comments

Comments
 (0)