Skip to content

Commit 85b370d

Browse files
committed
feat(cli): support explicit CDI device names via --gpu
Explicit CDI device IDs can now be passed: --gpu=nvidia.com/gpu=all single CDI device --gpu=nvidia.com/gpu=0 --gpu=nvidia.com/gpu=1 multiple CDI devices parse_gpu_flag validates the input and rejects mixing legacy/auto with CDI device names or specifying them more than once. Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent adcb3bd commit 85b370d

File tree

3 files changed

+92
-15
lines changed

3 files changed

+92
-15
lines changed

architecture/gateway-single-node.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ The `--gpu` flag on `gateway start` accepts an optional value that overrides the
324324
|---|---|
325325
| `--gpu` | Auto-select: CDI on Docker >= 28.2.0, `--gpus all` otherwise |
326326
| `--gpu=legacy` | Force `--gpus all` |
327+
| `--gpu=<cdi-device>` | Inject a specific CDI device (e.g. `nvidia.com/gpu=all`). May be repeated for multiple devices. Note: because the cluster container runs privileged, device-level isolation may not work as expected. |
328+
329+
Mixing `legacy` or auto-select with explicit CDI device names in the same invocation is an error.
327330

328331
The expected smoke test is a plain pod requesting `nvidia.com/gpu: 1` with `runtimeClassName: nvidia` and running `nvidia-smi`.
329332

crates/openshell-cli/src/main.rs

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -790,10 +790,19 @@ enum GatewayCommands {
790790
///
791791
/// An optional argument controls the injection mode:
792792
///
793-
/// --gpu Auto-select: CDI on Docker >= 28.2.0, legacy otherwise
794-
/// --gpu=legacy Force legacy nvidia DeviceRequest
795-
#[arg(long = "gpu", num_args = 0..=1, default_missing_value = "auto", value_name = "MODE")]
796-
gpu: Option<String>,
793+
/// --gpu Auto-select: CDI on Docker >= 28.2.0, legacy otherwise
794+
/// --gpu=legacy Force legacy nvidia DeviceRequest (specify once only)
795+
/// --gpu=<cdi-id> Use explicit CDI device name (repeatable)
796+
///
797+
/// Example CDI device names: `nvidia.com/gpu=all`, `nvidia.com/gpu=0`
798+
#[arg(
799+
long = "gpu",
800+
num_args = 0..=1,
801+
default_missing_value = "auto",
802+
action = clap::ArgAction::Append,
803+
value_name = "MODE",
804+
)]
805+
gpu: Vec<String>,
797806
},
798807

799808
/// Stop the gateway (preserves state).
@@ -1408,6 +1417,29 @@ enum ForwardCommands {
14081417
List,
14091418
}
14101419

1420+
/// Validate and normalise the raw values collected from `--gpu`.
1421+
///
1422+
/// | Input | Output |
1423+
/// |-------------------|---------------------------------|
1424+
/// | `[]` | `[]` — no GPU |
1425+
/// | `["auto"]` | `["auto"]` — resolve at deploy |
1426+
/// | `["legacy"]` | `["legacy"]` |
1427+
/// | `[cdi-ids…]` | `[cdi-ids…]` |
1428+
///
1429+
/// Returns an error when `legacy` or `auto` is mixed with other values, or
1430+
/// appears more than once.
1431+
fn parse_gpu_flag(values: &[String]) -> Result<Vec<String>> {
1432+
match values {
1433+
[] => Ok(vec![]),
1434+
[v] if v == "auto" || v == "legacy" => Ok(values.to_vec()),
1435+
ids if ids.iter().all(|v| v != "auto" && v != "legacy") => Ok(ids.to_vec()),
1436+
_ => Err(miette::miette!(
1437+
"--gpu=legacy and --gpu=auto can only be specified once \
1438+
and cannot be mixed with CDI device names"
1439+
)),
1440+
}
1441+
}
1442+
14111443
#[tokio::main]
14121444
async fn main() -> Result<()> {
14131445
// Install the rustls crypto provider before completion runs — completers may
@@ -1456,16 +1488,7 @@ async fn main() -> Result<()> {
14561488
registry_token,
14571489
gpu,
14581490
} => {
1459-
let gpu = match gpu.as_deref() {
1460-
None => vec![],
1461-
Some("auto") => vec!["auto".to_string()],
1462-
Some("legacy") => vec!["legacy".to_string()],
1463-
Some(other) => {
1464-
return Err(miette::miette!(
1465-
"unknown --gpu value: {other:?}; expected `legacy`"
1466-
));
1467-
}
1468-
};
1491+
let gpu = parse_gpu_flag(&gpu)?;
14691492
run::gateway_admin_deploy(
14701493
&name,
14711494
remote.as_deref(),
@@ -2818,4 +2841,55 @@ mod tests {
28182841
other => panic!("expected SshProxy, got: {other:?}"),
28192842
}
28202843
}
2844+
2845+
// --- parse_gpu_flag ---
2846+
2847+
#[test]
2848+
fn parse_gpu_empty_returns_empty() {
2849+
assert_eq!(parse_gpu_flag(&[]).unwrap(), Vec::<String>::new());
2850+
}
2851+
2852+
#[test]
2853+
fn parse_gpu_auto_accepted() {
2854+
assert_eq!(parse_gpu_flag(&["auto".to_string()]).unwrap(), vec!["auto"]);
2855+
}
2856+
2857+
#[test]
2858+
fn parse_gpu_legacy_accepted() {
2859+
assert_eq!(
2860+
parse_gpu_flag(&["legacy".to_string()]).unwrap(),
2861+
vec!["legacy"]
2862+
);
2863+
}
2864+
2865+
#[test]
2866+
fn parse_gpu_cdi_device_ids_accepted() {
2867+
assert_eq!(
2868+
parse_gpu_flag(&["nvidia.com/gpu=all".to_string()]).unwrap(),
2869+
vec!["nvidia.com/gpu=all"],
2870+
);
2871+
assert_eq!(
2872+
parse_gpu_flag(&[
2873+
"nvidia.com/gpu=0".to_string(),
2874+
"nvidia.com/gpu=1".to_string()
2875+
])
2876+
.unwrap(),
2877+
vec!["nvidia.com/gpu=0", "nvidia.com/gpu=1"],
2878+
);
2879+
}
2880+
2881+
#[test]
2882+
fn parse_gpu_legacy_mixed_with_cdi_is_error() {
2883+
assert!(parse_gpu_flag(&["legacy".to_string(), "nvidia.com/gpu=all".to_string()]).is_err());
2884+
}
2885+
2886+
#[test]
2887+
fn parse_gpu_auto_mixed_with_cdi_is_error() {
2888+
assert!(parse_gpu_flag(&["auto".to_string(), "nvidia.com/gpu=all".to_string()]).is_err());
2889+
}
2890+
2891+
#[test]
2892+
fn parse_gpu_double_legacy_is_error() {
2893+
assert!(parse_gpu_flag(&["legacy".to_string(), "legacy".to_string()]).is_err());
2894+
}
28212895
}

docs/sandboxes/manage-gateways.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ $ openshell gateway info --name my-remote-cluster
168168

169169
| Flag | Purpose |
170170
|---|---|
171-
| `--gpu` | Enable NVIDIA GPU passthrough. Requires NVIDIA drivers and the Container Toolkit on the host. Accepts an optional value: omit for auto-select (CDI on Docker >= 28.2.0, `--gpus all` otherwise), or `--gpu=legacy` to force `--gpus all`. |
171+
| `--gpu` | Enable NVIDIA GPU passthrough. Requires NVIDIA drivers and the Container Toolkit on the host. Accepts an optional value: omit for auto-select (CDI on Docker >= 28.2.0, `--gpus all` otherwise), `--gpu=legacy` to force `--gpus all`, or `--gpu=<cdi-device>` to inject a specific CDI device (e.g. `nvidia.com/gpu=all`). May be repeated for multiple CDI devices. |
172172
| `--plaintext` | Listen on HTTP instead of mTLS. Use behind a TLS-terminating reverse proxy. |
173173
| `--disable-gateway-auth` | Skip mTLS client certificate checks. Use when a reverse proxy cannot forward client certs. |
174174
| `--registry-username` | Username for registry authentication. Defaults to `__token__` when `--registry-token` is set. Only needed for private registries. Also configurable with `OPENSHELL_REGISTRY_USERNAME`. |

0 commit comments

Comments
 (0)