Skip to content

feat(sagemaker-ai): Add SageMaker AI plugin with model customization and HyperPod skills#109

Open
grraman wants to merge 2 commits intoawslabs:mainfrom
grraman:feature/sagemaker-ai
Open

feat(sagemaker-ai): Add SageMaker AI plugin with model customization and HyperPod skills#109
grraman wants to merge 2 commits intoawslabs:mainfrom
grraman:feature/sagemaker-ai

Conversation

@grraman
Copy link
Copy Markdown

@grraman grraman commented Mar 27, 2026

DO NOT MERGE [This is a PR for early feedback]

Adds a new sagemaker-ai plugin that equips AI coding agents with skills for SageMaker model customization (fine-tuning, evaluation, deployment) and HyperPod cluster operations.

What's included

Plugin infrastructure:

  • Plugin manifest (.claude-plugin/plugin.json) and MCP server config (.mcp.json)

Model customization skills (end-to-end workflow):

  • planning — Discovers user intent and generates a structured plan that orchestrates other skills
  • use-case-specification — Captures business problem, stakeholders, and success criteria per the AWS Responsible AI Lens
  • finetuning-setup — Guides technique selection (SFT, DPO, RLVR) and base model choice
  • dataset-evaluation — Validates dataset format and quality with an auto-detection script
  • dataset-transformation — Converts datasets to the required format for the selected technique/model
  • finetuning — Generates Jupyter notebooks for SageMaker Serverless Model Customization training jobs
  • model-evaluation — Generates notebooks for LLM-as-a-Judge evaluation with built-in and custom metrics
  • model-deployment — Generates notebooks to deploy fine-tuned models to SageMaker endpoints or Bedrock (Nova and OSS pathways)
  • directory-management — Manages project directory structure for generated artifacts

HyperPod cluster operations skills:

  • hyperpod-ssm — Remote command execution and file transfer on HyperPod nodes via SSM
  • hyperpod-issue-report — Collects diagnostic reports across cluster nodes and uploads to S3
  • hyperpod-version-checker — Detects and compares software component versions (NVIDIA drivers, CUDA, NCCL, EFA, Neuron SDK, etc.) across nodes

Acknowledgement

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of the project license.

Copy link
Copy Markdown

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semgrep OSS found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@theagenticguy theagenticguy added the do-not-merge Do not merge the pull request label Mar 28, 2026
Copy link
Copy Markdown
Contributor

@theagenticguy theagenticguy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review Summary

Strong plugin with real depth — 13 skills covering the full SageMaker model customization lifecycle plus HyperPod cluster ops. The planning skill with routing constraints is a thoughtful orchestration pattern. A few items need attention before this is merge-ready.

Blockers (4)

  1. Python indentation errors in notebook templates — The matplotlib plotting code in dpo_example.md, sft_example.md, and rlvr_example.md has broken indentation inside for loops. Lines after for idx, metric in enumerate(metrics): are not indented. Every generated notebook will fail with IndentationError.

  2. Script path mismatchdataset-evaluation/SKILL.md line 54 references python src/format_detector.py but the actual file is at scripts/format_detector.py.

  3. Private API accessdeploy-oss-bedrock.py line 43 sets builder._bedrock_client directly. This bypasses the public API and will break silently when the SDK renames or restructures the private attribute.

  4. Semgrep findings — 20+ Semgrep findings need triage (fix or justify).

Convention Deviations

  • Python scripts: All 6 existing plugins use bash scripts exclusively. This PR introduces 10 Python scripts (1,430 lines for hyperpod_issue_report.py alone). This needs an explicit maintainer decision on whether the repo accepts Python as a plugin scripting language. If yes, consider PEP 723 inline metadata + uv run instead of requirements.txt — scripts become self-contained with no separate dependency file:

    # /// script
    # requires-python = ">=3.8"
    # dependencies = ["boto3>=1.26.0", "pexpect>=4.8.0"]
    # ///

    Then: uv run scripts/hyperpod_issue_report.py --cluster ... (no pip install step).

  • requirements.txt: First in the repo. See PEP 723 suggestion above.

  • fs_write references: notebook_writing_guide.md and notebook_structure.md reference fs_write (Amazon Q Developer tool). Claude Code uses Write. If this plugin targets Claude Code, these need updating.

Risks

  • ssm-exec.sh line 68: sed-based JSON escaping fallback (when jq absent) is a command injection surface. Consider making jq a hard prerequisite or using a Python one-liner fallback.
  • deploy-oss-bedrock.py Cell 4: infinite while True polling loop with no timeout guard. If the job hangs in a non-terminal state, the notebook cell blocks forever.
  • SKILL.md descriptions are 3-4x longer than any existing plugin's — may affect skill menu UX and triggering accuracy.

What's Good

  • Well-structured skill orchestration with explicit routing constraints (skill-routing-constraints.md)
  • Production-grade HyperPod diagnostic tooling (SSM scripts, issue report collector, version checker)
  • Consistent notebook generation pattern across all skills
  • Proper EULA/license gates — never auto-accepts
  • Real validation scripts with structured error reporting
  • set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN) used consistently for telemetry

cc @justintlewis @scottschreckengaust — the Python scripts convention question needs a maintainer call.

metrics = ["loss_per_batch", "rewards/chosen", "rewards/rejected", "rewards/margins", "acc_per_batch"]
fig, axes = plt.subplots(1, len(metrics), figsize=(4 * len(metrics), 3))
for idx, metric in enumerate(metrics):
history = client.get_metric_history(run_id, metric)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Indentation error — The lines inside this for loop are not indented. This will produce IndentationError when the notebook runs.

for idx, metric in enumerate(metrics):
    history = client.get_metric_history(run_id, metric)  # needs 4-space indent
    axes[idx].plot(...)  # needs 4-space indent
    ...

Same issue exists in sft_example.md and rlvr_example.md.

```bash
# With the file path argument identified in workflow step 1
python src/format_detector.py local_path/to/dataset
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Wrong path — This references python src/format_detector.py but the actual file is at scripts/format_detector.py.

-python src/format_detector.py local_path/to/dataset
+python scripts/format_detector.py local_path/to/dataset

training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME, region=REGION)
builder = BedrockModelBuilder(model=training_job)
builder._bedrock_client = boto3.client("bedrock", region_name=REGION)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile: Private API access — Setting builder._bedrock_client directly bypasses the public API. This will break silently when the SDK renames or restructures the private attribute.

Is there a public parameter on BedrockModelBuilder to pass a region-specific client? If not, this should be filed as a feature request against the SageMaker SDK, and the workaround documented with a comment explaining why it's necessary.

#!/usr/bin/env python3
"""
HyperPod Issue Report Collector

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convention question for maintainers (@justintlewis @scottschreckengaust):

All 6 existing plugins in this repo use bash scripts exclusively. This PR introduces 10 Python files including this 1,430-line script with external dependencies (pexpect, boto3).

Is the repo ready to accept Python as a scripting language for plugins? If so, consider:

  1. PEP 723 inline metadata + uv run instead of requirements.txt — makes scripts self-contained with no separate dependency file:

    # /// script
    # requires-python = ">=3.8"
    # dependencies = ["boto3>=1.26.0", "pexpect>=4.8.0"]
    # ///

    Then: uv run scripts/hyperpod_issue_report.py --cluster ... (no pip install step)

  2. Updating the contributing guidelines to document Python as an accepted language

  3. Adding Python linting to CI (the repo already has dprint for markdown/JSON)

@@ -0,0 +1,3 @@
boto3>=1.26.0
botocore>=1.29.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First requirements.txt in the repo. Consider PEP 723 inline script metadata instead — keeps the dependency declaration inside the script itself, so there's no separate file to fall out of sync:

# /// script
# requires-python = ">=3.8"
# dependencies = [
#   "boto3>=1.26.0",
#   "botocore>=1.29.0",
#   "pexpect>=4.8.0",
# ]
# ///

Then the SKILL.md can instruct: uv run scripts/hyperpod_issue_report.py --cluster ... (no pip install step needed).

## The Solution: Use fs_write with JSON Structure

**ALWAYS use the `fs_write` tool with `command: create` to write notebooks.**

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong tool namefs_write with command: create is an Amazon Q Developer tool. Claude Code (which this plugin targets via .claude-plugin/plugin.json) uses a Write tool. This reference will confuse the agent at runtime.

Same issue in notebook_structure.md.

bedrock = boto3.client("bedrock", region_name=REGION)

while True:
resp = bedrock.get_model_import_job(jobIdentifier=job_arn)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No timeout on polling loop — This while True loop has no max-iterations or timeout guard. If the Bedrock import job hangs in a non-terminal state, the notebook cell blocks forever.

Suggestion: add a timeout similar to the retry pattern in Cell 5:

for _ in range(120):  # 60 min max
    ...
    time.sleep(30)
else:
    print("Import did not complete within 60 minutes.")

else
local escaped
escaped=$(printf '%s' "$cmd" | sed 's/\\/\\\\/g; s/"/\\"/g; s/\t/\\t/g')
printf '{"command":["%s"]}\n' "$escaped"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile JSON escaping — The sed-based fallback when jq is not available does manual string escaping. This doesn't handle newlines, backslashes-before-quotes, or unicode correctly — historically a source of injection bugs.

Consider either:

  1. Making jq a hard prerequisite (it's available on all HyperPod AMIs)
  2. Using Python as fallback: python3 -c "import json,sys; print(json.dumps({'command': [sys.argv[1]]}))" "$cmd"

;;
upload)
ENCODED=$(b64_encode "$LOCAL_PATH")
# Compress large files to stay within SSM command limits (~64KB)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shell injection in ssm-exec.sh upload mode — REMOTE_PATH is user-controlled and interpolated unquoted into the command string sent to remote nodes

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code (ssm-exec.sh:~line 60-62):
json_cmd "bash -c 'echo ${ENCODED} | base64 -d | gunzip > ${REMOTE_PATH}'" > "$TMPFILE"

Fix — quote the variable and validate the path:

  # Add input validation near the top, after argument parsing:
  if [[ "$MODE" == "upload" || "$MODE" == "read" ]]; then
    if [[ "$REMOTE_PATH" =~ [^a-zA-Z0-9_./-] ]]; then
      echo "Error: REMOTE_PATH contains invalid characters" >&2 && exit 1
    fi
  fi

  # Then quote the variable in the command string:
  json_cmd "bash -c 'echo ${ENCODED} | base64 -d | gunzip > \"${REMOTE_PATH}\"'" > "$TMPFILE"
  # ... and similarly for the non-gzip path:
  json_cmd "bash -c 'echo ${ENCODED} | base64 -d > \"${REMOTE_PATH}\"'" > "$TMPFILE"

json_cmd "bash -c 'echo ${ENCODED} | base64 -d > ${REMOTE_PATH}'" > "$TMPFILE"
fi
;;
read)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shell injection in ssm-exec.sh read mode — REMOTE_PATH single-quoted inside double-quoted string; a path containing ' breaks out of quoting

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code:
json_cmd "cat '${REMOTE_PATH}'" > "$TMPFILE"

  A path containing a single quote (') breaks out of the quoting. Fix — escape single quotes and validate:
  SAFE_PATH=$(printf '%s' "$REMOTE_PATH" | sed "s/'/'\\\\''/g")
  json_cmd "cat '${SAFE_PATH}'" > "$TMPFILE"

custom_prompt = "PEXPECT_READY# "

try:
ssm_command = f"aws ssm start-session --target {ssm_target}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pexpect.spawn with string argument — f"aws ssm start-session --target {ssm_target}" passes through shell. Should use list form

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code (hyperpod_issue_report.py:~line 778):
ssm_command = f"aws ssm start-session --target {ssm_target}"
child = pexpect.spawn(ssm_command, encoding='utf-8')

When pexpect.spawn receives a string, it invokes /bin/sh -c, so a crafted ssm_target can inject shell commands. Fix — use list form:

  ssm_command_args = ["aws", "ssm", "start-session", "--target", ssm_target]
  child = pexpect.spawn(ssm_command_args[0], ssm_command_args[1:], encoding='utf-8')

  Additionally, validate ssm_target matches the expected pattern:
  import re
  if not re.match(r'^sagemaker-cluster:[a-zA-Z0-9_-]+$', ssm_target):
      raise ValueError(f"Invalid SSM target format: {ssm_target}")

"""
HyperPod Issue Report Collector

Collects diagnostic logs and configurations from multiple HyperPod nodes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unescaped user commands in generated collector script — --command values interpolated into echo "Running: {cmd}" without escaping ", $, backticks
Location: hyperpod_issue_report.py (generate_collector_script)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix — escape commands before embedding, or write them to a separate file and source it:

  import shlex

  def _escape_for_bash_heredoc(cmd: str) -> str:
      """Escape a command string for safe embedding in a bash script."""
      # Use single-quoted string which prevents all shell expansion
      return "'" + cmd.replace("'", "'\\''") + "'"

  # When building the collector script:
  for cmd in user_commands:
      safe_cmd = _escape_for_bash_heredoc(cmd)
      script_lines.append(f"  eval {safe_cmd}")

--query 'ClusterArn' --output text)
CLUSTER_ID=$(echo "$ARN" | cut -d'/' -f2)

echo "{\"cluster_id\":\"${CLUSTER_ID}\",\"cluster_arn\":\"${ARN}\",\"cluster_name\":\"${CLUSTER}\",\"region\":\"${REGION}\"}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSON via string interpolation in get-cluster-info.sh — user-supplied cluster name injected into JSON string without escaping. list-nodes.sh already uses jq correctly

Copy link
Copy Markdown
Contributor

@krokoko krokoko Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current code (line 20):
echo "{"cluster_id":"${CLUSTER_ID}","cluster_arn":"${ARN}","cluster_name":"${CLUSTER}","region":"${REGION}"}"

If any value contains " or backslash, this produces malformed JSON. Fix — use jq:

  jq -n \
    --arg id "$CLUSTER_ID" \
    --arg arn "$ARN" \
    --arg name "$CLUSTER" \
    --arg region "$REGION" \
    '{cluster_id: $id, cluster_arn: $arn, cluster_name: $name, region: $region}'

The script already expects jq to be available (it's used by list-nodes.sh and ssm-exec.sh in the same plugin), so this is safe.

HyperPod Issue Report Collector

Collects diagnostic logs and configurations from multiple HyperPod nodes.
Supports both HyperPod EKS and HyperPod Slurm clusters.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EKS log collector downloaded and executed without integrity verification — downloads from GitHub main branch and runs as root on every node with no checksum
Location: hyperpod_issue_report.py (generate_collector_script)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix — pin to a specific commit and verify a checksum:

  EKS_LOG_COLLECTOR_URL = "https://raw.githubusercontent.com/awslabs/amazon-eks-ami/<COMMIT_SHA>/log-collector-script/linux/eks-log-collector.sh"
  EKS_LOG_COLLECTOR_SHA256 = "<expected_sha256_hash>"

  # After download:
  import hashlib
  actual_hash = hashlib.sha256(downloaded_content).hexdigest()
  if actual_hash != EKS_LOG_COLLECTOR_SHA256:
      raise RuntimeError(
          f"EKS log collector integrity check failed. "
          f"Expected {EKS_LOG_COLLECTOR_SHA256}, got {actual_hash}"
      )

print(f"Warning: Could not get private IP for {instance_id}: {e}")
return None

def get_cluster_nodes(self) -> List[Dict]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cluster_nodes() swallows all exceptions, returns [] — user sees "No nodes found" instead of the real error (wrong IAM, bad cluster name, network failure)

return parts[-1]
return None

def get_slurm_node_name(self, instance_id: str) -> Optional[str]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_slurm_node_name() silently swallows exceptions — only logs in --debug mode; all Slurm node resolution can fail silently

@@ -0,0 +1,250 @@
"""
Provide your custom reward function code below. Learn about the available libraries and templates that you can use
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambda reward function aborts entire batch on single-sample error — inner except returns immediately with {"error":...} (no statusCode), discarding all remaining samples

@@ -0,0 +1,43 @@
import boto3
import json
import sys
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_model_names.py has zero error handling — AWS API calls with no try/except; crashes with raw traceback

import json
import sys

if len(sys.argv) < 3:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_recipes.py has zero error handling — same issue as get_model_names

@@ -0,0 +1,146 @@
#!/usr/bin/env python3

import os
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformation_tools.py has zero error handling — get_execution_role() fails unhelpfully if not on SageMaker


time.sleep(30)

# Cell 5: Test Inference
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deploy-oss-bedrock.py retry loop only catches ModelNotReadyException in cell 5 — throttling, access denied, or JSON errors crash on first attempt

result = detect_format("s3://my-bucket/data.jsonl")
if result.is_valid:
print(f"Format: {result.format_type}")
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format_detector.py main block misses boto3 and UnicodeDecodeError — non-UTF-8 files crash with unhandled exception

print("-" * 60)
print(f"\nReport collection completed!")
print(f"Instance reports uploaded to: s3://{self.s3_bucket}/{self.report_s3_key}/instances/")
print(f"Summary: s3://{self.s3_bucket}/{self.report_s3_key}/summary.json")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_summary() failure printed but output still claims success — "Summary: s3://..." displayed even when upload failed

print(f"Warning: Error verifying kubectl config: {e}")
return False

def collect_kubectl_node_info(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect_kubectl_node_info() swallows all exceptions — entire kubectl collection silently dropped from report

#!/usr/bin/env bash
# List all HyperPod cluster nodes with instance group info (handles pagination)
# Usage: ./list-nodes.sh CLUSTER_NAME [--region REGION] [--instance-group GROUP] [--instance-id ID]
# Output: JSON array of nodes with InstanceId, InstanceGroupName, InstanceStatus, etc.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list-nodes.sh requires jq but doesn't check — set -e causes immediate exit with no friendly message

# Read: ./ssm-exec.sh --target TARGET --read REMOTE_PATH [--region REGION]
#
# Target format: sagemaker-cluster:<CLUSTER_ID>_<GROUP_NAME>-<INSTANCE_ID>
# Build target from parts: use --cluster-id, --group, --instance-id instead of --target
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ssm-exec.sh upload mode doesn't validate file existence — attempts to base64-encode a nonexistent file

@krokoko
Copy link
Copy Markdown
Contributor

krokoko commented Mar 30, 2026

SDK version mismatch across docs vs code — reference docs say >=3.0.0 (some say >=3.6.0), all scripts install >=3.7.0. Troubleshooting advice to install >=3.0.0 won't fix the problem

@@ -0,0 +1,138 @@
# Deploy OSS Merged LoRA to Bedrock CMI
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deploy-oss-bedrock.md says "Use Converse API" but script uses invoke_model — deploy-nova-bedrock.py correctly uses converse(), but the OSS variant does not

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

do-not-merge Do not merge the pull request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants