Skip to content

Commit a47216a

Browse files
authored
Merge pull request #447 from unity-sds/438-run-ogc-process
438 run ogc process
2 parents e34ff09 + 7e9d815 commit a47216a

7 files changed

Lines changed: 453 additions & 2 deletions

File tree

airflow/dags/run_ogc_process.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
"""
2+
DAG with custom SPSOGCOperator that subclasses KubernetesPodOperator
3+
for OGC process execution with SPS-specific functionality.
4+
"""
5+
6+
import json
7+
import logging
8+
import re
9+
from datetime import datetime
10+
11+
import requests
12+
from airflow.models.baseoperator import chain
13+
from airflow.models.dag import DAG
14+
from airflow.models.param import Param
15+
from airflow.operators.python import PythonOperator, get_current_context
16+
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
17+
from airflow.providers.cncf.kubernetes.secret import Secret as AirflowK8sSecret
18+
from airflow.utils.trigger_rule import TriggerRule
19+
from kubernetes.client import models as k8s
20+
from unity_sps_utils import POD_LABEL, POD_NAMESPACE, get_affinity
21+
22+
PROCESSES_ENDPOINT = "https://api.dit.maap-project.org/api/ogc/processes"
23+
24+
25+
def fetch_ogc_processes():
26+
"""Fetch available processes from the OGC API and create mapping."""
27+
try:
28+
response = requests.get(PROCESSES_ENDPOINT, timeout=30)
29+
response.raise_for_status()
30+
31+
processes_data = response.json()
32+
process_mapping = {}
33+
dropdown_options = []
34+
35+
for process in processes_data.get("processes", []):
36+
process_id = process.get("id")
37+
process_version = process.get("version")
38+
39+
# Extract numerical ID from links
40+
numerical_id = None
41+
for link in process.get("links", []):
42+
if link.get("rel") == "self":
43+
href = link.get("href", "")
44+
# Extract number from href like "/ogc/processes/7"
45+
match = re.search(r"/processes/(\d+)$", href)
46+
if match:
47+
numerical_id = int(match.group(1))
48+
break
49+
50+
if process_id and numerical_id:
51+
display_name = f"{process_id}:{process_version}" if process_version else process_id
52+
dropdown_options.append(display_name)
53+
process_mapping[display_name] = numerical_id
54+
55+
return process_mapping, dropdown_options
56+
57+
except requests.RequestException as e:
58+
logging.error(f"Failed to fetch processes: {e}")
59+
# Return fallback mapping
60+
return {"example-process:1.0": 1}, ["example-process:1.0"]
61+
except Exception as e:
62+
logging.error(f"Error processing OGC processes: {e}")
63+
return {"example-process:1.0": 1}, ["example-process:1.0"]
64+
65+
66+
# Constants
67+
K8S_SECRET_NAME = "sps-app-credentials"
68+
DOCKER_IMAGE = "jplmdps/ogc-job-runner:latest"
69+
PROCESS_MAPPING, DROPDOWN_OPTIONS = fetch_ogc_processes()
70+
71+
# SPS-specific secrets
72+
secret_env_vars = [
73+
AirflowK8sSecret(
74+
deploy_type="env",
75+
deploy_target="MAAP_PGT",
76+
secret=K8S_SECRET_NAME,
77+
key="MAAP_PGT",
78+
)
79+
]
80+
81+
82+
class SPSOGCOperator(KubernetesPodOperator):
83+
"""
84+
Custom operator for SPS OGC process execution that subclasses KubernetesPodOperator.
85+
86+
This operator encapsulates all SPS-specific configuration and provides a clean
87+
interface for OGC process submission and monitoring.
88+
"""
89+
90+
def __init__(
91+
self,
92+
operation_type: str,
93+
selected_process: str = None,
94+
job_inputs: str = None,
95+
job_queue: str = None,
96+
job_id: str = None,
97+
**kwargs,
98+
):
99+
"""
100+
Initialize the SPSOGCOperator.
101+
102+
Args:
103+
operation_type: Either "submit" or "monitor"
104+
selected_process: Process selection for submit operations
105+
job_inputs: JSON string of job inputs for submit operations
106+
job_queue: Queue name for submit operations
107+
job_id: Job ID for monitor operations
108+
"""
109+
self.operation_type = operation_type
110+
self.selected_process = selected_process
111+
self.job_inputs = job_inputs
112+
self.job_queue = job_queue
113+
self.job_id = job_id
114+
115+
# Set SPS-specific defaults
116+
kwargs.setdefault("namespace", POD_NAMESPACE)
117+
kwargs.setdefault("image", DOCKER_IMAGE)
118+
kwargs.setdefault("service_account_name", "airflow-worker")
119+
kwargs.setdefault("secrets", secret_env_vars)
120+
kwargs.setdefault("in_cluster", True)
121+
kwargs.setdefault("get_logs", True)
122+
kwargs.setdefault("startup_timeout_seconds", 600)
123+
kwargs.setdefault("container_security_context", {"privileged": True})
124+
kwargs.setdefault("container_logs", True)
125+
kwargs.setdefault("labels", {"pod": POD_LABEL})
126+
kwargs.setdefault("annotations", {"karpenter.sh/do-not-disrupt": "true"})
127+
kwargs.setdefault(
128+
"affinity",
129+
get_affinity(
130+
capacity_type=["spot"],
131+
anti_affinity_label=POD_LABEL,
132+
),
133+
)
134+
kwargs.setdefault("on_finish_action", "keep_pod")
135+
kwargs.setdefault("is_delete_operator_pod", False)
136+
137+
# Build operation-specific environment variables
138+
if operation_type == "submit":
139+
kwargs["env_vars"] = self._build_submit_env_vars()
140+
kwargs["name"] = f"ogc-submit-pod-{kwargs.get('task_id', 'unknown')}"
141+
kwargs.setdefault("do_xcom_push", True) # Submit tasks need to return job ID
142+
elif operation_type == "monitor":
143+
kwargs["env_vars"] = self._build_monitor_env_vars()
144+
kwargs["name"] = f"ogc-monitor-pod-{kwargs.get('task_id', 'unknown')}"
145+
else:
146+
raise ValueError(f"Invalid operation_type: {operation_type}. Must be 'submit' or 'monitor'")
147+
148+
super().__init__(**kwargs)
149+
150+
def _build_submit_env_vars(self):
151+
"""Build environment variables for job submission."""
152+
# Resolve numerical process ID from selected process
153+
numerical_process_id = self._resolve_process_id()
154+
155+
return [
156+
k8s.V1EnvVar(
157+
name="SUBMIT_JOB_URL",
158+
value="https://api.dit.maap-project.org/api/ogc/processes/{process_id}/execution",
159+
),
160+
k8s.V1EnvVar(name="PROCESS_ID", value=str(numerical_process_id)),
161+
k8s.V1EnvVar(name="JOB_INPUTS", value=self.job_inputs or "{}"),
162+
k8s.V1EnvVar(name="QUEUE", value=self.job_queue or "maap-dps-worker-cardamom"),
163+
k8s.V1EnvVar(name="SUBMIT_JOB", value="true"),
164+
]
165+
166+
def _build_monitor_env_vars(self):
167+
"""Build environment variables for job monitoring."""
168+
return [
169+
k8s.V1EnvVar(
170+
name="MONITOR_JOB_URL",
171+
value="https://api.dit.maap-project.org/api/ogc/jobs/{job_id}",
172+
),
173+
k8s.V1EnvVar(name="JOB_ID", value=self.job_id),
174+
k8s.V1EnvVar(name="SUBMIT_JOB", value="false"),
175+
]
176+
177+
def _resolve_process_id(self):
178+
"""Resolve the selected process to a numerical process ID."""
179+
if not self.selected_process:
180+
raise ValueError("selected_process is required for submit operations")
181+
182+
# Handle templated values - they won't be resolved yet during __init__
183+
if "{{" in str(self.selected_process):
184+
# Return a template that will be resolved at runtime
185+
return "{{ ti.xcom_pull(task_ids='Setup', key='return_value')['numerical_process_id'] }}"
186+
187+
# Direct lookup for non-templated values
188+
numerical_id = PROCESS_MAPPING.get(self.selected_process)
189+
if numerical_id is None:
190+
self.log.warning(f"Process '{self.selected_process}' not found in mapping, defaulting to ID 1")
191+
return 1
192+
193+
return numerical_id
194+
195+
def execute(self, context):
196+
"""Execute the operator with additional SPS-specific logging."""
197+
self.log.info(f"Starting SPS OGC {self.operation_type} operation")
198+
199+
if self.operation_type == "submit":
200+
self.log.info(f"Selected process: {self.selected_process}")
201+
self.log.info(f"Job queue: {self.job_queue}")
202+
self.log.info(f"Job inputs: {self.job_inputs}")
203+
elif self.operation_type == "monitor":
204+
self.log.info(f"Monitoring job ID: {self.job_id}")
205+
206+
# Call parent execute method
207+
result = super().execute(context)
208+
209+
self.log.info(f"SPS OGC {self.operation_type} operation completed")
210+
return result
211+
212+
213+
dag_default_args = {
214+
"owner": "unity-sps",
215+
"depends_on_past": False,
216+
"start_date": datetime.utcfromtimestamp(0),
217+
}
218+
219+
# --- DAG Definition ---
220+
221+
dag = DAG(
222+
dag_id="run_ogc_process",
223+
description="Submits a job to an OGC process and monitors (using custom SPSOGCOperator)",
224+
dag_display_name="Run an OGC Process (Custom Operator from KubernetesPodOperator)",
225+
tags=["ogc", "job", "custom-operator"],
226+
is_paused_upon_creation=False,
227+
catchup=False,
228+
schedule=None,
229+
max_active_runs=10,
230+
default_args=dag_default_args,
231+
params={
232+
"selected_process": Param(
233+
default=DROPDOWN_OPTIONS[0] if DROPDOWN_OPTIONS else "Error loading dropdown",
234+
enum=DROPDOWN_OPTIONS,
235+
title="Process Selection",
236+
description="Select a process to execute.",
237+
),
238+
"queue": Param(
239+
"maap-dps-worker-cardamom",
240+
type="string",
241+
title="Queue",
242+
description="The MAAP queue to submit the job to",
243+
),
244+
"job_inputs": Param(
245+
"{}",
246+
type="string",
247+
title="Job Inputs",
248+
description="A JSON string representing the inputs payload for the job.",
249+
),
250+
},
251+
)
252+
253+
# --- Task Definitions ---
254+
255+
256+
def setup(ti=None, **context):
257+
"""Task that logs DAG parameters and process mapping information."""
258+
259+
logging.info("Starting OGC job submission and monitoring DAG (Custom Operator Version).")
260+
logging.info(f"Parameters received: {context['params']}")
261+
logging.info(f"Available processes: {len(DROPDOWN_OPTIONS)}")
262+
logging.info(f"Process mapping: {json.dumps(PROCESS_MAPPING, indent=2)}")
263+
264+
context = get_current_context()
265+
logging.info(f"DAG Run parameters: {json.dumps(context['params'], sort_keys=True, indent=4)}")
266+
267+
selected_process = context["params"].get("selected_process")
268+
if selected_process in PROCESS_MAPPING:
269+
numerical_id = PROCESS_MAPPING[selected_process]
270+
logging.info(f"Selected process '{selected_process}' maps to numerical ID: {numerical_id}")
271+
return {"numerical_process_id": numerical_id}
272+
else:
273+
logging.warning(f"Selected process '{selected_process}' not found in mapping")
274+
return {"numerical_process_id": 1}
275+
276+
277+
setup_task = PythonOperator(task_id="Setup", python_callable=setup, dag=dag)
278+
279+
submit_job_task = SPSOGCOperator(
280+
task_id="submit_job_task",
281+
operation_type="submit",
282+
selected_process="{{ params.selected_process }}",
283+
job_inputs="{{ params.job_inputs }}",
284+
job_queue="{{ params.queue }}",
285+
dag=dag,
286+
)
287+
288+
monitor_job_task = SPSOGCOperator(
289+
task_id="monitor_job_task",
290+
operation_type="monitor",
291+
job_id="{{ ti.xcom_pull(task_ids='submit_job_task', key='return_value')['job_id'] }}",
292+
dag=dag,
293+
)
294+
295+
296+
def cleanup(**context):
297+
"""A placeholder cleanup task"""
298+
logging.info("Cleanup executed.")
299+
300+
# Log final results if available
301+
submit_result = context["ti"].xcom_pull(task_ids="submit_job_task", key="return_value")
302+
monitor_result = context["ti"].xcom_pull(task_ids="monitor_job_task", key="return_value")
303+
304+
if submit_result:
305+
logging.info(f"Job submission result: {submit_result}")
306+
if monitor_result:
307+
logging.info(f"Job monitoring result: {monitor_result}")
308+
309+
310+
cleanup_task = PythonOperator(
311+
task_id="Cleanup", python_callable=cleanup, dag=dag, trigger_rule=TriggerRule.ALL_DONE
312+
)
313+
314+
chain(setup_task, submit_job_task, monitor_job_task, cleanup_task)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
FROM alpine:3.18
2+
3+
RUN apk add --no-cache curl jq
4+
5+
COPY run_ogc_process_entrypoint.sh /usr/share/ogc/run_ogc_process_entrypoint.sh
6+
WORKDIR /usr/share/ogc
7+
RUN chmod +x /usr/share/ogc/run_ogc_process_entrypoint.sh
8+
ENTRYPOINT ["/usr/share/ogc/run_ogc_process_entrypoint.sh"]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/bin/sh
2+
3+
set -e
4+
5+
if [ "$SUBMIT_JOB" = "true" ] || [ "$SUBMIT_JOB" = "True" ]; then
6+
echo "Submitting job"
7+
8+
SUBMIT_JOB_URL=$(echo "$SUBMIT_JOB_URL" | sed "s/{process_id}/$PROCESS_ID/")
9+
SUBMIT_JOB_ARGUMENTS=$(jq -n \
10+
--arg queue "$QUEUE" \
11+
--argjson inputs "$JOB_INPUTS" \
12+
'{queue: $queue, inputs: $inputs}')
13+
14+
echo "Submitting the job to ${SUBMIT_JOB_URL}"
15+
16+
response=$(curl --location ${SUBMIT_JOB_URL} \
17+
--header "proxy-ticket: ${MAAP_PGT}" \
18+
--header "Content-Type: application/json" \
19+
--data "${SUBMIT_JOB_ARGUMENTS}")
20+
21+
echo "API Response: $response"
22+
job_id=$(echo "$response" | jq -r .id)
23+
24+
if [ "$job_id" = "null" ] || [ -z "$job_id" ]; then
25+
echo "Failed to get jobID from response."
26+
exit 1
27+
fi
28+
29+
echo "Job submitted successfully. Job ID: ${job_id}"
30+
31+
# Write the job_id to the XCom return file for the next task
32+
mkdir -p /airflow/xcom/
33+
printf '{"job_id": "%s"}' "$job_id" > /airflow/xcom/return.json
34+
elif [ "$SUBMIT_JOB" = "false" ] || [ "$SUBMIT_JOB" = "False" ]; then
35+
echo "Monitoring job status"
36+
37+
MONITOR_JOB_URL=$(echo "$MONITOR_JOB_URL" | sed "s/{job_id}/$JOB_ID/")
38+
39+
TIMEOUT=3600
40+
POLL_INTERVAL=30
41+
SECONDS=0
42+
43+
while [ $SECONDS -lt $TIMEOUT ]; do
44+
echo "Checking status..."
45+
response=$(curl --location ${MONITOR_JOB_URL} \
46+
--header "proxy-ticket: ${MAAP_PGT}" \
47+
--header "Content-Type: application/json")
48+
49+
status=$(echo "$response" | jq -r .status)
50+
51+
echo "Current status is: $status"
52+
53+
if [ "$status" = "successful" ]; then
54+
echo "Job completed successfully!"
55+
exit 0
56+
elif [ "$status" = "failed" ]; then
57+
echo "Job failed!"
58+
echo "Error details: $(echo "$response" | jq .)"
59+
exit 1
60+
fi
61+
62+
sleep $POLL_INTERVAL
63+
SECONDS=$((SECONDS + POLL_INTERVAL))
64+
done
65+
66+
echo "Job monitoring timed out after $TIMEOUT seconds."
67+
exit 1
68+
else
69+
echo "SUBMIT_JOB variable must be specified and set to true or false"
70+
fi

0 commit comments

Comments
 (0)