diff --git a/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFn.java b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFn.java new file mode 100644 index 0000000000..e74e991f1a --- /dev/null +++ b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFn.java @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dofn; + +import com.google.cloud.teleport.v2.templates.CdcDataGeneratorOptions.SinkType; +import com.google.cloud.teleport.v2.templates.model.DataGeneratorSchema; +import com.google.cloud.teleport.v2.templates.model.DataGeneratorTable; +import com.google.cloud.teleport.v2.templates.model.GeneratedRecord; +import com.google.cloud.teleport.v2.templates.model.LifecycleEvent; +import com.google.cloud.teleport.v2.templates.sink.DataWriter; +import com.google.cloud.teleport.v2.templates.sink.DataWriterFactory; +import com.google.cloud.teleport.v2.templates.utils.FailureRecord; +import com.google.cloud.teleport.v2.templates.utils.SchemaUtils; +import com.google.common.annotations.VisibleForTesting; +import java.util.List; +import java.util.function.Consumer; +import net.datafaker.Faker; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.Row; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Stateful {@link DoFn} that manages persistence lifecycles, state fields, and timers, delegating + * core traversal to {@link DataGeneratorEngine} and mutation batch writing to {@link + * MutationBatcher}. + */ +public class BatchAndWriteFn extends DoFn, String> { + + private static final Logger LOG = LoggerFactory.getLogger(BatchAndWriteFn.class); + + private final SinkType sinkType; + private final String sinkOptionsPath; + private final int batchSize; + private final Integer jdbcPoolSize; + private final Integer updateInterval; + private final Integer deleteInterval; + private final PCollectionView schemaView; + + private transient DataWriter writer; + private transient Faker faker; + private transient volatile DataGeneratorSchema schema; + private transient volatile List insertTopoOrder; + + private transient DataGeneratorEngine dataGeneratorEngine; + private transient MutationBatcher batcher; + + @StateId("eventQueue") + private final StateSpec>> eventQueueSpec = + StateSpecs.map(VarLongCoder.of(), ListCoder.of(SerializableCoder.of(LifecycleEvent.class))); + + @StateId("activeTimestamps") + private final StateSpec>> activeTimestampsSpec = + StateSpecs.value(ListCoder.of(VarLongCoder.of())); + + @StateId("tableMapState") + private final StateSpec> tableMapSpec = + StateSpecs.map(StringUtf8Coder.of(), SerializableCoder.of(DataGeneratorTable.class)); + + @StateId("insertTopoOrderState") + private final StateSpec>> insertTopoOrderSpec = + StateSpecs.value(ListCoder.of(StringUtf8Coder.of())); + + @TimerId("eventTimer") + private final TimerSpec eventTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + public BatchAndWriteFn( + SinkType sinkType, + String sinkOptionsPath, + Integer batchSize, + Integer jdbcPoolSize, + Integer updateInterval, + Integer deleteInterval, + PCollectionView schemaView) { + this.sinkType = sinkType; + this.sinkOptionsPath = sinkOptionsPath; + this.batchSize = batchSize; + this.jdbcPoolSize = jdbcPoolSize; + this.updateInterval = updateInterval; + this.deleteInterval = deleteInterval; + this.schemaView = schemaView; + } + + @Setup + public void setup() { + this.schema = null; + this.insertTopoOrder = null; + if (writer == null) { + writer = DataWriterFactory.createWriter(sinkType, sinkOptionsPath); + } + if (faker == null) { + faker = new Faker(); + } + + this.batcher = new MutationBatcher(batchSize, jdbcPoolSize, writer); + this.dataGeneratorEngine = new DataGeneratorEngine(updateInterval, deleteInterval, faker); + } + + @StartBundle + public void startBundle() { + this.batcher.startBundle(); + } + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId("eventQueue") MapState> eventQueueState, + @StateId("activeTimestamps") ValueState> activeTimestamps, + @StateId("tableMapState") MapState tableMapState, + @StateId("insertTopoOrderState") ValueState> insertTopoOrderState, + @TimerId("eventTimer") Timer eventTimer) { + + ensureSchemaInitialized(c, insertTopoOrderState); + + GeneratedRecord record = c.element().getValue(); + String tableName = record.tableName(); + Row pkValues = record.primaryKeyValues(); + + try { + dataGeneratorEngine.processRecord( + tableName, + pkValues, + eventQueueState, + activeTimestamps, + tableMapState, + eventTimer, + schema, + batcher, + insertTopoOrder); + } catch (Exception genError) { + LOG.error("Generation failed for table {}", tableName, genError); + Metrics.counter(BatchAndWriteFn.class, "generationFailures").inc(); + batcher + .getFailedRecords() + .add( + FailureRecord.toJson( + tableName, FailureRecord.OPERATION_GENERATION, pkValues, genError)); + } + + writeFailedRecords(c::output); + } + + @OnTimer("eventTimer") + public void onTimer( + OnTimerContext c, + @StateId("eventQueue") MapState> eventQueueState, + @StateId("activeTimestamps") ValueState> activeTimestamps, + @StateId("tableMapState") MapState tableMapState, + @StateId("insertTopoOrderState") ValueState> insertTopoOrderState, + @TimerId("eventTimer") Timer eventTimer) { + + if (this.insertTopoOrder == null) { + this.insertTopoOrder = insertTopoOrderState.read(); + } + + try { + dataGeneratorEngine.processScheduledEvents( + eventQueueState, + activeTimestamps, + tableMapState, + eventTimer, + batcher, + batcher.getFailedRecords(), + this.insertTopoOrder); + } catch (Exception timerError) { + LOG.error("Scheduled events generation failed during timer processing", timerError); + Metrics.counter(BatchAndWriteFn.class, "generationFailures").inc(); + batcher.getFailedRecords().add(FailureRecord.toJson("UNKNOWN_TABLE", null, null, timerError)); + } + + writeFailedRecords(c::output); + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + batcher.flushInsertsInTopoOrder(insertTopoOrder); + batcher.flushUpdates(); + batcher.flushDeletesInReverseTopoOrder(insertTopoOrder); + + List pendingDlq = batcher.getFailedRecords(); + if (pendingDlq != null && !pendingDlq.isEmpty()) { + Instant now = Instant.now(); + for (String record : pendingDlq) { + c.output(record, now, GlobalWindow.INSTANCE); + } + batcher.clearDlq(); + } + } + + @Teardown + public void teardown() { + if (writer != null) { + try { + writer.close(); + } catch (Exception e) { + throw new RuntimeException("Failed to close writer", e); + } + } + } + + private void ensureSchemaInitialized( + ProcessContext c, ValueState> insertTopoOrderState) { + if (schema != null && insertTopoOrder != null) { + return; + } + DataGeneratorSchema loaded = c.sideInput(schemaView); + this.insertTopoOrder = SchemaUtils.buildInsertTopoOrder(loaded); + insertTopoOrderState.write(this.insertTopoOrder); + this.schema = loaded; + } + + private void writeFailedRecords(Consumer sink) { + List dlq = batcher.getFailedRecords(); + if (dlq == null || dlq.isEmpty()) { + return; + } + for (String record : dlq) { + sink.accept(record); + } + batcher.clearDlq(); + } + + @VisibleForTesting + DataWriter getWriter() { + return writer; + } + + @VisibleForTesting + void setWriter(DataWriter writer) { + this.writer = writer; + } + + @VisibleForTesting + Faker getFaker() { + return faker; + } + + @VisibleForTesting + void setFaker(Faker faker) { + this.faker = faker; + } + + @VisibleForTesting + DataGeneratorSchema getSchema() { + return schema; + } + + @VisibleForTesting + void setSchema(DataGeneratorSchema schema) { + this.schema = schema; + } + + @VisibleForTesting + List getInsertTopoOrder() { + return insertTopoOrder; + } + + @VisibleForTesting + void setInsertTopoOrder(List insertTopoOrder) { + this.insertTopoOrder = insertTopoOrder; + } + + @VisibleForTesting + DataGeneratorEngine getDataGeneratorEngine() { + return dataGeneratorEngine; + } + + @VisibleForTesting + void setDataGeneratorEngine(DataGeneratorEngine dataGeneratorEngine) { + this.dataGeneratorEngine = dataGeneratorEngine; + } + + @VisibleForTesting + MutationBatcher getBatcher() { + return batcher; + } + + @VisibleForTesting + void setBatcher(MutationBatcher batcher) { + this.batcher = batcher; + } +} diff --git a/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/model/GeneratedRecord.java b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/model/GeneratedRecord.java new file mode 100644 index 0000000000..198ce0abe6 --- /dev/null +++ b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/model/GeneratedRecord.java @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.model; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import org.apache.beam.sdk.values.Row; + +/** Type-safe container wrapping table names and primary key values. */ +@AutoValue +public abstract class GeneratedRecord implements Serializable { + public abstract String tableName(); + + public abstract Row primaryKeyValues(); + + public static GeneratedRecord create(String tableName, Row primaryKeyValues) { + return new AutoValue_GeneratedRecord(tableName, primaryKeyValues); + } +} diff --git a/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactory.java b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactory.java new file mode 100644 index 0000000000..77f4768373 --- /dev/null +++ b/v2/cdc-data-generator/src/main/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactory.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.sink; + +import com.google.cloud.teleport.v2.templates.CdcDataGeneratorOptions.SinkType; +import com.google.cloud.teleport.v2.templates.mysql.MySqlDataWriter; +import com.google.cloud.teleport.v2.templates.spanner.SpannerDataWriter; + +/** + * Factory class for creating {@link DataWriter} instances based on the configured {@link SinkType}. + */ +public class DataWriterFactory { + + private DataWriterFactory() {} + + /** + * Creates a {@link DataWriter} for the specified sink type and configuration path. + * + * @param type the sink type to create a writer for + * @param configPath the path to the sink configuration document + * @return a new {@link DataWriter} instance + * @throws IllegalArgumentException if the sink type is unsupported + */ + public static DataWriter createWriter(SinkType type, String configPath) { + switch (type) { + case MYSQL: + return new MySqlDataWriter(configPath); + case SPANNER: + return new SpannerDataWriter(configPath); + default: + throw new IllegalArgumentException("Unsupported sink type: " + type); + } + } +} diff --git a/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFnTest.java b/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFnTest.java new file mode 100644 index 0000000000..4cbbd6ce18 --- /dev/null +++ b/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/dofn/BatchAndWriteFnTest.java @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dofn; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.teleport.v2.templates.CdcDataGeneratorOptions.SinkType; +import com.google.cloud.teleport.v2.templates.model.DataGeneratorColumn; +import com.google.cloud.teleport.v2.templates.model.DataGeneratorSchema; +import com.google.cloud.teleport.v2.templates.model.DataGeneratorTable; +import com.google.cloud.teleport.v2.templates.model.GeneratedRecord; +import com.google.cloud.teleport.v2.templates.model.LogicalType; +import com.google.cloud.teleport.v2.templates.sink.DataWriter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import net.datafaker.Faker; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext; +import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; +import org.apache.beam.sdk.transforms.DoFn.ProcessContext; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.Row; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Comprehensive unit tests covering code paths and lifecycle of {@link BatchAndWriteFn}. */ +@RunWith(JUnit4.class) +@SuppressWarnings("unchecked") +public class BatchAndWriteFnTest { + + @Test + public void constructor_nonPositiveBatchSize_fallsBackToDefault() { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 0, null, 10, 10, mock(PCollectionView.class)); + assertNotNull(fn); + } + + @Test + public void testSetup_initializesDefaultWriterAndFaker() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setup(); + + assertNotNull(fn.getWriter()); + assertNotNull(fn.getFaker()); + assertNotNull(fn.getBatcher()); + assertNotNull(fn.getDataGeneratorEngine()); + } + + @Test + public void testSetup_retainsInjectedWriterAndFaker() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + DataWriter mockWriter = mock(DataWriter.class); + Faker mockFaker = mock(Faker.class); + + fn.setWriter(mockWriter); + fn.setFaker(mockFaker); + + fn.setup(); + + assertEquals(mockWriter, fn.getWriter()); + assertEquals(mockFaker, fn.getFaker()); + } + + @Test + public void testStartBundle_callsBatcherStartBundle() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + MutationBatcher mockBatcher = mock(MutationBatcher.class); + fn.setBatcher(mockBatcher); + + fn.startBundle(); + + verify(mockBatcher).startBundle(); + } + + @Test + public void testProcessElement_initializesSchemaWhenNull() throws Exception { + DataGeneratorTable users = simpleUsersTable(); + DataGeneratorSchema schema = + DataGeneratorSchema.builder().tables(ImmutableMap.of(users.name(), users)).build(); + + PCollectionView schemaView = mock(PCollectionView.class); + ProcessContext c = mock(ProcessContext.class); + when(c.sideInput(schemaView)).thenReturn(schema); + + Schema rowSchema = Schema.builder().addInt64Field("id").build(); + Row row = Row.withSchema(rowSchema).addValue(1L).build(); + when(c.element()).thenReturn(KV.of(0, GeneratedRecord.create("Users", row))); + + BatchAndWriteFn fn = new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, schemaView); + DataWriter mockWriter = mock(DataWriter.class); + fn.setWriter(mockWriter); + fn.setup(); + fn.startBundle(); + + ValueState> mockInsertTopoOrderState = mock(ValueState.class); + + fn.processElement( + c, + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mockInsertTopoOrderState, + mock(Timer.class)); + + verify(c).sideInput(schemaView); + verify(mockInsertTopoOrderState).write(any(List.class)); + verify(mockWriter).insert(any(), eq(users), any(), anyInt()); + } + + @Test + public void testProcessElement_skipsSchemaInitializationWhenAlreadyLoaded() throws Exception { + DataGeneratorTable users = simpleUsersTable(); + DataGeneratorSchema schema = + DataGeneratorSchema.builder().tables(ImmutableMap.of(users.name(), users)).build(); + + PCollectionView schemaView = mock(PCollectionView.class); + ProcessContext c = mock(ProcessContext.class); + when(c.sideInput(schemaView)).thenReturn(schema); + + Schema rowSchema = Schema.builder().addInt64Field("id").build(); + Row row = Row.withSchema(rowSchema).addValue(1L).build(); + when(c.element()).thenReturn(KV.of(0, GeneratedRecord.create("Users", row))); + + BatchAndWriteFn fn = new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, schemaView); + DataWriter mockWriter = mock(DataWriter.class); + fn.setWriter(mockWriter); + fn.setup(); + fn.startBundle(); + + // Pre-populate schema and insertTopoOrder + fn.setSchema(schema); + fn.setInsertTopoOrder(ImmutableList.of("Users")); + + ValueState> mockInsertTopoOrderState = mock(ValueState.class); + + fn.processElement( + c, + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mockInsertTopoOrderState, + mock(Timer.class)); + + // verify sideInput was never called since schema was already initialized + verify(c, never()).sideInput(any()); + verify(mockInsertTopoOrderState, never()).write(any()); + verify(mockWriter).insert(any(), eq(users), any(), anyInt()); + } + + @Test + public void testProcessElement_normalExecution_flushesDlqWhenPresent() throws Exception { + DataGeneratorTable users = simpleUsersTable(); + DataGeneratorSchema schema = + DataGeneratorSchema.builder().tables(ImmutableMap.of(users.name(), users)).build(); + + PCollectionView schemaView = mock(PCollectionView.class); + ProcessContext c = mock(ProcessContext.class); + when(c.sideInput(schemaView)).thenReturn(schema); + + Schema rowSchema = Schema.builder().addInt64Field("id").build(); + Row row = Row.withSchema(rowSchema).addValue(1L).build(); + when(c.element()).thenReturn(KV.of(0, GeneratedRecord.create("Users", row))); + + BatchAndWriteFn fn = new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, schemaView); + DataWriter mockWriter = mock(DataWriter.class); + fn.setWriter(mockWriter); + fn.setup(); + fn.startBundle(); + + MutationBatcher mockBatcher = mock(MutationBatcher.class); + List dlq = new ArrayList<>(); + dlq.add("dlq_record_1"); + when(mockBatcher.getFailedRecords()).thenReturn(dlq); + fn.setBatcher(mockBatcher); + + fn.processElement( + c, + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mock(ValueState.class), + mock(Timer.class)); + + verify(c).output(eq("dlq_record_1")); + verify(mockBatcher).clearDlq(); + } + + @Test + public void testProcessElement_engineFailure_catchesAndOutputsToDlq() throws Exception { + DataGeneratorTable users = simpleUsersTable(); + DataGeneratorSchema schema = + DataGeneratorSchema.builder().tables(ImmutableMap.of(users.name(), users)).build(); + + PCollectionView schemaView = mock(PCollectionView.class); + ProcessContext c = mock(ProcessContext.class); + when(c.sideInput(schemaView)).thenReturn(schema); + + Schema rowSchema = Schema.builder().addInt64Field("id").build(); + Row row = Row.withSchema(rowSchema).addValue(1L).build(); + when(c.element()).thenReturn(KV.of(0, GeneratedRecord.create("Users", row))); + + BatchAndWriteFn fn = new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, schemaView); + DataWriter mockWriter = mock(DataWriter.class); + doThrow(new RuntimeException("simulated sink failure")) + .when(mockWriter) + .insert(any(), any(), any(), anyInt()); + fn.setWriter(mockWriter); + fn.setup(); + fn.startBundle(); + + fn.processElement( + c, + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mock(ValueState.class), + mock(Timer.class)); + + verify(c).output(any(String.class)); + } + + @Test + public void testOnTimer_restoresInsertTopoOrderFromStateWhenNull() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setWriter(mock(DataWriter.class)); + fn.setup(); + fn.startBundle(); + + fn.setDataGeneratorEngine(mock(DataGeneratorEngine.class)); + // ensure insertTopoOrder is null in memory + fn.setInsertTopoOrder(null); + + ValueState> mockInsertTopoOrderState = mock(ValueState.class); + when(mockInsertTopoOrderState.read()).thenReturn(ImmutableList.of("TableA", "TableB")); + + fn.onTimer( + mock(OnTimerContext.class), + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mockInsertTopoOrderState, + mock(Timer.class)); + + verify(mockInsertTopoOrderState).read(); + assertEquals(ImmutableList.of("TableA", "TableB"), fn.getInsertTopoOrder()); + } + + @Test + public void testOnTimer_skipsStateReadWhenInsertTopoOrderIsPresent() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setWriter(mock(DataWriter.class)); + fn.setup(); + fn.startBundle(); + + fn.setDataGeneratorEngine(mock(DataGeneratorEngine.class)); + // Pre-populate insertTopoOrder in memory + fn.setInsertTopoOrder(ImmutableList.of("TableA")); + + ValueState> mockInsertTopoOrderState = mock(ValueState.class); + + fn.onTimer( + mock(OnTimerContext.class), + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mockInsertTopoOrderState, + mock(Timer.class)); + + verify(mockInsertTopoOrderState, never()).read(); + } + + @Test + public void testOnTimer_exceptionRoutesToDlq() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setWriter(mock(DataWriter.class)); + fn.setup(); + fn.startBundle(); + + DataGeneratorEngine mockEngine = mock(DataGeneratorEngine.class); + doThrow(new RuntimeException("timer failure")) + .when(mockEngine) + .processScheduledEvents(any(), any(), any(), any(), any(), any(), any()); + fn.setDataGeneratorEngine(mockEngine); + + OnTimerContext c = mock(OnTimerContext.class); + + fn.onTimer( + c, + mock(MapState.class), + mock(ValueState.class), + mock(MapState.class), + mock(ValueState.class), + mock(Timer.class)); + + verify(c).output(any(String.class)); + } + + @Test + public void testFinishBundle_flushesBatcherAndEmitsPendingDlq() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setWriter(mock(DataWriter.class)); + fn.setup(); + fn.startBundle(); + + List topoOrder = ImmutableList.of("TableA", "TableB"); + fn.setInsertTopoOrder(topoOrder); + + MutationBatcher mockBatcher = mock(MutationBatcher.class); + when(mockBatcher.getFailedRecords()).thenReturn(Arrays.asList("dlq_record_1")); + fn.setBatcher(mockBatcher); + + FinishBundleContext context = mock(FinishBundleContext.class); + + fn.finishBundle(context); + + verify(mockBatcher).flushInsertsInTopoOrder(eq(topoOrder)); + verify(mockBatcher).flushUpdates(); + verify(mockBatcher).flushDeletesInReverseTopoOrder(eq(topoOrder)); + verify(context).output(eq("dlq_record_1"), any(Instant.class), eq(GlobalWindow.INSTANCE)); + verify(mockBatcher).clearDlq(); + } + + @Test + public void testTeardown_closesWriterSuccessfully() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + DataWriter mockWriter = mock(DataWriter.class); + fn.setWriter(mockWriter); + + fn.teardown(); + + verify(mockWriter).close(); + } + + @Test + public void testTeardown_nullWriterDoesNothing() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + fn.setWriter(null); + + fn.teardown(); + } + + @Test(expected = RuntimeException.class) + public void testTeardown_writerCloseThrowsException() throws Exception { + BatchAndWriteFn fn = + new BatchAndWriteFn(SinkType.SPANNER, "{}", 1, null, 10, 10, mock(PCollectionView.class)); + DataWriter mockWriter = mock(DataWriter.class); + doThrow(new RuntimeException("simulated close error")).when(mockWriter).close(); + fn.setWriter(mockWriter); + + fn.teardown(); + } + + // =========================================================================== + // Helpers + // =========================================================================== + + private static DataGeneratorTable simpleUsersTable() { + return DataGeneratorTable.builder() + .name("Users") + .columns(ImmutableList.of(intColumn("id"))) + .primaryKeys(ImmutableList.of("id")) + .foreignKeys(ImmutableList.of()) + .uniqueKeys(ImmutableList.of()) + .insertQps(1) + .updateQps(0) + .deleteQps(0) + .isRoot(true) + .recordsPerTick(1.0) + .build(); + } + + private static DataGeneratorColumn intColumn(String name) { + return DataGeneratorColumn.builder() + .name(name) + .logicalType(LogicalType.INT64) + .isPrimaryKey(false) + .isNullable(false) + .isSkipped(false) + .isGenerated(false) + .size(null) + .precision(null) + .scale(null) + .build(); + } +} diff --git a/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactoryTest.java b/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactoryTest.java new file mode 100644 index 0000000000..68d26dcc2d --- /dev/null +++ b/v2/cdc-data-generator/src/test/java/com/google/cloud/teleport/v2/templates/sink/DataWriterFactoryTest.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.sink; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.teleport.v2.templates.CdcDataGeneratorOptions.SinkType; +import com.google.cloud.teleport.v2.templates.mysql.MySqlDataWriter; +import com.google.cloud.teleport.v2.templates.spanner.SpannerDataWriter; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Comprehensive unit tests for {@link DataWriterFactory}. */ +@RunWith(JUnit4.class) +public class DataWriterFactoryTest { + + @Test + public void testCreateWriter_mySql() { + DataWriter writer = DataWriterFactory.createWriter(SinkType.MYSQL, "{}"); + assertNotNull(writer); + assertTrue(writer instanceof MySqlDataWriter); + } + + @Test + public void testCreateWriter_spanner() { + DataWriter writer = DataWriterFactory.createWriter(SinkType.SPANNER, "{}"); + assertNotNull(writer); + assertTrue(writer instanceof SpannerDataWriter); + } + + @Test(expected = NullPointerException.class) + public void testCreateWriter_unsupportedThrowsException() { + DataWriterFactory.createWriter(null, "{}"); + } +}