Skip to content

Commit db2dc1e

Browse files
committed
[FLINK-38933][python] Expose runtime context information in Python UDX FunctionContext
This aligns Python UDX with Java UDX by exposing runtime context information through FunctionContext. Added the following getter methods to FunctionContext: - get_task_name() - get_task_name_with_subtasks() - get_number_of_parallel_subtasks() - get_max_number_of_parallel_subtasks() - get_index_of_this_subtask() - get_attempt_number() The runtime context is propagated from Java operators via the protobuf protocol by adding a runtime_context field to UserDefinedFunctions and UserDefinedAggregateFunctions messages.
1 parent 45a04a7 commit db2dc1e

11 files changed

Lines changed: 329 additions & 116 deletions

File tree

flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py

Lines changed: 106 additions & 106 deletions
Large diffs are not rendered by default.

flink-python/pyflink/fn_execution/flink_fn_execution_pb2.pyi

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,22 @@ class AsyncOptions(_message.Message):
7070
def __init__(self, max_concurrent_operations: _Optional[int] = ..., timeout_ms: _Optional[int] = ..., retry_enabled: bool = ..., retry_max_attempts: _Optional[int] = ..., retry_delay_ms: _Optional[int] = ...) -> None: ...
7171

7272
class UserDefinedFunctions(_message.Message):
73-
__slots__ = ("udfs", "metric_enabled", "windows", "profile_enabled", "job_parameters", "async_options")
73+
__slots__ = ("udfs", "metric_enabled", "windows", "profile_enabled", "job_parameters", "async_options", "runtime_context")
7474
UDFS_FIELD_NUMBER: _ClassVar[int]
7575
METRIC_ENABLED_FIELD_NUMBER: _ClassVar[int]
7676
WINDOWS_FIELD_NUMBER: _ClassVar[int]
7777
PROFILE_ENABLED_FIELD_NUMBER: _ClassVar[int]
7878
JOB_PARAMETERS_FIELD_NUMBER: _ClassVar[int]
7979
ASYNC_OPTIONS_FIELD_NUMBER: _ClassVar[int]
80+
RUNTIME_CONTEXT_FIELD_NUMBER: _ClassVar[int]
8081
udfs: _containers.RepeatedCompositeFieldContainer[UserDefinedFunction]
8182
metric_enabled: bool
8283
windows: _containers.RepeatedCompositeFieldContainer[OverWindow]
8384
profile_enabled: bool
8485
job_parameters: _containers.RepeatedCompositeFieldContainer[JobParameter]
8586
async_options: AsyncOptions
86-
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedFunction, _Mapping]]] = ..., metric_enabled: bool = ..., windows: _Optional[_Iterable[_Union[OverWindow, _Mapping]]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., async_options: _Optional[_Union[AsyncOptions, _Mapping]] = ...) -> None: ...
87+
runtime_context: UserDefinedDataStreamFunction.RuntimeContext
88+
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedFunction, _Mapping]]] = ..., metric_enabled: bool = ..., windows: _Optional[_Iterable[_Union[OverWindow, _Mapping]]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., async_options: _Optional[_Union[AsyncOptions, _Mapping]] = ..., runtime_context: _Optional[_Union[UserDefinedDataStreamFunction.RuntimeContext, _Mapping]] = ...) -> None: ...
8789

8890
class OverWindow(_message.Message):
8991
__slots__ = ("window_type", "lower_boundary", "upper_boundary")
@@ -195,7 +197,7 @@ class GroupWindow(_message.Message):
195197
def __init__(self, window_type: _Optional[_Union[GroupWindow.WindowType, str]] = ..., is_time_window: bool = ..., window_slide: _Optional[int] = ..., window_size: _Optional[int] = ..., window_gap: _Optional[int] = ..., is_row_time: bool = ..., time_field_index: _Optional[int] = ..., allowedLateness: _Optional[int] = ..., namedProperties: _Optional[_Iterable[_Union[GroupWindow.WindowProperty, str]]] = ..., shift_timezone: _Optional[str] = ...) -> None: ...
196198

197199
class UserDefinedAggregateFunctions(_message.Message):
198-
__slots__ = ("udfs", "metric_enabled", "grouping", "generate_update_before", "key_type", "index_of_count_star", "state_cleaning_enabled", "state_cache_size", "map_state_read_cache_size", "map_state_write_cache_size", "count_star_inserted", "group_window", "profile_enabled", "job_parameters")
200+
__slots__ = ("udfs", "metric_enabled", "grouping", "generate_update_before", "key_type", "index_of_count_star", "state_cleaning_enabled", "state_cache_size", "map_state_read_cache_size", "map_state_write_cache_size", "count_star_inserted", "group_window", "profile_enabled", "job_parameters", "runtime_context")
199201
UDFS_FIELD_NUMBER: _ClassVar[int]
200202
METRIC_ENABLED_FIELD_NUMBER: _ClassVar[int]
201203
GROUPING_FIELD_NUMBER: _ClassVar[int]
@@ -210,6 +212,7 @@ class UserDefinedAggregateFunctions(_message.Message):
210212
GROUP_WINDOW_FIELD_NUMBER: _ClassVar[int]
211213
PROFILE_ENABLED_FIELD_NUMBER: _ClassVar[int]
212214
JOB_PARAMETERS_FIELD_NUMBER: _ClassVar[int]
215+
RUNTIME_CONTEXT_FIELD_NUMBER: _ClassVar[int]
213216
udfs: _containers.RepeatedCompositeFieldContainer[UserDefinedAggregateFunction]
214217
metric_enabled: bool
215218
grouping: _containers.RepeatedScalarFieldContainer[int]
@@ -224,7 +227,8 @@ class UserDefinedAggregateFunctions(_message.Message):
224227
group_window: GroupWindow
225228
profile_enabled: bool
226229
job_parameters: _containers.RepeatedCompositeFieldContainer[JobParameter]
227-
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedAggregateFunction, _Mapping]]] = ..., metric_enabled: bool = ..., grouping: _Optional[_Iterable[int]] = ..., generate_update_before: bool = ..., key_type: _Optional[_Union[Schema.FieldType, _Mapping]] = ..., index_of_count_star: _Optional[int] = ..., state_cleaning_enabled: bool = ..., state_cache_size: _Optional[int] = ..., map_state_read_cache_size: _Optional[int] = ..., map_state_write_cache_size: _Optional[int] = ..., count_star_inserted: bool = ..., group_window: _Optional[_Union[GroupWindow, _Mapping]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ...) -> None: ...
230+
runtime_context: UserDefinedDataStreamFunction.RuntimeContext
231+
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedAggregateFunction, _Mapping]]] = ..., metric_enabled: bool = ..., grouping: _Optional[_Iterable[int]] = ..., generate_update_before: bool = ..., key_type: _Optional[_Union[Schema.FieldType, _Mapping]] = ..., index_of_count_star: _Optional[int] = ..., state_cleaning_enabled: bool = ..., state_cache_size: _Optional[int] = ..., map_state_read_cache_size: _Optional[int] = ..., map_state_write_cache_size: _Optional[int] = ..., count_star_inserted: bool = ..., group_window: _Optional[_Union[GroupWindow, _Mapping]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., runtime_context: _Optional[_Union[UserDefinedDataStreamFunction.RuntimeContext, _Mapping]] = ...) -> None: ...
228232

229233
class Schema(_message.Message):
230234
__slots__ = ("fields",)

flink-python/pyflink/fn_execution/metrics/tests/test_metric.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,31 @@ def test_metric_not_enabled(self):
5252
with self.assertRaises(RuntimeError):
5353
fc.get_metric_group()
5454

55+
def test_function_context_runtime_info(self):
56+
fc = FunctionContext(
57+
None, {},
58+
task_name='MyTask',
59+
task_name_with_subtasks='MyTask (1/4)',
60+
number_of_parallel_subtasks=4,
61+
max_number_of_parallel_subtasks=128,
62+
index_of_this_subtask=0,
63+
attempt_number=2)
64+
self.assertEqual('MyTask', fc.get_task_name())
65+
self.assertEqual('MyTask (1/4)', fc.get_task_name_with_subtasks())
66+
self.assertEqual(4, fc.get_number_of_parallel_subtasks())
67+
self.assertEqual(128, fc.get_max_number_of_parallel_subtasks())
68+
self.assertEqual(0, fc.get_index_of_this_subtask())
69+
self.assertEqual(2, fc.get_attempt_number())
70+
71+
def test_function_context_runtime_info_defaults(self):
72+
fc = FunctionContext(None, {})
73+
self.assertIsNone(fc.get_task_name())
74+
self.assertIsNone(fc.get_task_name_with_subtasks())
75+
self.assertIsNone(fc.get_number_of_parallel_subtasks())
76+
self.assertIsNone(fc.get_max_number_of_parallel_subtasks())
77+
self.assertIsNone(fc.get_index_of_this_subtask())
78+
self.assertIsNone(fc.get_attempt_number())
79+
5580
def test_get_metric_name(self):
5681
new_group = MetricTests.base_metric_group.add_group('my_group')
5782
self.assertEqual(

flink-python/pyflink/fn_execution/table/async_function/operations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ def __init__(self, serialized_fn):
7474

7575
# Job parameters
7676
self._job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
77+
if serialized_fn.HasField('runtime_context'):
78+
rc = serialized_fn.runtime_context
79+
self._runtime_context = {
80+
'task_name': rc.task_name,
81+
'task_name_with_subtasks': rc.task_name_with_subtasks,
82+
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
83+
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
84+
'index_of_this_subtask': rc.index_of_this_subtask,
85+
'attempt_number': rc.attempt_number,
86+
}
87+
else:
88+
self._runtime_context = {}
7789

7890
def set_output_processor(self, output_processor):
7991
"""Set the output processor for emitting results.
@@ -86,8 +98,9 @@ def open(self):
8698
# Open user defined functions
8799
for user_defined_func in self.user_defined_funcs:
88100
if hasattr(user_defined_func, 'open'):
89-
user_defined_func.open(
90-
FunctionContext(self.base_metric_group, self._job_parameters))
101+
user_defined_func.open(FunctionContext(
102+
self.base_metric_group, self._job_parameters,
103+
**self._runtime_context))
91104

92105
# Start emitter thread to collect async results
93106
self._emitter = Emitter(self._mark_exception, self._output_processor, self._queue)

flink-python/pyflink/fn_execution/table/operations.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ def __init__(self, serialized_fn):
8585
self.base_metric_group = None
8686
self.func, self.user_defined_funcs = self.generate_func(serialized_fn)
8787
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
88+
if serialized_fn.HasField('runtime_context'):
89+
rc = serialized_fn.runtime_context
90+
self.runtime_context = {
91+
'task_name': rc.task_name,
92+
'task_name_with_subtasks': rc.task_name_with_subtasks,
93+
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
94+
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
95+
'index_of_this_subtask': rc.index_of_this_subtask,
96+
'attempt_number': rc.attempt_number,
97+
}
98+
else:
99+
self.runtime_context = {}
88100

89101
def finish(self):
90102
self._update_gauge(self.base_metric_group)
@@ -104,7 +116,9 @@ def process_element(self, value):
104116
def open(self):
105117
for user_defined_func in self.user_defined_funcs:
106118
if hasattr(user_defined_func, 'open'):
107-
user_defined_func.open(FunctionContext(self.base_metric_group, self.job_parameters))
119+
user_defined_func.open(FunctionContext(
120+
self.base_metric_group, self.job_parameters,
121+
**self.runtime_context))
108122

109123
def close(self):
110124
for user_defined_func in self.user_defined_funcs:
@@ -326,11 +340,25 @@ def __init__(self, serialized_fn, keyed_state_backend):
326340
self.state_cleaning_enabled = serialized_fn.state_cleaning_enabled
327341
self.data_view_specs = extract_data_view_specs(serialized_fn.udfs)
328342
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
343+
if serialized_fn.HasField('runtime_context'):
344+
rc = serialized_fn.runtime_context
345+
self.runtime_context = {
346+
'task_name': rc.task_name,
347+
'task_name_with_subtasks': rc.task_name_with_subtasks,
348+
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
349+
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
350+
'index_of_this_subtask': rc.index_of_this_subtask,
351+
'attempt_number': rc.attempt_number,
352+
}
353+
else:
354+
self.runtime_context = {}
329355
super(AbstractStreamGroupAggregateOperation, self).__init__(
330356
serialized_fn, keyed_state_backend)
331357

332358
def open(self):
333-
self.group_agg_function.open(FunctionContext(self.base_metric_group, self.job_parameters))
359+
self.group_agg_function.open(FunctionContext(
360+
self.base_metric_group, self.job_parameters,
361+
**self.runtime_context))
334362

335363
def close(self):
336364
self.group_agg_function.close()

flink-python/pyflink/proto/flink-fn-execution.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ message UserDefinedFunctions {
8282
bool profile_enabled = 4;
8383
repeated JobParameter job_parameters = 5;
8484
AsyncOptions async_options = 6;
85+
// The runtime context of the user-defined functions, providing task info such as
86+
// task_name, parallelism, subtask_index, etc.
87+
UserDefinedDataStreamFunction.RuntimeContext runtime_context = 7;
8588
}
8689

8790
// Used to describe the info of over window in pandas batch over window aggregation
@@ -200,6 +203,9 @@ message UserDefinedAggregateFunctions {
200203

201204
bool profile_enabled = 13;
202205
repeated JobParameter job_parameters = 14;
206+
// The runtime context of the user-defined aggregate functions, providing task info such as
207+
// task_name, parallelism, subtask_index, etc.
208+
UserDefinedDataStreamFunction.RuntimeContext runtime_context = 15;
203209
}
204210

205211
// A representation of the data schema.

flink-python/pyflink/table/tests/test_udf.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,34 @@ def test_open(self):
266266
actual = source_sink_utils.results()
267267
self.assert_equals(actual, ["+I[1, 1]", "+I[2, 4]", "+I[3, 3]"])
268268

269+
def test_function_context_runtime_info(self):
270+
runtime_info_func = udf(RuntimeInfoFunc(), result_type=DataTypes.STRING())
271+
272+
sink_table = generate_random_table_name()
273+
sink_table_ddl = f"""
274+
CREATE TABLE {sink_table}(a STRING) WITH ('connector'='test-sink')
275+
"""
276+
self.t_env.execute_sql(sink_table_ddl)
277+
278+
t = self.t_env.from_elements([(1,)], ['a'])
279+
t.select(runtime_info_func(t.a)).execute_insert(sink_table).wait()
280+
actual = source_sink_utils.results()
281+
result = actual[0]
282+
# The result should contain task_name, number_of_parallel_subtasks, and
283+
# index_of_this_subtask info, verifying that FunctionContext runtime info
284+
# is properly propagated from Java to Python.
285+
self.assertTrue(result.startswith("+I["))
286+
# Extract the value between +I[ and ]
287+
value = result[3:-1]
288+
parts = value.split(",")
289+
self.assertEqual(len(parts), 3)
290+
# task_name should be non-empty
291+
self.assertTrue(len(parts[0].strip()) > 0)
292+
# number_of_parallel_subtasks should be a positive integer
293+
self.assertTrue(int(parts[1].strip()) > 0)
294+
# index_of_this_subtask should be a non-negative integer
295+
self.assertTrue(int(parts[2].strip()) >= 0)
296+
269297
def test_udf_without_arguments(self):
270298
one = udf(lambda: 1, result_type=DataTypes.BIGINT(), deterministic=True)
271299
two = udf(lambda: 2, result_type=DataTypes.BIGINT(), deterministic=False)
@@ -1147,6 +1175,17 @@ def eval(self, i):
11471175
return i - self.subtracted_value
11481176

11491177

1178+
class RuntimeInfoFunc(ScalarFunction):
1179+
1180+
def open(self, function_context: FunctionContext):
1181+
self.task_name = function_context.get_task_name()
1182+
self.num_parallel = function_context.get_number_of_parallel_subtasks()
1183+
self.subtask_index = function_context.get_index_of_this_subtask()
1184+
1185+
def eval(self, i):
1186+
return "%s,%d,%d" % (self.task_name, self.num_parallel, self.subtask_index)
1187+
1188+
11501189
class CallablePlus(object):
11511190

11521191
def __call__(self, col):

flink-python/pyflink/table/udf.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,21 @@ class FunctionContext(object):
3636
"""
3737
Used to obtain global runtime information about the context in which the
3838
user-defined function is executed. The information includes the metric group,
39-
and global job parameters, etc.
39+
global job parameters, and runtime task information such as task name, parallelism, etc.
4040
"""
4141

42-
def __init__(self, base_metric_group, job_parameters):
42+
def __init__(self, base_metric_group, job_parameters,
43+
task_name=None, task_name_with_subtasks=None,
44+
number_of_parallel_subtasks=None, max_number_of_parallel_subtasks=None,
45+
index_of_this_subtask=None, attempt_number=None):
4346
self._base_metric_group = base_metric_group
4447
self._job_parameters = job_parameters
48+
self._task_name = task_name
49+
self._task_name_with_subtasks = task_name_with_subtasks
50+
self._number_of_parallel_subtasks = number_of_parallel_subtasks
51+
self._max_number_of_parallel_subtasks = max_number_of_parallel_subtasks
52+
self._index_of_this_subtask = index_of_this_subtask
53+
self._attempt_number = attempt_number
4554

4655
def get_metric_group(self) -> MetricGroup:
4756
"""
@@ -66,6 +75,57 @@ def get_job_parameter(self, key: str, default_value: str) -> str:
6675
"""
6776
return self._job_parameters[key] if key in self._job_parameters else default_value
6877

78+
def get_task_name(self) -> str:
79+
"""
80+
Returns the name of the task in which the UDF runs, as assigned during plan construction.
81+
82+
.. versionadded:: 2.2.0
83+
"""
84+
return self._task_name
85+
86+
def get_task_name_with_subtasks(self) -> str:
87+
"""
88+
Returns the name of the task, appended with the subtask indicator, such as "MyTask (3/6)",
89+
where 3 would be (:func:`get_index_of_this_subtask` + 1), and 6 would be
90+
:func:`get_number_of_parallel_subtasks`.
91+
92+
.. versionadded:: 2.2.0
93+
"""
94+
return self._task_name_with_subtasks
95+
96+
def get_number_of_parallel_subtasks(self) -> int:
97+
"""
98+
Gets the parallelism with which the parallel task runs.
99+
100+
.. versionadded:: 2.2.0
101+
"""
102+
return self._number_of_parallel_subtasks
103+
104+
def get_max_number_of_parallel_subtasks(self) -> int:
105+
"""
106+
Gets the number of max-parallelism with which the parallel task runs.
107+
108+
.. versionadded:: 2.2.0
109+
"""
110+
return self._max_number_of_parallel_subtasks
111+
112+
def get_index_of_this_subtask(self) -> int:
113+
"""
114+
Gets the number of this parallel subtask. The numbering starts from 0 and goes up to
115+
parallelism-1 (parallelism as returned by :func:`get_number_of_parallel_subtasks`).
116+
117+
.. versionadded:: 2.2.0
118+
"""
119+
return self._index_of_this_subtask
120+
121+
def get_attempt_number(self) -> int:
122+
"""
123+
Gets the attempt number of this parallel subtask. First attempt is numbered 0.
124+
125+
.. versionadded:: 2.2.0
126+
"""
127+
return self._attempt_number
128+
69129

70130
@PublicEvolving()
71131
class UserDefinedFunction(abc.ABC):

flink-python/src/main/java/org/apache/flink/python/util/ProtoUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ public static FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto(
154154
.setValue(entry.getValue())
155155
.build())
156156
.collect(Collectors.toList()));
157+
builder.setRuntimeContext(
158+
FlinkFnApi.UserDefinedDataStreamFunction.RuntimeContext.newBuilder()
159+
.setTaskName(runtimeContext.getTaskInfo().getTaskName())
160+
.setTaskNameWithSubtasks(
161+
runtimeContext.getTaskInfo().getTaskNameWithSubtasks())
162+
.setNumberOfParallelSubtasks(
163+
runtimeContext.getTaskInfo().getNumberOfParallelSubtasks())
164+
.setMaxNumberOfParallelSubtasks(
165+
runtimeContext.getTaskInfo().getMaxNumberOfParallelSubtasks())
166+
.setIndexOfThisSubtask(runtimeContext.getTaskInfo().getIndexOfThisSubtask())
167+
.setAttemptNumber(runtimeContext.getTaskInfo().getAttemptNumber())
168+
.build());
157169
return builder.build();
158170
}
159171

0 commit comments

Comments
 (0)