Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 106 additions & 106 deletions flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions flink-python/pyflink/fn_execution/flink_fn_execution_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,22 @@ class AsyncOptions(_message.Message):
def __init__(self, max_concurrent_operations: _Optional[int] = ..., timeout_ms: _Optional[int] = ..., retry_enabled: bool = ..., retry_max_attempts: _Optional[int] = ..., retry_delay_ms: _Optional[int] = ...) -> None: ...

class UserDefinedFunctions(_message.Message):
__slots__ = ("udfs", "metric_enabled", "windows", "profile_enabled", "job_parameters", "async_options")
__slots__ = ("udfs", "metric_enabled", "windows", "profile_enabled", "job_parameters", "async_options", "runtime_context")
UDFS_FIELD_NUMBER: _ClassVar[int]
METRIC_ENABLED_FIELD_NUMBER: _ClassVar[int]
WINDOWS_FIELD_NUMBER: _ClassVar[int]
PROFILE_ENABLED_FIELD_NUMBER: _ClassVar[int]
JOB_PARAMETERS_FIELD_NUMBER: _ClassVar[int]
ASYNC_OPTIONS_FIELD_NUMBER: _ClassVar[int]
RUNTIME_CONTEXT_FIELD_NUMBER: _ClassVar[int]
udfs: _containers.RepeatedCompositeFieldContainer[UserDefinedFunction]
metric_enabled: bool
windows: _containers.RepeatedCompositeFieldContainer[OverWindow]
profile_enabled: bool
job_parameters: _containers.RepeatedCompositeFieldContainer[JobParameter]
async_options: AsyncOptions
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedFunction, _Mapping]]] = ..., metric_enabled: bool = ..., windows: _Optional[_Iterable[_Union[OverWindow, _Mapping]]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., async_options: _Optional[_Union[AsyncOptions, _Mapping]] = ...) -> None: ...
runtime_context: UserDefinedDataStreamFunction.RuntimeContext
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedFunction, _Mapping]]] = ..., metric_enabled: bool = ..., windows: _Optional[_Iterable[_Union[OverWindow, _Mapping]]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., async_options: _Optional[_Union[AsyncOptions, _Mapping]] = ..., runtime_context: _Optional[_Union[UserDefinedDataStreamFunction.RuntimeContext, _Mapping]] = ...) -> None: ...

class OverWindow(_message.Message):
__slots__ = ("window_type", "lower_boundary", "upper_boundary")
Expand Down Expand Up @@ -195,7 +197,7 @@ class GroupWindow(_message.Message):
def __init__(self, window_type: _Optional[_Union[GroupWindow.WindowType, str]] = ..., is_time_window: bool = ..., window_slide: _Optional[int] = ..., window_size: _Optional[int] = ..., window_gap: _Optional[int] = ..., is_row_time: bool = ..., time_field_index: _Optional[int] = ..., allowedLateness: _Optional[int] = ..., namedProperties: _Optional[_Iterable[_Union[GroupWindow.WindowProperty, str]]] = ..., shift_timezone: _Optional[str] = ...) -> None: ...

class UserDefinedAggregateFunctions(_message.Message):
__slots__ = ("udfs", "metric_enabled", "grouping", "generate_update_before", "key_type", "index_of_count_star", "state_cleaning_enabled", "state_cache_size", "map_state_read_cache_size", "map_state_write_cache_size", "count_star_inserted", "group_window", "profile_enabled", "job_parameters")
__slots__ = ("udfs", "metric_enabled", "grouping", "generate_update_before", "key_type", "index_of_count_star", "state_cleaning_enabled", "state_cache_size", "map_state_read_cache_size", "map_state_write_cache_size", "count_star_inserted", "group_window", "profile_enabled", "job_parameters", "runtime_context")
UDFS_FIELD_NUMBER: _ClassVar[int]
METRIC_ENABLED_FIELD_NUMBER: _ClassVar[int]
GROUPING_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -210,6 +212,7 @@ class UserDefinedAggregateFunctions(_message.Message):
GROUP_WINDOW_FIELD_NUMBER: _ClassVar[int]
PROFILE_ENABLED_FIELD_NUMBER: _ClassVar[int]
JOB_PARAMETERS_FIELD_NUMBER: _ClassVar[int]
RUNTIME_CONTEXT_FIELD_NUMBER: _ClassVar[int]
udfs: _containers.RepeatedCompositeFieldContainer[UserDefinedAggregateFunction]
metric_enabled: bool
grouping: _containers.RepeatedScalarFieldContainer[int]
Expand All @@ -224,7 +227,8 @@ class UserDefinedAggregateFunctions(_message.Message):
group_window: GroupWindow
profile_enabled: bool
job_parameters: _containers.RepeatedCompositeFieldContainer[JobParameter]
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedAggregateFunction, _Mapping]]] = ..., metric_enabled: bool = ..., grouping: _Optional[_Iterable[int]] = ..., generate_update_before: bool = ..., key_type: _Optional[_Union[Schema.FieldType, _Mapping]] = ..., index_of_count_star: _Optional[int] = ..., state_cleaning_enabled: bool = ..., state_cache_size: _Optional[int] = ..., map_state_read_cache_size: _Optional[int] = ..., map_state_write_cache_size: _Optional[int] = ..., count_star_inserted: bool = ..., group_window: _Optional[_Union[GroupWindow, _Mapping]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ...) -> None: ...
runtime_context: UserDefinedDataStreamFunction.RuntimeContext
def __init__(self, udfs: _Optional[_Iterable[_Union[UserDefinedAggregateFunction, _Mapping]]] = ..., metric_enabled: bool = ..., grouping: _Optional[_Iterable[int]] = ..., generate_update_before: bool = ..., key_type: _Optional[_Union[Schema.FieldType, _Mapping]] = ..., index_of_count_star: _Optional[int] = ..., state_cleaning_enabled: bool = ..., state_cache_size: _Optional[int] = ..., map_state_read_cache_size: _Optional[int] = ..., map_state_write_cache_size: _Optional[int] = ..., count_star_inserted: bool = ..., group_window: _Optional[_Union[GroupWindow, _Mapping]] = ..., profile_enabled: bool = ..., job_parameters: _Optional[_Iterable[_Union[JobParameter, _Mapping]]] = ..., runtime_context: _Optional[_Union[UserDefinedDataStreamFunction.RuntimeContext, _Mapping]] = ...) -> None: ...

class Schema(_message.Message):
__slots__ = ("fields",)
Expand Down
25 changes: 25 additions & 0 deletions flink-python/pyflink/fn_execution/metrics/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ def test_metric_not_enabled(self):
with self.assertRaises(RuntimeError):
fc.get_metric_group()

def test_function_context_runtime_info(self):
fc = FunctionContext(
None, {},
task_name='MyTask',
task_name_with_subtasks='MyTask (1/4)',
number_of_parallel_subtasks=4,
max_number_of_parallel_subtasks=128,
index_of_this_subtask=0,
attempt_number=2)
self.assertEqual('MyTask', fc.get_task_name())
self.assertEqual('MyTask (1/4)', fc.get_task_name_with_subtasks())
self.assertEqual(4, fc.get_number_of_parallel_subtasks())
self.assertEqual(128, fc.get_max_number_of_parallel_subtasks())
self.assertEqual(0, fc.get_index_of_this_subtask())
self.assertEqual(2, fc.get_attempt_number())

def test_function_context_runtime_info_defaults(self):
fc = FunctionContext(None, {})
self.assertIsNone(fc.get_task_name())
self.assertIsNone(fc.get_task_name_with_subtasks())
self.assertIsNone(fc.get_number_of_parallel_subtasks())
self.assertIsNone(fc.get_max_number_of_parallel_subtasks())
self.assertIsNone(fc.get_index_of_this_subtask())
self.assertIsNone(fc.get_attempt_number())

def test_get_metric_name(self):
new_group = MetricTests.base_metric_group.add_group('my_group')
self.assertEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def __init__(self, serialized_fn):

# Job parameters
self._job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
if serialized_fn.HasField('runtime_context'):
rc = serialized_fn.runtime_context
self._runtime_context = {
'task_name': rc.task_name,
'task_name_with_subtasks': rc.task_name_with_subtasks,
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
'index_of_this_subtask': rc.index_of_this_subtask,
'attempt_number': rc.attempt_number,
}
else:
self._runtime_context = {}

def set_output_processor(self, output_processor):
"""Set the output processor for emitting results.
Expand All @@ -86,8 +98,9 @@ def open(self):
# Open user defined functions
for user_defined_func in self.user_defined_funcs:
if hasattr(user_defined_func, 'open'):
user_defined_func.open(
FunctionContext(self.base_metric_group, self._job_parameters))
user_defined_func.open(FunctionContext(
self.base_metric_group, self._job_parameters,
**self._runtime_context))

# Start emitter thread to collect async results
self._emitter = Emitter(self._mark_exception, self._output_processor, self._queue)
Expand Down
32 changes: 30 additions & 2 deletions flink-python/pyflink/fn_execution/table/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def __init__(self, serialized_fn):
self.base_metric_group = None
self.func, self.user_defined_funcs = self.generate_func(serialized_fn)
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
if serialized_fn.HasField('runtime_context'):
rc = serialized_fn.runtime_context
self.runtime_context = {
'task_name': rc.task_name,
'task_name_with_subtasks': rc.task_name_with_subtasks,
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
'index_of_this_subtask': rc.index_of_this_subtask,
'attempt_number': rc.attempt_number,
}
else:
self.runtime_context = {}

def finish(self):
self._update_gauge(self.base_metric_group)
Expand All @@ -104,7 +116,9 @@ def process_element(self, value):
def open(self):
for user_defined_func in self.user_defined_funcs:
if hasattr(user_defined_func, 'open'):
user_defined_func.open(FunctionContext(self.base_metric_group, self.job_parameters))
user_defined_func.open(FunctionContext(
self.base_metric_group, self.job_parameters,
**self.runtime_context))

def close(self):
for user_defined_func in self.user_defined_funcs:
Expand Down Expand Up @@ -326,11 +340,25 @@ def __init__(self, serialized_fn, keyed_state_backend):
self.state_cleaning_enabled = serialized_fn.state_cleaning_enabled
self.data_view_specs = extract_data_view_specs(serialized_fn.udfs)
self.job_parameters = {p.key: p.value for p in serialized_fn.job_parameters}
if serialized_fn.HasField('runtime_context'):
rc = serialized_fn.runtime_context
self.runtime_context = {
'task_name': rc.task_name,
'task_name_with_subtasks': rc.task_name_with_subtasks,
'number_of_parallel_subtasks': rc.number_of_parallel_subtasks,
'max_number_of_parallel_subtasks': rc.max_number_of_parallel_subtasks,
'index_of_this_subtask': rc.index_of_this_subtask,
'attempt_number': rc.attempt_number,
}
else:
self.runtime_context = {}
super(AbstractStreamGroupAggregateOperation, self).__init__(
serialized_fn, keyed_state_backend)

def open(self):
self.group_agg_function.open(FunctionContext(self.base_metric_group, self.job_parameters))
self.group_agg_function.open(FunctionContext(
self.base_metric_group, self.job_parameters,
**self.runtime_context))

def close(self):
self.group_agg_function.close()
Expand Down
6 changes: 6 additions & 0 deletions flink-python/pyflink/proto/flink-fn-execution.proto
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ message UserDefinedFunctions {
bool profile_enabled = 4;
repeated JobParameter job_parameters = 5;
AsyncOptions async_options = 6;
// The runtime context of the user-defined functions, providing task info such as
// task_name, parallelism, subtask_index, etc.
UserDefinedDataStreamFunction.RuntimeContext runtime_context = 7;
}

// Used to describe the info of over window in pandas batch over window aggregation
Expand Down Expand Up @@ -200,6 +203,9 @@ message UserDefinedAggregateFunctions {

bool profile_enabled = 13;
repeated JobParameter job_parameters = 14;
// The runtime context of the user-defined aggregate functions, providing task info such as
// task_name, parallelism, subtask_index, etc.
UserDefinedDataStreamFunction.RuntimeContext runtime_context = 15;
}

// A representation of the data schema.
Expand Down
39 changes: 39 additions & 0 deletions flink-python/pyflink/table/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,34 @@ def test_open(self):
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1, 1]", "+I[2, 4]", "+I[3, 3]"])

def test_function_context_runtime_info(self):
runtime_info_func = udf(RuntimeInfoFunc(), result_type=DataTypes.STRING())

sink_table = generate_random_table_name()
sink_table_ddl = f"""
CREATE TABLE {sink_table}(a STRING) WITH ('connector'='test-sink')
"""
self.t_env.execute_sql(sink_table_ddl)

t = self.t_env.from_elements([(1,)], ['a'])
t.select(runtime_info_func(t.a)).execute_insert(sink_table).wait()
actual = source_sink_utils.results()
result = actual[0]
# The result should contain task_name, number_of_parallel_subtasks, and
# index_of_this_subtask info, verifying that FunctionContext runtime info
# is properly propagated from Java to Python.
self.assertTrue(result.startswith("+I["))
# Extract the value between +I[ and ]
value = result[3:-1]
parts = value.split(",")
self.assertEqual(len(parts), 3)
# task_name should be non-empty
self.assertTrue(len(parts[0].strip()) > 0)
# number_of_parallel_subtasks should be a positive integer
self.assertTrue(int(parts[1].strip()) > 0)
# index_of_this_subtask should be a non-negative integer
self.assertTrue(int(parts[2].strip()) >= 0)

def test_udf_without_arguments(self):
one = udf(lambda: 1, result_type=DataTypes.BIGINT(), deterministic=True)
two = udf(lambda: 2, result_type=DataTypes.BIGINT(), deterministic=False)
Expand Down Expand Up @@ -1147,6 +1175,17 @@ def eval(self, i):
return i - self.subtracted_value


class RuntimeInfoFunc(ScalarFunction):

def open(self, function_context: FunctionContext):
self.task_name = function_context.get_task_name()
self.num_parallel = function_context.get_number_of_parallel_subtasks()
self.subtask_index = function_context.get_index_of_this_subtask()

def eval(self, i):
return "%s,%d,%d" % (self.task_name, self.num_parallel, self.subtask_index)


class CallablePlus(object):

def __call__(self, col):
Expand Down
64 changes: 62 additions & 2 deletions flink-python/pyflink/table/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,21 @@ class FunctionContext(object):
"""
Used to obtain global runtime information about the context in which the
user-defined function is executed. The information includes the metric group,
and global job parameters, etc.
global job parameters, and runtime task information such as task name, parallelism, etc.
"""

def __init__(self, base_metric_group, job_parameters):
def __init__(self, base_metric_group, job_parameters,
task_name=None, task_name_with_subtasks=None,
number_of_parallel_subtasks=None, max_number_of_parallel_subtasks=None,
index_of_this_subtask=None, attempt_number=None):
self._base_metric_group = base_metric_group
self._job_parameters = job_parameters
self._task_name = task_name
self._task_name_with_subtasks = task_name_with_subtasks
self._number_of_parallel_subtasks = number_of_parallel_subtasks
self._max_number_of_parallel_subtasks = max_number_of_parallel_subtasks
self._index_of_this_subtask = index_of_this_subtask
self._attempt_number = attempt_number

def get_metric_group(self) -> MetricGroup:
"""
Expand All @@ -66,6 +75,57 @@ def get_job_parameter(self, key: str, default_value: str) -> str:
"""
return self._job_parameters[key] if key in self._job_parameters else default_value

def get_task_name(self) -> str:
"""
Returns the name of the task in which the UDF runs, as assigned during plan construction.

.. versionadded:: 2.2.0
"""
return self._task_name

def get_task_name_with_subtasks(self) -> str:
"""
Returns the name of the task, appended with the subtask indicator, such as "MyTask (3/6)",
where 3 would be (:func:`get_index_of_this_subtask` + 1), and 6 would be
:func:`get_number_of_parallel_subtasks`.

.. versionadded:: 2.2.0
"""
return self._task_name_with_subtasks

def get_number_of_parallel_subtasks(self) -> int:
"""
Gets the parallelism with which the parallel task runs.

.. versionadded:: 2.2.0
"""
return self._number_of_parallel_subtasks

def get_max_number_of_parallel_subtasks(self) -> int:
"""
Gets the number of max-parallelism with which the parallel task runs.

.. versionadded:: 2.2.0
"""
return self._max_number_of_parallel_subtasks

def get_index_of_this_subtask(self) -> int:
"""
Gets the number of this parallel subtask. The numbering starts from 0 and goes up to
parallelism-1 (parallelism as returned by :func:`get_number_of_parallel_subtasks`).

.. versionadded:: 2.2.0
"""
return self._index_of_this_subtask

def get_attempt_number(self) -> int:
"""
Gets the attempt number of this parallel subtask. First attempt is numbered 0.

.. versionadded:: 2.2.0
"""
return self._attempt_number


@PublicEvolving()
class UserDefinedFunction(abc.ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ public static FlinkFnApi.UserDefinedFunctions createUserDefinedFunctionsProto(
.setValue(entry.getValue())
.build())
.collect(Collectors.toList()));
builder.setRuntimeContext(
FlinkFnApi.UserDefinedDataStreamFunction.RuntimeContext.newBuilder()
.setTaskName(runtimeContext.getTaskInfo().getTaskName())
.setTaskNameWithSubtasks(
runtimeContext.getTaskInfo().getTaskNameWithSubtasks())
.setNumberOfParallelSubtasks(
runtimeContext.getTaskInfo().getNumberOfParallelSubtasks())
.setMaxNumberOfParallelSubtasks(
runtimeContext.getTaskInfo().getMaxNumberOfParallelSubtasks())
.setIndexOfThisSubtask(runtimeContext.getTaskInfo().getIndexOfThisSubtask())
.setAttemptNumber(runtimeContext.getTaskInfo().getAttemptNumber())
.build());
return builder.build();
}

Expand Down
Loading