diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java index 80338a66dd3b6..b525e5f1e3e49 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java @@ -362,7 +362,9 @@ public boolean writeOutNext() throws IOException { compositeKeyBuilder.buildCompositeKeyUserKey(entry.getKey(), userKeySerializer); Object userValue = entry.getValue(); valueOut.writeBoolean(userValue == null); - userValueSerializer.serialize(userValue, valueOut); + if (userValue != null) { + userValueSerializer.serialize(userValue, valueOut); + } currentValue = valueOut.getCopyOfBuffer(); if (!mapEntries.hasNext()) { diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/MapStateNullValueCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/MapStateNullValueCheckpointingITCase.java index f5237534c86a6..a9ba307080627 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/MapStateNullValueCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/MapStateNullValueCheckpointingITCase.java @@ -114,6 +114,8 @@ void before() throws Exception { StatefulMapper.firstRunFuture = new CompletableFuture<>(); StatefulMapper.secondRunFuture = new CompletableFuture<>(); + NullUnsafeStatefulMapper.firstRunFuture = new CompletableFuture<>(); + NullUnsafeStatefulMapper.secondRunFuture = new CompletableFuture<>(); } @AfterEach @@ -218,6 +220,143 @@ private void restoreAndVerify(String savepointPath) throws Exception { assertThat(restoredState).containsKey("null-key"); } + /** + * Tests that MapState with null values works correctly with null-unsafe serializers (e.g., + * IntSerializer) during checkpoint/savepoint and restore. This verifies the fix in {@link + * org.apache.flink.runtime.state.heap.HeapKeyValueStateIterator} which previously would NPE + * when serializing null values during savepoint. + */ + @TestTemplate + void testMapStateWithNullUnsafeSerializerCheckpointingAndRestore() throws Exception { + final String savepointPath = runJobWithNullUnsafeSerializer(); + assertThat(savepointPath).isNotEmpty(); + restoreAndVerifyNullUnsafeSerializer(savepointPath); + } + + private String runJobWithNullUnsafeSerializer() throws Exception { + Configuration conf = new Configuration(); + conf.set( + CheckpointingOptions.CHECKPOINTS_DIRECTORY, + TempDirUtils.newFolder(tmpFolder).toURI().toString()); + conf.set(CheckpointingOptions.EXTERNALIZED_CHECKPOINT_RETENTION, RETAIN_ON_CANCELLATION); + conf.set( + CheckpointingOptions.SAVEPOINT_DIRECTORY, + TempDirUtils.newFolder(tmpFolder).toURI().toString()); + conf.set(StateBackendOptions.STATE_BACKEND, stateBackend); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(conf); + env.setParallelism(1); + + env.fromSource(createSource(), WatermarkStrategy.noWatermarks(), "Data Generator Source") + .keyBy(v -> 0) + .map(new NullUnsafeStatefulMapper(true)) + .sinkTo(new DiscardingSink<>()); + + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + MiniCluster miniCluster = cluster.getMiniCluster(); + miniCluster.submitJob(jobGraph).get(); + + JobID jobID = jobGraph.getJobID(); + NullUnsafeStatefulMapper.firstRunFuture.get(2, TimeUnit.MINUTES); + + if (snapshotType.isLeft()) { + cluster.getClusterClient() + .triggerCheckpoint(jobID, snapshotType.left()) + .get(2, TimeUnit.MINUTES); + String checkpointPath = + CommonTestUtils.getLatestCompletedCheckpointPath(jobID, miniCluster) + .orElseThrow( + () -> + new NoSuchElementException( + "No checkpoint was created yet")); + cluster.getClusterClient().cancel(jobID); + return checkpointPath; + } else { + return cluster.getClusterClient() + .stopWithSavepoint(jobID, false, null, snapshotType.right()) + .get(2, TimeUnit.MINUTES); + } + } + + private void restoreAndVerifyNullUnsafeSerializer(String savepointPath) throws Exception { + Configuration conf = new Configuration(); + conf.set( + CheckpointingOptions.CHECKPOINTS_DIRECTORY, + TempDirUtils.newFolder(tmpFolder).toURI().toString()); + conf.set( + CheckpointingOptions.SAVEPOINT_DIRECTORY, + TempDirUtils.newFolder(tmpFolder).toURI().toString()); + conf.set(StateBackendOptions.STATE_BACKEND, stateBackend); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(conf); + env.setParallelism(1); + + env.fromSource(createSource(), WatermarkStrategy.noWatermarks(), "Data Generator Source") + .keyBy(v -> 0) + .map(new NullUnsafeStatefulMapper(false)) + .sinkTo(new DiscardingSink<>()); + + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath)); + + MiniCluster miniCluster = cluster.getMiniCluster(); + miniCluster.submitJob(jobGraph).get(); + + Map restoredState = + NullUnsafeStatefulMapper.secondRunFuture.get(2, TimeUnit.MINUTES); + + assertThat(restoredState.get("key")).isEqualTo(42); + assertThat(restoredState.get("null-key")).isNull(); + assertThat(restoredState).containsKey("null-key"); + } + + /** + * A stateful mapper using IntSerializer (null-unsafe) for the map state value type. This + * exercises the code path where serializers that cannot handle null will fail during + * savepoint/checkpoint if the null-handling logic is incorrect. + */ + private static class NullUnsafeStatefulMapper extends RichMapFunction { + + static CompletableFuture firstRunFuture; + static CompletableFuture> secondRunFuture; + + private final boolean isFirstRun; + private boolean hasPopulated; + private transient MapState mapState; + + NullUnsafeStatefulMapper(boolean isFirstRun) { + this.isFirstRun = isFirstRun; + } + + @Override + public void open(OpenContext context) { + MapStateDescriptor mapStateDescriptor = + new MapStateDescriptor<>( + "map-state-int", + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO); + mapState = getRuntimeContext().getMapState(mapStateDescriptor); + hasPopulated = false; + } + + @Override + public Long map(Long value) throws Exception { + if (hasPopulated) { + return value; + } + if (isFirstRun) { + mapState.put("key", 42); + mapState.put("null-key", null); + firstRunFuture.complete(null); + } else { + Map restoredState = new HashMap<>(); + restoredState.put("key", mapState.get("key")); + restoredState.put("null-key", mapState.get("null-key")); + secondRunFuture.complete(restoredState); + } + hasPopulated = true; + return value; + } + } + private static class StatefulMapper extends RichMapFunction { static CompletableFuture firstRunFuture;