Skip to content

Commit 948dca5

Browse files
committed
Refactor patch loading and error handling methods
1 parent 35cd102 commit 948dca5

1 file changed

Lines changed: 11 additions & 14 deletions

File tree

custom_components/patch/__init__.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,15 @@ async def _read(self, path: Path | URL) -> str:
159159
response.raise_for_status()
160160
return await response.text()
161161

162-
async def init(self) -> bool | list[tuple[str, Exception]]:
163-
"""Get the content of the files."""
162+
async def load(self) -> list[tuple[str, Exception]]:
163+
"""Read the content of the files and return any errors."""
164164
self._destination, self._base, self._patch = await asyncio.gather(
165165
self._read(self.config[CONF_DESTINATION]),
166166
self._read(self.config[CONF_BASE]),
167167
self._read(self.config[CONF_PATCH]),
168168
return_exceptions=True,
169169
)
170-
if errors := [
170+
return [
171171
(str(name), content)
172172
for name, content in zip(
173173
[
@@ -179,9 +179,7 @@ async def init(self) -> bool | list[tuple[str, Exception]]:
179179
strict=True,
180180
)
181181
if isinstance(content, Exception)
182-
]:
183-
return errors
184-
return self._check()
182+
]
185183

186184
def _is_base(self) -> bool:
187185
"""Check if the destination is identical to the base file."""
@@ -191,7 +189,7 @@ def _is_patched(self) -> bool:
191189
"""Check if the destination is identical to the patch file."""
192190
return self._destination == self._patch
193191

194-
def _check(self) -> bool:
192+
def check(self) -> bool:
195193
"""Check if patch is needed and then if it's as base."""
196194
if not self._is_patched() and not self._is_base():
197195
LOGGER.error(
@@ -248,18 +246,17 @@ async def apply_after_migration(self, _: datetime.datetime | None = None) -> Non
248246

249247
async def init(self) -> bool:
250248
"""Initialize all patches."""
251-
results = await asyncio.gather(
252-
*(patch.init() for patch in self._patches), return_exceptions=True
253-
)
254249
if errors := [
255-
error for result in results if isinstance(result, list) for error in result
250+
error
251+
for result in (
252+
await asyncio.gather(*(patch.load() for patch in self._patches))
253+
)
254+
for error in result
256255
]:
257256
self._error(errors)
258257
return False
259258
if base_mismatch := [
260-
patch.config
261-
for index, patch in enumerate(self._patches)
262-
if not results[index]
259+
patch.config for patch in self._patches if not patch.check()
263260
]:
264261
self._repair(base_mismatch)
265262
return False

0 commit comments

Comments
 (0)