Skip to content

Commit 081d553

Browse files
committed
fix(serverless): inject runtime vars on update to survive template env overwrite
Extract _inject_runtime_template_vars() from _do_deploy so both initial deploy and update() paths inject RUNPOD_API_KEY and FLASH_MODULE_PATH into template.env. Without this, runtime vars set during _do_deploy were silently dropped when update() overwrote the template env on config drift. Also preserve explicit template.env entries when env dict is empty on both sides.
1 parent 8c46727 commit 081d553

2 files changed

Lines changed: 229 additions & 32 deletions

File tree

src/runpod_flash/core/resources/serverless.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,41 @@ def _inject_template_env(self, key: str, value: str) -> None:
654654
if key not in existing_keys:
655655
self.template.env.append(KeyValuePair(key=key, value=value))
656656

657+
def _inject_runtime_template_vars(self) -> None:
658+
"""Inject runtime env vars into template.env without mutating self.env.
659+
660+
For QB endpoints making remote calls: injects RUNPOD_API_KEY.
661+
For LB endpoints: injects FLASH_MODULE_PATH.
662+
663+
Called by both _do_deploy (initial) and update (env changes) so
664+
runtime vars survive template updates.
665+
"""
666+
env_dict = self.env or {}
667+
668+
if self.type == ServerlessType.QB:
669+
if self._check_makes_remote_calls():
670+
if "RUNPOD_API_KEY" not in env_dict:
671+
from runpod_flash.core.credentials import get_api_key
672+
673+
api_key = get_api_key()
674+
if api_key:
675+
self._inject_template_env("RUNPOD_API_KEY", api_key)
676+
log.debug(
677+
f"{self.name}: Injected RUNPOD_API_KEY for remote calls "
678+
f"(makes_remote_calls=True)"
679+
)
680+
else:
681+
log.warning(
682+
f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. "
683+
f"Remote calls to other endpoints will fail."
684+
)
685+
686+
elif self.type == ServerlessType.LB:
687+
module_path = self._get_module_path()
688+
if module_path and "FLASH_MODULE_PATH" not in env_dict:
689+
self._inject_template_env("FLASH_MODULE_PATH", module_path)
690+
log.debug(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}")
691+
657692
async def _do_deploy(self) -> "DeployableResource":
658693
"""
659694
Deploys the serverless resource using the provided configuration.
@@ -669,37 +704,7 @@ async def _do_deploy(self) -> "DeployableResource":
669704
log.debug(f"{self} exists")
670705
return self
671706

672-
# Inject API key for queue-based endpoints that make remote calls.
673-
# Injected into template.env (not self.env) to avoid false config drift.
674-
if self.type == ServerlessType.QB:
675-
makes_remote_calls = self._check_makes_remote_calls()
676-
677-
if makes_remote_calls:
678-
env_dict = self.env or {}
679-
if "RUNPOD_API_KEY" not in env_dict:
680-
from runpod_flash.core.credentials import get_api_key
681-
682-
api_key = get_api_key()
683-
if api_key:
684-
self._inject_template_env("RUNPOD_API_KEY", api_key)
685-
log.debug(
686-
f"{self.name}: Injected RUNPOD_API_KEY for remote calls "
687-
f"(makes_remote_calls=True)"
688-
)
689-
else:
690-
log.warning(
691-
f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. "
692-
f"Remote calls to other endpoints will fail."
693-
)
694-
695-
# Inject module path for load-balanced endpoints.
696-
# Injected into template.env (not self.env) to avoid false config drift.
697-
elif self.type == ServerlessType.LB:
698-
env_dict = self.env or {}
699-
module_path = self._get_module_path()
700-
if module_path and "FLASH_MODULE_PATH" not in env_dict:
701-
self._inject_template_env("FLASH_MODULE_PATH", module_path)
702-
log.debug(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}")
707+
self._inject_runtime_template_vars()
703708

704709
# Ensure network volume is deployed first
705710
await self._ensure_network_volume_deployed()
@@ -773,11 +778,25 @@ async def update(self, new_config: "ServerlessResource") -> "ServerlessResource"
773778
# hasn't changed. This lets the platform keep vars it
774779
# injected (e.g. PORT, PORT_HEALTH on LB endpoints)
775780
# and avoids a spurious rolling release.
781+
#
782+
# Also check template.env: if env is empty but the
783+
# caller provided explicit template env entries, those
784+
# must not be silently dropped.
776785
env_unchanged = self.env == new_config.env
786+
has_explicit_template_env = (
787+
not new_config.env and new_config.template.env is not None
788+
)
789+
skip_env = env_unchanged and not has_explicit_template_env
790+
791+
if not skip_env:
792+
# Inject runtime vars (RUNPOD_API_KEY, FLASH_MODULE_PATH)
793+
# so they survive the template env overwrite.
794+
new_config._inject_runtime_template_vars()
795+
777796
template_payload = self._build_template_update_payload(
778797
new_config.template,
779798
resolved_template_id,
780-
skip_env=env_unchanged,
799+
skip_env=skip_env,
781800
)
782801
await client.update_template(template_payload)
783802
log.debug(

tests/unit/resources/test_serverless.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,3 +1693,181 @@ async def test_update_includes_env_when_changed(self):
16931693
assert mock_client.update_template.called
16941694
template_payload = mock_client.update_template.call_args.args[0]
16951695
assert "env" in template_payload
1696+
1697+
@pytest.mark.asyncio
1698+
async def test_update_injects_runtime_vars_when_env_changed(self):
1699+
"""update() injects RUNPOD_API_KEY into template.env when env changed.
1700+
1701+
Without this, runtime-injected vars (set during _do_deploy) would be
1702+
lost when update() overwrites the template env.
1703+
"""
1704+
old_resource = ServerlessEndpoint(
1705+
name="update-inject-test",
1706+
imageName="test:latest",
1707+
env={"LOG_LEVEL": "INFO"},
1708+
flashboot=False,
1709+
)
1710+
old_resource.id = "ep-inject"
1711+
old_resource.templateId = "tpl-inject"
1712+
1713+
new_resource = ServerlessEndpoint(
1714+
name="update-inject-test",
1715+
imageName="test:latest",
1716+
env={"LOG_LEVEL": "DEBUG"},
1717+
flashboot=False,
1718+
)
1719+
1720+
mock_client = AsyncMock()
1721+
mock_client.save_endpoint = AsyncMock(
1722+
return_value={
1723+
"id": "ep-inject",
1724+
"name": "update-inject-test",
1725+
"templateId": "tpl-inject",
1726+
"gpuIds": "AMPERE_48",
1727+
"allowedCudaVersions": "",
1728+
}
1729+
)
1730+
mock_client.update_template = AsyncMock(return_value={})
1731+
1732+
with patch(
1733+
"runpod_flash.core.resources.serverless.RunpodGraphQLClient"
1734+
) as mock_client_class:
1735+
mock_client_class.return_value.__aenter__.return_value = mock_client
1736+
mock_client_class.return_value.__aexit__.return_value = None
1737+
1738+
with patch.object(
1739+
ServerlessResource,
1740+
"_ensure_network_volume_deployed",
1741+
new=AsyncMock(),
1742+
):
1743+
with patch.object(
1744+
ServerlessResource,
1745+
"_check_makes_remote_calls",
1746+
return_value=True,
1747+
):
1748+
with patch.dict(os.environ, {"RUNPOD_API_KEY": "inject-key"}):
1749+
await old_resource.update(new_resource)
1750+
1751+
template_payload = mock_client.update_template.call_args.args[0]
1752+
env_entries = template_payload.get("env", [])
1753+
api_key_entries = [e for e in env_entries if e["key"] == "RUNPOD_API_KEY"]
1754+
assert len(api_key_entries) == 1
1755+
assert api_key_entries[0]["value"] == "inject-key"
1756+
1757+
@pytest.mark.asyncio
1758+
async def test_update_skips_runtime_injection_when_env_unchanged(self):
1759+
"""update() does not inject runtime vars when env is unchanged.
1760+
1761+
When skip_env=True, the template env payload is omitted entirely,
1762+
so runtime vars already on the platform are preserved as-is.
1763+
"""
1764+
env = {"LOG_LEVEL": "INFO"}
1765+
old_resource = ServerlessEndpoint(
1766+
name="update-no-inject",
1767+
imageName="test:latest",
1768+
env=env,
1769+
flashboot=False,
1770+
)
1771+
old_resource.id = "ep-no-inject"
1772+
old_resource.templateId = "tpl-no-inject"
1773+
1774+
new_resource = ServerlessEndpoint(
1775+
name="update-no-inject",
1776+
imageName="test:latest",
1777+
env=env,
1778+
flashboot=False,
1779+
)
1780+
1781+
mock_client = AsyncMock()
1782+
mock_client.save_endpoint = AsyncMock(
1783+
return_value={
1784+
"id": "ep-no-inject",
1785+
"name": "update-no-inject",
1786+
"templateId": "tpl-no-inject",
1787+
"gpuIds": "AMPERE_48",
1788+
"allowedCudaVersions": "",
1789+
}
1790+
)
1791+
mock_client.update_template = AsyncMock(return_value={})
1792+
1793+
with patch(
1794+
"runpod_flash.core.resources.serverless.RunpodGraphQLClient"
1795+
) as mock_client_class:
1796+
mock_client_class.return_value.__aenter__.return_value = mock_client
1797+
mock_client_class.return_value.__aexit__.return_value = None
1798+
1799+
with patch.object(
1800+
ServerlessResource,
1801+
"_ensure_network_volume_deployed",
1802+
new=AsyncMock(),
1803+
):
1804+
with patch.object(
1805+
ServerlessResource,
1806+
"_check_makes_remote_calls",
1807+
return_value=True,
1808+
):
1809+
with patch.dict(os.environ, {"RUNPOD_API_KEY": "inject-key"}):
1810+
await old_resource.update(new_resource)
1811+
1812+
# env should be omitted from template payload (skip_env=True)
1813+
template_payload = mock_client.update_template.call_args.args[0]
1814+
assert "env" not in template_payload
1815+
1816+
@pytest.mark.asyncio
1817+
async def test_update_includes_env_for_explicit_template_env(self):
1818+
"""update() sends env when caller provides explicit template.env with empty env.
1819+
1820+
Even if self.env == new_config.env (both empty), explicit template.env
1821+
entries must not be silently dropped.
1822+
"""
1823+
old_resource = ServerlessEndpoint(
1824+
name="update-tpl-env",
1825+
imageName="test:latest",
1826+
env={},
1827+
flashboot=False,
1828+
)
1829+
old_resource.id = "ep-tpl-env"
1830+
old_resource.templateId = "tpl-tpl-env"
1831+
1832+
new_resource = ServerlessEndpoint(
1833+
name="update-tpl-env",
1834+
imageName="test:latest",
1835+
env={},
1836+
flashboot=False,
1837+
template=PodTemplate(
1838+
name="explicit-tpl",
1839+
imageName="test:latest",
1840+
env=[KeyValuePair(key="EXPLICIT_VAR", value="explicit_val")],
1841+
),
1842+
)
1843+
1844+
mock_client = AsyncMock()
1845+
mock_client.save_endpoint = AsyncMock(
1846+
return_value={
1847+
"id": "ep-tpl-env",
1848+
"name": "update-tpl-env",
1849+
"templateId": "tpl-tpl-env",
1850+
"gpuIds": "AMPERE_48",
1851+
"allowedCudaVersions": "",
1852+
}
1853+
)
1854+
mock_client.update_template = AsyncMock(return_value={})
1855+
1856+
with patch(
1857+
"runpod_flash.core.resources.serverless.RunpodGraphQLClient"
1858+
) as mock_client_class:
1859+
mock_client_class.return_value.__aenter__.return_value = mock_client
1860+
mock_client_class.return_value.__aexit__.return_value = None
1861+
1862+
with patch.object(
1863+
ServerlessResource,
1864+
"_ensure_network_volume_deployed",
1865+
new=AsyncMock(),
1866+
):
1867+
await old_resource.update(new_resource)
1868+
1869+
template_payload = mock_client.update_template.call_args.args[0]
1870+
assert "env" in template_payload
1871+
env_entries = template_payload["env"]
1872+
explicit = [e for e in env_entries if e["key"] == "EXPLICIT_VAR"]
1873+
assert len(explicit) == 1

0 commit comments

Comments
 (0)