diff --git a/runpod/api/ctl_commands.py b/runpod/api/ctl_commands.py index 73355cb0..adfe7ff8 100644 --- a/runpod/api/ctl_commands.py +++ b/runpod/api/ctl_commands.py @@ -88,7 +88,7 @@ def get_pod(pod_id: str): def create_pod( name: str, - image_name: str, + image_name: Optional[str] = "", gpu_type_id: Optional[str] = None, cloud_type: str = "ALL", support_public_ip: bool = True, @@ -141,6 +141,10 @@ def create_pod( >>> pod_id = runpod.create_pod("test", "runpod/stack", instance_id="cpu3c-2-4") """ # Input Validation + + if not image_name and not template_id: + raise ValueError("Either image_name or template_id must be provided") + if gpu_type_id is not None: get_gpu(gpu_type_id) # Check if GPU exists, will raise ValueError if not. if cloud_type not in ["ALL", "COMMUNITY", "SECURE"]: diff --git a/tests/test_api/test_ctl_commands.py b/tests/test_api/test_ctl_commands.py index a24da44f..2b2e4301 100644 --- a/tests/test_api/test_ctl_commands.py +++ b/tests/test_api/test_ctl_commands.py @@ -136,6 +136,18 @@ def test_create_pod(self): "cloud_type must be one of ALL, COMMUNITY or SECURE", ) + with self.assertRaises(ValueError) as context: + pod = ctl_commands.create_pod( + name="POD_NAME", + gpu_type_id="NVIDIA A100 80GB PCIe", + network_volume_id="NETWORK_VOLUME_ID", + ) + + self.assertEqual( + str(context.exception), + "Either image_name or template_id must be provided", + ) + def test_stop_pod(self): """ Test stop_pod