From eda1e24606266f5003722f9f213cf57e1d69eaed Mon Sep 17 00:00:00 2001 From: Srivyshnavi-K Date: Wed, 24 Dec 2025 11:38:53 +0530 Subject: [PATCH 1/3] IE-532: updated scripts, tests to exit in case of failures --- migrate_connections.py | 4 ++-- migrate_pipelines.py | 4 ++-- migrate_policies.py | 4 ++-- migrate_roles.py | 8 ++++---- migrate_service_accounts.py | 4 ++-- tests/test_migrate_connections.py | 6 ++++-- tests/test_migrate_pipelines.py | 6 ++++-- tests/test_migrate_policies.py | 6 ++++-- tests/test_migrate_roles.py | 12 ++++++++---- tests/test_migrate_service_accounts.py | 6 ++++-- tests/test_update_policy_main.py | 6 ++++-- tests/test_update_role.py | 6 ++++-- tests/test_update_service_account.py | 6 ++++-- update_policy.py | 4 ++-- update_role.py | 4 ++-- update_service_account.py | 4 ++-- update_vault_schema.py | 4 ++-- 17 files changed, 56 insertions(+), 38 deletions(-) diff --git a/migrate_connections.py b/migrate_connections.py index 443391c..0356ad0 100644 --- a/migrate_connections.py +++ b/migrate_connections.py @@ -128,10 +128,10 @@ def main(connection_ids=None): print( f"-- migrate_connections HTTP error: {http_err.response.content.decode()} --" ) - raise http_err + exit(1) except Exception as err: print(f"-- migrate_connections other error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/migrate_pipelines.py b/migrate_pipelines.py index 2604262..2e30cf8 100644 --- a/migrate_pipelines.py +++ b/migrate_pipelines.py @@ -219,10 +219,10 @@ def main(pipeline_id: str) -> None: print( f"-- migrate_pipelines HTTP error: {http_err.response.content.decode()} --" ) - raise http_err + exit(1) except Exception as err: print(f"-- migrate_pipelines other error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/migrate_policies.py b/migrate_policies.py index 547ca5b..7e739d0 100644 --- a/migrate_policies.py +++ b/migrate_policies.py @@ -111,10 +111,10 @@ def main(policy_ids=None): return policies_created except requests.exceptions.HTTPError as http_err: print(f'-- migrate_policies HTTP error: {http_err.response.content.decode()} --') - raise http_err + exit(1) except Exception as err: print(f"-- migrate_policies error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/migrate_roles.py b/migrate_roles.py index 0346e87..e17cb33 100644 --- a/migrate_roles.py +++ b/migrate_roles.py @@ -151,7 +151,7 @@ def main(role_ids=None): roles_created.append({"ID" : role_response["roles"][0]["ID"]}) else: print("-- Role does not exist --") - if(should_create_role): + if should_create_role: role_payload = transform_role_payload(role_info) print(f"-- Creating role: {role_name} --") new_role = create_role(role_payload) @@ -160,7 +160,7 @@ def main(role_ids=None): role_policies = get_role_policies(role_id) policy_ids = [policy["ID"] for policy in role_policies["policies"]] no_of_policies = len(policy_ids) - if(no_of_policies == 0): + if no_of_policies == 0: print('-- No policies found for the given role --') else: print(f"-- Working on policies migration. No. of policies found for given role: {no_of_policies} --") @@ -173,10 +173,10 @@ def main(role_ids=None): except requests.exceptions.HTTPError as http_err: print(f'-- Role creation failed for {role_name if role_name else ""}, ID: {role_id}. --') print(f'-- migrate_roles HTTP error: {http_err.response.content.decode()} --') - raise http_err + exit(1) except Exception as err: print(f'-- migrate_roles error: {err} --') - raise err + exit(1) if __name__ == "__main__": main() diff --git a/migrate_service_accounts.py b/migrate_service_accounts.py index 814a4b1..22bf928 100644 --- a/migrate_service_accounts.py +++ b/migrate_service_accounts.py @@ -124,10 +124,10 @@ def main(service_accounts_ids=None): print( f"-- migrate_service_accounts HTTP error: {http_err.response.content.decode()} --" ) - raise http_err + exit(1) except Exception as err: print(f"-- migrate_service_accounts other error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/tests/test_migrate_connections.py b/tests/test_migrate_connections.py index 0095a97..43a90fb 100644 --- a/tests/test_migrate_connections.py +++ b/tests/test_migrate_connections.py @@ -104,8 +104,9 @@ def raise_err(*args, **kwargs): monkeypatch.setattr(mc, "SOURCE_ENV_URL", "https://s") mock_get.side_effect = raise_err - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: mc.main() + assert excinfo.value.code == 1 def test_main_migrate_all_with_source_calls_list(monkeypatch): @@ -132,8 +133,9 @@ def boom(_): raise Exception("boom") monkeypatch.setattr(mc, "get_connection", boom) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: mc.main() + assert excinfo.value.code == 1 def test_run_as_script_config_file(monkeypatch): diff --git a/tests/test_migrate_pipelines.py b/tests/test_migrate_pipelines.py index 6628aca..68026d5 100644 --- a/tests/test_migrate_pipelines.py +++ b/tests/test_migrate_pipelines.py @@ -361,8 +361,9 @@ def raise_http_error(_): monkeypatch.setattr(module, "get_pipeline", raise_http_error) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: module.main("pipeline-http-error") + assert excinfo.value.code == 1 stdout = capsys.readouterr().out assert "HTTP error" in stdout @@ -378,8 +379,9 @@ def boom(_payload): monkeypatch.setattr(module, "create_pipeline", boom) - with pytest.raises(RuntimeError): + with pytest.raises(SystemExit) as excinfo: module.main("pipeline-other-error") + assert excinfo.value.code == 1 stdout = capsys.readouterr().out assert "other error" in stdout diff --git a/tests/test_migrate_policies.py b/tests/test_migrate_policies.py index ec6c147..ad90250 100644 --- a/tests/test_migrate_policies.py +++ b/tests/test_migrate_policies.py @@ -110,8 +110,9 @@ def raise_err(_): raise err monkeypatch.setattr(mp, "get_policy", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: mp.main(policy_ids=["p1"]) + assert excinfo.value.code == 1 def test_main_generic_exception(monkeypatch): @@ -121,8 +122,9 @@ def test_main_generic_exception(monkeypatch): "transform_policy_payload", lambda _: (_ for _ in ()).throw(Exception("boom")), ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: mp.main(policy_ids=["p1"]) + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/tests/test_migrate_roles.py b/tests/test_migrate_roles.py index 0dff650..d5d2e45 100644 --- a/tests/test_migrate_roles.py +++ b/tests/test_migrate_roles.py @@ -175,8 +175,9 @@ def raise_err(*args, **kwargs): monkeypatch.setattr(mr, "TARGET_ENV_URL", "https://t") monkeypatch.setattr(mr, "get_role", lambda _id: role_resp.json()) monkeypatch.setattr(mr, "get_system_role", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: mr.main() + assert excinfo.value.code == 1 @patch("migrate_roles.requests.post") @@ -186,8 +187,9 @@ def test_migrate_all_missing_source_prints(mock_get, mock_post, monkeypatch): monkeypatch.setattr(mr, "SOURCE_VAULT_ID", None, raising=False) monkeypatch.setattr(mr, "ROLE_IDS", "[]", raising=False) # Function prints a message then later attempts to iterate None; assert it errors predictably - with pytest.raises(TypeError): + with pytest.raises(SystemExit) as excinfo: mr.main() + assert excinfo.value.code == 1 @patch("migrate_roles.requests.post") @@ -251,8 +253,9 @@ def raise_err(*args, **kwargs): mock_post.side_effect = raise_err - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: mr.main() + assert excinfo.value.code == 1 @patch("migrate_roles.requests.post") @@ -279,8 +282,9 @@ def test_generic_exception_after_role_name(mock_get, mock_post, monkeypatch): monkeypatch.setattr( mr, "get_role_policies", lambda _id: (_ for _ in ()).throw(Exception("oops")) ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: mr.main() + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/tests/test_migrate_service_accounts.py b/tests/test_migrate_service_accounts.py index 36589f3..306003d 100644 --- a/tests/test_migrate_service_accounts.py +++ b/tests/test_migrate_service_accounts.py @@ -107,8 +107,9 @@ def raise_err(_): monkeypatch.setattr(msa, "SERVICE_ACCOUNT_IDS", "['sa1']", raising=False) monkeypatch.setattr(msa, "get_service_account", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: msa.main() + assert excinfo.value.code == 1 def test_generic_exception_branch(monkeypatch): @@ -132,8 +133,9 @@ def test_generic_exception_branch(monkeypatch): "create_service_account", lambda x: (_ for _ in ()).throw(Exception("boom")), ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: msa.main() + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/tests/test_update_policy_main.py b/tests/test_update_policy_main.py index 916a8c2..9b414af 100644 --- a/tests/test_update_policy_main.py +++ b/tests/test_update_policy_main.py @@ -63,8 +63,9 @@ def raise_err(_): monkeypatch.setattr(up, "TARGET_POLICY_ID", "t1", raising=False) monkeypatch.setattr(up, "get_source_policy", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: up.main() + assert excinfo.value.code == 1 def test_main_missing_inputs(monkeypatch): @@ -95,8 +96,9 @@ def test_main_generic_exception(monkeypatch): "transform_policy_payload", lambda s, t: (_ for _ in ()).throw(Exception("boom")), ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: up.main() + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/tests/test_update_role.py b/tests/test_update_role.py index 4b7ce61..0ca4367 100644 --- a/tests/test_update_role.py +++ b/tests/test_update_role.py @@ -92,8 +92,9 @@ def raise_err(_): raise err monkeypatch.setattr(ur, "update_role", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: ur.main() + assert excinfo.value.code == 1 def test_main_update_metadata_missing_inputs(monkeypatch): @@ -132,8 +133,9 @@ def test_main_generic_exception(monkeypatch): monkeypatch.setattr( ur, "update_role", lambda payload: (_ for _ in ()).throw(Exception("boom")) ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: ur.main() + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/tests/test_update_service_account.py b/tests/test_update_service_account.py index 1a7379a..ed952a7 100644 --- a/tests/test_update_service_account.py +++ b/tests/test_update_service_account.py @@ -112,8 +112,9 @@ def raise_err(_): raise err monkeypatch.setattr(usa, "update_service_account", raise_err) - with pytest.raises(requests.exceptions.HTTPError): + with pytest.raises(SystemExit) as excinfo: usa.main() + assert excinfo.value.code == 1 def test_main_update_metadata_missing_inputs(monkeypatch): @@ -164,8 +165,9 @@ def test_main_generic_exception(monkeypatch): "update_service_account", lambda payload: (_ for _ in ()).throw(Exception("boom")), ) - with pytest.raises(Exception): + with pytest.raises(SystemExit) as excinfo: usa.main() + assert excinfo.value.code == 1 def test_run_as_script(monkeypatch): diff --git a/update_policy.py b/update_policy.py index 1ea543f..e98a6b6 100644 --- a/update_policy.py +++ b/update_policy.py @@ -151,10 +151,10 @@ def main(): print("-- Please provide valid input. Missing input paramaters. --") except requests.exceptions.HTTPError as http_err: print(f"-- update_policy HTTP error: {http_err.response.content.decode()} --") - raise http_err + exit(1) except Exception as err: print(f"-- update_policy error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/update_role.py b/update_role.py index 4f2642b..357555d 100644 --- a/update_role.py +++ b/update_role.py @@ -103,10 +103,10 @@ def main(): print(f"-- Role {TARGET_ROLE_ID} updated successfully. --") except requests.exceptions.HTTPError as http_err: print(f"-- update_role HTTP error: {http_err.response.content.decode()} --") - raise http_err + exit(1) except Exception as err: print(f"-- update_role error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/update_service_account.py b/update_service_account.py index b79eb5e..ec30e23 100644 --- a/update_service_account.py +++ b/update_service_account.py @@ -128,10 +128,10 @@ def main(): print( f"-- update_service_account HTTP error: {http_err.response.content.decode()} --" ) - raise http_err + exit(1) except Exception as err: print(f"-- update_service_account error: {err} --") - raise err + exit(1) if __name__ == "__main__": diff --git a/update_vault_schema.py b/update_vault_schema.py index 0d1d0f7..b92ca16 100644 --- a/update_vault_schema.py +++ b/update_vault_schema.py @@ -22,8 +22,8 @@ "Content-Type": "application/json", } -def get_vault_details(vaultID: str): - response = requests.get(f"{SOURCE_ENV_URL}/v1/vaults/{vaultID}", headers=SOURCE_ACCOUNT_HEADERS) +def get_vault_details(vault_id: str): + response = requests.get(f"{SOURCE_ENV_URL}/v1/vaults/{vault_id}", headers=SOURCE_ACCOUNT_HEADERS) response.raise_for_status() return response.json() From 34d57e64233ddc7e4c317f2eb6e55144dc8c16bb Mon Sep 17 00:00:00 2001 From: Srivyshnavi-K Date: Wed, 24 Dec 2025 11:39:37 +0530 Subject: [PATCH 2/3] IE-532: updated coverage to be 98% --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 9aea2d3..4c9b079 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] -addopts = -q --cov=. --cov-report=term-missing --cov-fail-under=95 +addopts = -q --cov=. --cov-report=term-missing --cov-fail-under=98 testpaths = tests From 728e72e17bf8d8d98f12ba2218521bec327fecb8 Mon Sep 17 00:00:00 2001 From: Srivyshnavi-K Date: Wed, 24 Dec 2025 11:41:10 +0530 Subject: [PATCH 3/3] IE-532: added update connection script, tests and workflow --- .github/workflows/update_connection.yml | 85 ++++++++++++ tests/test_update_connection.py | 174 ++++++++++++++++++++++++ update_connection.py | 84 ++++++++++++ 3 files changed, 343 insertions(+) create mode 100644 .github/workflows/update_connection.yml create mode 100644 tests/test_update_connection.py create mode 100644 update_connection.py diff --git a/.github/workflows/update_connection.yml b/.github/workflows/update_connection.yml new file mode 100644 index 0000000..2a1f4b0 --- /dev/null +++ b/.github/workflows/update_connection.yml @@ -0,0 +1,85 @@ +name: update_connection + +on: + workflow_dispatch: + inputs: + env_url: + description: "Select source and target env's" + type: choice + default: "Source: SANDBOX, Target: PRODUCTION" + options: + - "Source: SANDBOX, Target: PRODUCTION" + - "Source: SANDBOX, Target: SANDBOX" + - "Source: PRODUCTION, Target: PRODUCTION" + - "Source: PRODUCTION, Target: SANDBOX" + source_connection_id: + description: "Source Connection ID" + required: false + target_connection_id: + description: "Target Connection ID" + required: true + source_account_access_token: + description: "Access token of the Source Account" + required: false + target_account_access_token: + description: "Access token of the Target Account" + required: true + source_account_id: + description: "Source Account ID. If not provided, will use the repository variable" + required: false + target_account_id: + description: "Target Account ID. If not provided, will use the repository variable" + required: false + + +jobs: + execute-update-connection-script: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install dependencies + run: pip install requests + + - name: Parse and map environment URLs + id: map_envs + shell: bash + run: | + input="${{ github.event.inputs.env_url }}" + + source_name=$(echo "$input" | sed -n 's/Source: \([^,]*\),.*/\1/p' | xargs) + target_name=$(echo "$input" | sed -n 's/.*Target: \(.*\)/\1/p' | xargs) + + get_env_url() { + case "$1" in + SANDBOX) echo "https://manage.skyflowapis-preview.com" ;; + PRODUCTION) echo "https://manage.skyflowapis.com" ;; + *) echo "Invalid environment: $1" >&2; exit 1 ;; + esac + } + + # Resolve URLs + source_url=$(get_env_url "$source_name") + target_url=$(get_env_url "$target_name") + + echo "source_url=$source_url" >> $GITHUB_OUTPUT + echo "target_url=$target_url" >> $GITHUB_OUTPUT + + - name: Run Python script + env: + SOURCE_CONNECTION_ID: ${{ github.event.inputs.source_connection_id }} + TARGET_CONNECTION_ID: ${{ github.event.inputs.target_connection_id }} + SOURCE_ACCOUNT_AUTH: ${{ github.event.inputs.source_account_access_token }} + TARGET_ACCOUNT_AUTH: ${{ github.event.inputs.target_account_access_token }} + SOURCE_ACCOUNT_ID: ${{ github.event.inputs.source_account_id != '' && github.event.inputs.source_account_id || vars.SOURCE_ACCOUNT_ID }} + TARGET_ACCOUNT_ID: ${{ github.event.inputs.target_account_id != '' && github.event.inputs.target_account_id || vars.TARGET_ACCOUNT_ID }} + SOURCE_ENV_URL: ${{ steps.map_envs.outputs.source_url }} + TARGET_ENV_URL: ${{ steps.map_envs.outputs.target_url }} + run: python3 update_connection.py diff --git a/tests/test_update_connection.py b/tests/test_update_connection.py new file mode 100644 index 0000000..d7459f6 --- /dev/null +++ b/tests/test_update_connection.py @@ -0,0 +1,174 @@ +import runpy +from unittest.mock import MagicMock, patch + +import pytest +import requests + +import update_connection as uc + + +def _build_resp(payload): + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = payload + return resp + + +@patch("update_connection.requests.get") +def test_get_connection_fetches_resource(mock_get): + mock_get.return_value = _build_resp({"ID": "123"}) + out = uc.get_connection("123", "https://src", {"h": "v"}) + + assert out["ID"] == "123" + mock_get.assert_called_once_with( + "https://src/v1/gateway/inboundRoutes/123", headers={"h": "v"} + ) + + +@patch("update_connection.requests.put") +def test_update_connection_put(mock_put): + uc.TARGET_ENV_URL = "https://target.test.com" + uc.TARGET_ACCOUNT_HEADERS = {"h": "t"} + mock_put.return_value = _build_resp({"ID": "target"}) + payload = {"ID": "target", "name": "foo", "mode": "INGRESS"} + out = uc.update_connection("target", payload) + + assert out["ID"] == "target" + mock_put.assert_called_once_with( + "https://target.test.com/v1/gateway/inboundRoutes/target", + json=payload, + headers={"h": "t"}, + ) + + +def test_transform_connection_payload_strips_fields(): + source = { + "ID": "source", + "vaultID": "sourceVault", + "BasicAudit": {"foo": "bar"}, + "routes": [{"name": "r1", "invocationURL": "https://invoke"}], + } + target = {"ID": "target", "vaultID": "targetVault"} + + result = uc.transform_connection_payload(source, target) + assert result["ID"] == "target" + assert result["vaultID"] == "targetVault" + assert "BasicAudit" not in result + assert "invocationURL" not in result["routes"][0] + # ensure source dict left untouched + assert "BasicAudit" in source + assert "invocationURL" in source["routes"][0] + + +@patch("update_connection.requests.put") +@patch("update_connection.requests.get") +def test_main_happy_path(mock_get, mock_put, monkeypatch): + monkeypatch.setattr(uc, "SOURCE_CONNECTION_ID", "s1", raising=False) + monkeypatch.setattr(uc, "TARGET_CONNECTION_ID", "t1", raising=False) + monkeypatch.setattr(uc, "SOURCE_ENV_URL", "https://source", raising=False) + monkeypatch.setattr(uc, "TARGET_ENV_URL", "https://target", raising=False) + monkeypatch.setattr(uc, "SOURCE_ACCOUNT_HEADERS", {"h": "s"}, raising=False) + monkeypatch.setattr(uc, "TARGET_ACCOUNT_HEADERS", {"h": "t"}, raising=False) + + mock_get.side_effect = [ + _build_resp( + { + "ID": "s1", + "vaultID": "v1", + "mode": "INGRESS", + "BasicAudit": {}, + "routes": [{"invocationURL": "https://invoke"}], + } + ), + _build_resp({"ID": "t1", "vaultID": "v2", "routes": [{}]}), + ] + mock_put.return_value = _build_resp({"ID": "t1"}) + + uc.main() + + _, kwargs = mock_put.call_args + assert kwargs["json"]["ID"] == "t1" + assert "BasicAudit" not in kwargs["json"] + assert kwargs["json"]["vaultID"] == "v2" + + +def test_main_missing_ids(monkeypatch, capsys): + monkeypatch.setattr(uc, "SOURCE_CONNECTION_ID", None, raising=False) + monkeypatch.setattr(uc, "TARGET_CONNECTION_ID", None, raising=False) + uc.main() + captured = capsys.readouterr().out + assert "Missing connection IDs" in captured + + +def test_main_http_error(monkeypatch): + class Resp: + content = b"fail" + + err = requests.exceptions.HTTPError(response=Resp()) + + monkeypatch.setattr(uc, "SOURCE_CONNECTION_ID", "s1", raising=False) + monkeypatch.setattr(uc, "TARGET_CONNECTION_ID", "t1", raising=False) + monkeypatch.setattr(uc, "SOURCE_ENV_URL", "https://source", raising=False) + monkeypatch.setattr(uc, "TARGET_ENV_URL", "https://target", raising=False) + monkeypatch.setattr(uc, "SOURCE_ACCOUNT_HEADERS", {}, raising=False) + monkeypatch.setattr(uc, "TARGET_ACCOUNT_HEADERS", {}, raising=False) + monkeypatch.setattr( + uc, + "get_connection", + lambda *_args, **_kwargs: {"ID": "x", "vaultID": "y", "routes": []}, + ) + monkeypatch.setattr( + uc, + "update_connection", + lambda *_args, **_kwargs: (_ for _ in ()).throw(err), + raising=False, + ) + + with pytest.raises(SystemExit) as excinfo: + uc.main() + assert excinfo.value.code == 1 + + +def test_main_generic_exception(monkeypatch): + monkeypatch.setattr(uc, "SOURCE_CONNECTION_ID", "s1", raising=False) + monkeypatch.setattr(uc, "TARGET_CONNECTION_ID", "t1", raising=False) + monkeypatch.setattr(uc, "SOURCE_ENV_URL", "https://source", raising=False) + monkeypatch.setattr(uc, "TARGET_ENV_URL", "https://target", raising=False) + monkeypatch.setattr(uc, "SOURCE_ACCOUNT_HEADERS", {}, raising=False) + monkeypatch.setattr(uc, "TARGET_ACCOUNT_HEADERS", {}, raising=False) + monkeypatch.setattr( + uc, + "get_connection", + lambda *_args, **_kwargs: {"ID": "x", "vaultID": "y", "routes": []}, + ) + monkeypatch.setattr( + uc, + "transform_connection_payload", + lambda *_: (_ for _ in ()).throw(Exception("boom")), + ) + + with pytest.raises(SystemExit) as excinfo: + uc.main() + assert excinfo.value.code == 1 + + +@patch("update_connection.requests.put") +@patch("update_connection.requests.get") +def test_run_as_script(mock_get, mock_put, monkeypatch): + monkeypatch.setenv("SOURCE_CONNECTION_ID", "s1") + monkeypatch.setenv("TARGET_CONNECTION_ID", "t1") + monkeypatch.setenv("SOURCE_ENV_URL", "https://source") + monkeypatch.setenv("TARGET_ENV_URL", "https://target") + monkeypatch.setenv("SOURCE_ACCOUNT_ID", "src") + monkeypatch.setenv("TARGET_ACCOUNT_ID", "tgt") + monkeypatch.setenv("SOURCE_ACCOUNT_AUTH", "sa") + monkeypatch.setenv("TARGET_ACCOUNT_AUTH", "ta") + + mock_get.side_effect = [ + _build_resp({"ID": "s1", "vaultID": "v1", "mode": "INGRESS", "routes": []}), + _build_resp({"ID": "t1", "vaultID": "v2", "routes": []}), + ] + mock_put.return_value = _build_resp({"ID": "t1"}) + + runpy.run_module("update_connection", run_name="__main__") + assert mock_put.called diff --git a/update_connection.py b/update_connection.py new file mode 100644 index 0000000..e09fb41 --- /dev/null +++ b/update_connection.py @@ -0,0 +1,84 @@ +import copy +import os +import requests + +SOURCE_CONNECTION_ID = os.getenv("SOURCE_CONNECTION_ID") +TARGET_CONNECTION_ID = os.getenv("TARGET_CONNECTION_ID") +SOURCE_ACCOUNT_ID = os.getenv("SOURCE_ACCOUNT_ID") +TARGET_ACCOUNT_ID = os.getenv("TARGET_ACCOUNT_ID") +SOURCE_ACCOUNT_AUTH = os.getenv("SOURCE_ACCOUNT_AUTH") +TARGET_ACCOUNT_AUTH = os.getenv("TARGET_ACCOUNT_AUTH") +SOURCE_ENV_URL = os.getenv("SOURCE_ENV_URL") +TARGET_ENV_URL = os.getenv("TARGET_ENV_URL") + +SOURCE_ACCOUNT_HEADERS = { + "X-SKYFLOW-ACCOUNT-ID": SOURCE_ACCOUNT_ID, + "Authorization": f"Bearer {SOURCE_ACCOUNT_AUTH}", + "Content-Type": "application/json", +} + +TARGET_ACCOUNT_HEADERS = { + "X-SKYFLOW-ACCOUNT-ID": TARGET_ACCOUNT_ID, + "Authorization": f"Bearer {TARGET_ACCOUNT_AUTH}", + "Content-Type": "application/json", +} + +def get_connection(connection_id: str, env_url: str, headers: dict) -> dict: + response = requests.get( + f"{env_url}/v1/gateway/inboundRoutes/{connection_id}", headers=headers + ) + response.raise_for_status() + return response.json() + +def update_connection(connection_id: str, connection_payload: dict): + mode = 'inboundRoutes' if connection_payload["mode"] == "INGRESS" else 'outboundRoutes' + response = requests.put( + f"{TARGET_ENV_URL}/v1/gateway/{mode}/{connection_id}", + json=connection_payload, + headers=TARGET_ACCOUNT_HEADERS, + ) + response.raise_for_status() + return response.json() + +def transform_connection_payload(source_connection: dict, target_connection: dict): + transformed_connection = copy.deepcopy(source_connection) + transformed_connection["ID"] = target_connection["ID"] + transformed_connection["vaultID"] = target_connection["vaultID"] + transformed_connection.pop("BasicAudit", None) + + for route in transformed_connection.get("routes", []): + route.pop("invocationURL", None) + + return transformed_connection + +def main(): + try: + if not SOURCE_CONNECTION_ID or not TARGET_CONNECTION_ID: + print("-- Please provide valid input. Missing connection IDs --") + return + + print(f"-- Fetching source connection details:{SOURCE_CONNECTION_ID} --") + source_connection = get_connection( + SOURCE_CONNECTION_ID, SOURCE_ENV_URL, SOURCE_ACCOUNT_HEADERS + ) + print(f"-- Fetching target connection details:{TARGET_CONNECTION_ID} --") + target_connection = get_connection( + TARGET_CONNECTION_ID, TARGET_ENV_URL, TARGET_ACCOUNT_HEADERS + ) + print("-- Working on updating connection in target account --") + connection_payload = transform_connection_payload( + source_connection, target_connection + ) + update_response = update_connection(TARGET_CONNECTION_ID, connection_payload) + print( + f"-- Connection updated successfully. Source CONNECTION_ID: {SOURCE_CONNECTION_ID}. Target CONNECTION_ID: {update_response['ID']} --" + ) + except requests.exceptions.HTTPError as http_err: + print(f"-- update_connection HTTP error: {http_err.response.content.decode()} --") + exit(1) + except Exception as err: + print(f"-- update_connection error: {err} --") + exit(1) + +if __name__ == "__main__": + main()