diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/constants/Constants.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/constants/Constants.java index 0177383acb..6953a3ae79 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/constants/Constants.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/constants/Constants.java @@ -30,6 +30,9 @@ public class Constants { /* The value for Postgres databases in the source type key */ public static final String POSTGRES_SOURCE_TYPE = "postgresql"; + /* The value for Spanner databases in the source type key */ + public static final String SPANNER_SOURCE_TYPE = "spanner"; + /* The run mode for retryDLQ */ public static final String RUN_MODE_RETRY_DLQ = "retryDLQ"; diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShard.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShard.java new file mode 100644 index 0000000000..c7f2dfc621 --- /dev/null +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShard.java @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.shard; + +import java.util.Objects; + +/** + * Represents a shard targeting a Cloud Spanner database. The {@code projectId} is stored as a + * dedicated field; {@code instanceId} maps to the parent's {@code namespace} field and {@code + * databaseId} maps to the parent's {@code dbName} field. + */ +public class SpannerShard extends Shard { + + private final String projectId; + + public SpannerShard( + String logicalShardId, String projectId, String instanceId, String databaseId) { + super(); + this.projectId = projectId; + setLogicalShardId(logicalShardId); + setNamespace(instanceId); + setDbName(databaseId); + } + + public String getProjectId() { + return projectId; + } + + public String getInstanceId() { + return getNamespace(); + } + + public String getDatabaseId() { + return getDbName(); + } + + @Override + public String toString() { + return String.format( + "SpannerShard{logicalShardId='%s', projectId='%s', instanceId='%s', databaseId='%s'}", + getLogicalShardId(), projectId, getInstanceId(), getDatabaseId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SpannerShard)) { + return false; + } + SpannerShard that = (SpannerShard) o; + return Objects.equals(projectId, that.projectId) + && Objects.equals(getInstanceId(), that.getInstanceId()) + && Objects.equals(getDatabaseId(), that.getDatabaseId()); + } + + @Override + public int hashCode() { + return Objects.hash(projectId, getInstanceId(), getDatabaseId()); + } +} diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/SpannerShardFileReader.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/SpannerShardFileReader.java new file mode 100644 index 0000000000..30301282de --- /dev/null +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/utils/SpannerShardFileReader.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.utils; + +import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; +import com.google.gson.FieldNamingPolicy; +import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Type; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reads a JSON array of Spanner shard configurations from GCS and returns a list of {@link + * SpannerShard} instances. + * + *

Each entry in the JSON file must contain the following fields: + * + *

+ * [
+ *   {
+ *     "logicalShardId": "shard1",
+ *     "projectId": "my-gcp-project",
+ *     "instanceId": "my-spanner-instance",
+ *     "databaseId": "my-database"
+ *   }
+ * ]
+ * 
+ */ +public class SpannerShardFileReader { + + private static final Logger LOG = LoggerFactory.getLogger(SpannerShardFileReader.class); + + /** + * Reads Spanner shard configuration from the given GCS file path. + * + * @param shardsFilePath GCS path to the JSON shard config file. + * @return list of {@link SpannerShard} objects, sorted by logicalShardId. + */ + public List getSpannerShards(String shardsFilePath) { + try (InputStream stream = + Channels.newInputStream( + FileSystems.open(FileSystems.matchNewResource(shardsFilePath, false)))) { + + String result = IOUtils.toString(stream, StandardCharsets.UTF_8); + Type listType = new TypeToken>>() {}.getType(); + List> shardConfigs = + new GsonBuilder() + .setFieldNamingPolicy(FieldNamingPolicy.IDENTITY) + .create() + .fromJson(result, listType); + + List shards = new ArrayList<>(); + for (Map config : shardConfigs) { + String logicalShardId = config.getOrDefault("logicalShardId", ""); + String projectId = config.get("projectId"); + String instanceId = config.get("instanceId"); + String databaseId = config.get("databaseId"); + if (projectId == null || instanceId == null || databaseId == null) { + throw new RuntimeException( + "SpannerShard config at '" + + shardsFilePath + + "' is missing one or more required fields: projectId, instanceId, databaseId"); + } + shards.add(new SpannerShard(logicalShardId, projectId, instanceId, databaseId)); + LOG.info( + "Loaded SpannerShard: logicalShardId={}, project={}, instance={}, database={}", + logicalShardId, + projectId, + instanceId, + databaseId); + } + + shards.sort(Comparator.comparing(Shard::getLogicalShardId)); + LOG.info("Read {} Spanner shard(s) from {}", shards.size(), shardsFilePath); + return shards; + + } catch (IOException e) { + throw new RuntimeException("Failed to read Spanner shard config file: " + shardsFilePath, e); + } + } +} diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SourceDatabaseType.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SourceDatabaseType.java index b28c9dd039..ffe6a63d08 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SourceDatabaseType.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SourceDatabaseType.java @@ -22,5 +22,6 @@ public enum SourceDatabaseType { CASSANDRA, ORACLE, SQLSERVER, + SPANNER, // Add more database types as needed } diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SpannerInformationSchemaScanner.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SpannerInformationSchemaScanner.java new file mode 100644 index 0000000000..b8ab1de2dd --- /dev/null +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/sourceddl/SpannerInformationSchemaScanner.java @@ -0,0 +1,170 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.sourceddl; + +import com.google.cloud.spanner.BatchClient; +import com.google.cloud.spanner.BatchReadOnlyTransaction; +import com.google.cloud.spanner.TimestampBound; +import com.google.cloud.teleport.v2.spanner.ddl.Column; +import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.IndexColumn; +import com.google.cloud.teleport.v2.spanner.ddl.InformationSchemaScanner; +import com.google.cloud.teleport.v2.spanner.ddl.Table; +import com.google.cloud.teleport.v2.spanner.type.Type; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.io.gcp.spanner.SpannerAccessor; +import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Scans a Cloud Spanner database's information schema and converts it into a {@link SourceSchema}. + * + *

Uses the existing {@link InformationSchemaScanner} to read the target Spanner DDL and then + * maps each {@link Table} and {@link Column} into the {@link SourceTable}/{@link SourceColumn} + * model so that the rest of the reverse-replication pipeline can treat Spanner as just another + * source type. + */ +public class SpannerInformationSchemaScanner implements SourceSchemaScanner { + + private static final Logger LOG = LoggerFactory.getLogger(SpannerInformationSchemaScanner.class); + + private final SpannerConfig spannerConfig; + private final SourceDatabaseType sourceType = SourceDatabaseType.SPANNER; + + public SpannerInformationSchemaScanner(SpannerConfig spannerConfig) { + this.spannerConfig = spannerConfig; + } + + @Override + public SourceSchema scan() { + SpannerAccessor accessor = SpannerAccessor.getOrCreate(spannerConfig); + try { + BatchClient batchClient = accessor.getBatchClient(); + BatchReadOnlyTransaction txn = batchClient.batchReadOnlyTransaction(TimestampBound.strong()); + InformationSchemaScanner scanner = new InformationSchemaScanner(txn); + Ddl ddl = scanner.scan(); + LOG.info("Scanned Spanner schema for database '{}'", spannerConfig.getDatabaseId().get()); + return convertDdlToSourceSchema(ddl); + } finally { + accessor.close(); + } + } + + private SourceSchema convertDdlToSourceSchema(Ddl ddl) { + Map tables = new HashMap<>(); + for (Table spannerTable : ddl.allTables()) { + SourceTable sourceTable = convertTable(spannerTable); + tables.put(sourceTable.name(), sourceTable); + } + return SourceSchema.builder(sourceType) + .databaseName(spannerConfig.getDatabaseId().get()) + .tables(ImmutableMap.copyOf(tables)) + .build(); + } + + private SourceTable convertTable(Table spannerTable) { + List pkColumns = new ArrayList<>(); + for (IndexColumn pk : spannerTable.primaryKeys()) { + pkColumns.add(pk.name()); + } + + List columns = new ArrayList<>(); + for (Column col : spannerTable.columns()) { + SourceColumn sourceCol = + SourceColumn.builder(sourceType) + .name(col.name()) + .type(spannerTypeToString(col.type())) + .isNullable(!col.notNull()) + .isPrimaryKey(pkColumns.contains(col.name())) + .isGenerated(col.isGenerated()) + .columnOptions(ImmutableList.of()) + .build(); + columns.add(sourceCol); + } + + return SourceTable.builder(sourceType) + .name(spannerTable.name()) + .columns(ImmutableList.copyOf(columns)) + .primaryKeyColumns(ImmutableList.copyOf(pkColumns)) + .foreignKeys(ImmutableList.of()) + .indexes(ImmutableList.of()) + .build(); + } + + /** + * Converts a Spanner {@link Type} to a canonical type-name string used in {@link SourceColumn}. + */ + static String spannerTypeToString(Type type) { + switch (type.getCode()) { + case BOOL: + return "BOOL"; + case INT64: + return "INT64"; + case FLOAT32: + return "FLOAT32"; + case FLOAT64: + return "FLOAT64"; + case STRING: + return "STRING"; + case BYTES: + return "BYTES"; + case DATE: + return "DATE"; + case TIMESTAMP: + return "TIMESTAMP"; + case NUMERIC: + return "NUMERIC"; + case JSON: + return "JSON"; + case PG_NUMERIC: + return "PG_NUMERIC"; + case PG_JSONB: + return "PG_JSONB"; + case PG_FLOAT4: + return "PG_FLOAT4"; + case PG_FLOAT8: + return "PG_FLOAT8"; + case PG_TEXT: + return "PG_TEXT"; + case PG_VARCHAR: + return "PG_VARCHAR"; + case PG_BOOL: + return "PG_BOOL"; + case PG_BYTEA: + return "PG_BYTEA"; + case PG_DATE: + return "PG_DATE"; + case PG_TIMESTAMPTZ: + return "PG_TIMESTAMPTZ"; + case PG_COMMIT_TIMESTAMP: + return "PG_COMMIT_TIMESTAMP"; + case PG_INT8: + return "PG_INT8"; + case ARRAY: + return "ARRAY<" + spannerTypeToString(type.getArrayElementType()) + ">"; + case PG_ARRAY: + return "PG_ARRAY<" + spannerTypeToString(type.getArrayElementType()) + ">"; + default: + return type.getCode().name(); + } + } +} diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShardTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShardTest.java new file mode 100644 index 0000000000..68104e84d1 --- /dev/null +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/shard/SpannerShardTest.java @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.spanner.migrations.shard; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SpannerShardTest { + + @Test + public void gettersReturnConstructorValues() { + SpannerShard shard = new SpannerShard("shard1", "p1", "i1", "d1"); + assertEquals("shard1", shard.getLogicalShardId()); + assertEquals("p1", shard.getProjectId()); + assertEquals("i1", shard.getInstanceId()); + assertEquals("d1", shard.getDatabaseId()); + } + + @Test + public void instanceIdMapsToNamespace() { + SpannerShard shard = new SpannerShard("s", "p", "myinstance", "d"); + assertEquals("myinstance", shard.getNamespace()); + } + + @Test + public void databaseIdMapsToDbName() { + SpannerShard shard = new SpannerShard("s", "p", "i", "mydb"); + assertEquals("mydb", shard.getDbName()); + } + + @Test + public void equalsReturnsTrueForSameInstance() { + SpannerShard a = new SpannerShard("s", "p", "i", "d"); + assertEquals(a, a); + } + + @Test + public void equalsIgnoresLogicalShardId() { + SpannerShard a = new SpannerShard("s1", "p", "i", "d"); + SpannerShard b = new SpannerShard("s2", "p", "i", "d"); + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + } + + @Test + public void equalsReturnsFalseForDifferentProject() { + SpannerShard a = new SpannerShard("s", "p1", "i", "d"); + SpannerShard b = new SpannerShard("s", "p2", "i", "d"); + assertNotEquals(a, b); + } + + @Test + public void equalsReturnsFalseForDifferentInstance() { + SpannerShard a = new SpannerShard("s", "p", "i1", "d"); + SpannerShard b = new SpannerShard("s", "p", "i2", "d"); + assertNotEquals(a, b); + } + + @Test + public void equalsReturnsFalseForDifferentDatabase() { + SpannerShard a = new SpannerShard("s", "p", "i", "d1"); + SpannerShard b = new SpannerShard("s", "p", "i", "d2"); + assertNotEquals(a, b); + } + + @Test + public void equalsReturnsFalseForNull() { + SpannerShard a = new SpannerShard("s", "p", "i", "d"); + assertNotEquals(a, null); + } + + @Test + public void equalsReturnsFalseForDifferentType() { + SpannerShard a = new SpannerShard("s", "p", "i", "d"); + assertNotEquals(a, "not-a-shard"); + } + + @Test + public void toStringIncludesAllFields() { + SpannerShard shard = new SpannerShard("shard1", "proj", "inst", "db"); + String s = shard.toString(); + assertTrue(s.contains("shard1")); + assertTrue(s.contains("proj")); + assertTrue(s.contains("inst")); + assertTrue(s.contains("db")); + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java index fff701cd9e..4336ac7b53 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/SpannerToSourceDb.java @@ -20,6 +20,7 @@ import static com.google.cloud.teleport.v2.spanner.migrations.constants.Constants.RUN_MODE_REGULAR; import static com.google.cloud.teleport.v2.spanner.migrations.constants.Constants.RUN_MODE_RETRY_ALL_DLQ; import static com.google.cloud.teleport.v2.spanner.migrations.constants.Constants.RUN_MODE_RETRY_DLQ; +import static com.google.cloud.teleport.v2.spanner.migrations.constants.Constants.SPANNER_SOURCE_TYPE; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; @@ -38,17 +39,20 @@ import com.google.cloud.teleport.v2.spanner.ddl.Ddl; import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation; import com.google.cloud.teleport.v2.spanner.migrations.utils.CassandraConfigFileReader; import com.google.cloud.teleport.v2.spanner.migrations.utils.CassandraDriverConfigLoader; import com.google.cloud.teleport.v2.spanner.migrations.utils.DataflowWorkerMachineTypeUtils; import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl; import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader; +import com.google.cloud.teleport.v2.spanner.migrations.utils.SpannerShardFileReader; import com.google.cloud.teleport.v2.spanner.sourceddl.CassandraInformationSchemaScanner; import com.google.cloud.teleport.v2.spanner.sourceddl.MySqlInformationSchemaScanner; import com.google.cloud.teleport.v2.spanner.sourceddl.PostgreSQLInformationSchemaScanner; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchemaScanner; +import com.google.cloud.teleport.v2.spanner.sourceddl.SpannerInformationSchemaScanner; import com.google.cloud.teleport.v2.templates.SpannerToSourceDb.Options; import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord; import com.google.cloud.teleport.v2.templates.constants.Constants; @@ -404,7 +408,8 @@ public interface Options extends PipelineOptions, StreamingOptions { enumOptions = { @TemplateEnumOption("mysql"), @TemplateEnumOption("cassandra"), - @TemplateEnumOption("postgresql") + @TemplateEnumOption("postgresql"), + @TemplateEnumOption("spanner") }, helpText = "The type of source database to reverse replicate to.") @Default.String("mysql") @@ -653,6 +658,12 @@ public static PipelineResult run(Options options) { shards = shardFileReader.getOrderedShardDetails(options.getSourceShardsFilePath()); shardingMode = Constants.SHARDING_MODE_MULTI_SHARD; + } else if (SPANNER_SOURCE_TYPE.equals(options.getSourceType())) { + SpannerShardFileReader spannerShardFileReader = new SpannerShardFileReader(); + shards = spannerShardFileReader.getSpannerShards(options.getSourceShardsFilePath()); + LOG.info("Spanner target shard config: {}", shards.get(0)); + shardingMode = Constants.SHARDING_MODE_SINGLE_SHARD; + } else { CassandraConfigFileReader cassandraConfigFileReader = new CassandraConfigFileReader(); shards = cassandraConfigFileReader.getCassandraShard(options.getSourceShardsFilePath()); @@ -1041,6 +1052,15 @@ private static SourceSchema fetchSourceSchema(Options options, List shard connection, shards.get(0).getDbName(), shards.get(0).getNamespace()); sourceSchema = scanner.scan(); connection.close(); + } else if (options.getSourceType().equals(SPANNER_SOURCE_TYPE)) { + SpannerShard spannerShard = (SpannerShard) shards.get(0); + SpannerConfig targetSpannerConfig = + SpannerConfig.create() + .withProjectId(spannerShard.getProjectId()) + .withInstanceId(spannerShard.getInstanceId()) + .withDatabaseId(spannerShard.getDatabaseId()); + scanner = new SpannerInformationSchemaScanner(targetSpannerConfig); + sourceSchema = scanner.scan(); } else { try (CqlSession session = createCqlSession((CassandraShard) shards.get(0))) { scanner = diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java index 2157aa8715..8b1012680d 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/constants/Constants.java @@ -79,6 +79,8 @@ public class Constants { public static final String SOURCE_POSTGRESQL = "postgresql"; + public static final String SOURCE_SPANNER = "spanner"; + // Message written to the file for filtered records public static final String FILTERED_TAG_MESSAGE = "Filtered record from custom transformation in reverse replication"; diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelper.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelper.java new file mode 100644 index 0000000000..bdb1dd8ad1 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelper.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.connection; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.teleport.v2.spanner.migrations.connection.ConnectionHelperRequest; +import com.google.cloud.teleport.v2.spanner.migrations.connection.IConnectionHelper; +import com.google.cloud.teleport.v2.spanner.migrations.exceptions.ConnectionException; +import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Manages {@link DatabaseClient} connections to target Cloud Spanner databases. */ +public class SpannerConnectionHelper implements IConnectionHelper { + + private static final Logger LOG = LoggerFactory.getLogger(SpannerConnectionHelper.class); + + private static Map clientMap = new ConcurrentHashMap<>(); + private static Spanner spannerService; + + @Override + public synchronized void init(ConnectionHelperRequest connectionHelperRequest) { + if (!clientMap.isEmpty()) { + LOG.info("Spanner connection pool is already initialized."); + return; + } + + List shards = connectionHelperRequest.getShards(); + String projectId = ((SpannerShard) shards.get(0)).getProjectId(); + spannerService = SpannerOptions.newBuilder().setProjectId(projectId).build().getService(); + + for (Shard shard : shards) { + if (!(shard instanceof SpannerShard)) { + throw new IllegalArgumentException( + "Expected SpannerShard but got: " + shard.getClass().getSimpleName()); + } + SpannerShard spannerShard = (SpannerShard) shard; + String key = connectionKey(spannerShard); + DatabaseClient client = + spannerService.getDatabaseClient( + DatabaseId.of( + spannerShard.getProjectId(), + spannerShard.getInstanceId(), + spannerShard.getDatabaseId())); + clientMap.put(key, client); + LOG.info("Initialized Spanner connection for key: {}", key); + } + } + + @Override + public DatabaseClient getConnection(String connectionRequestKey) throws ConnectionException { + if (clientMap.isEmpty()) { + throw new ConnectionException("Spanner connection pool is not initialized."); + } + DatabaseClient client = clientMap.get(connectionRequestKey); + if (client == null) { + throw new ConnectionException("No Spanner connection found for key: " + connectionRequestKey); + } + return client; + } + + @Override + public boolean isConnectionPoolInitialized() { + return !clientMap.isEmpty(); + } + + public static String connectionKey(SpannerShard shard) { + return shard.getProjectId() + "/" + shard.getInstanceId() + "/" + shard.getDatabaseId(); + } + + /** For unit testing only. */ + public void setClientMap(Map inputMap) { + clientMap = inputMap; + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDao.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDao.java new file mode 100644 index 0000000000..feec0e8ea6 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDao.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.dao.source; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.teleport.v2.spanner.migrations.connection.IConnectionHelper; +import com.google.cloud.teleport.v2.spanner.migrations.exceptions.ConnectionException; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; +import com.google.cloud.teleport.v2.templates.models.SpannerMutationResponse; +import com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DAO for writing reverse-replicated records to a target Cloud Spanner database. + * + *

Receives a {@link SpannerMutationResponse} from the DML generator and commits the contained + * {@link Mutation} via a {@link DatabaseClient} obtained from the {@link + * com.google.cloud.teleport.v2.templates.dbutils.connection.SpannerConnectionHelper}. + */ +public class SpannerTargetDao implements IDao { + + private static final Logger LOG = LoggerFactory.getLogger(SpannerTargetDao.class); + + private final String connectionKey; + private final IConnectionHelper connectionHelper; + + public SpannerTargetDao( + String connectionKey, IConnectionHelper connectionHelper) { + this.connectionKey = connectionKey; + this.connectionHelper = connectionHelper; + } + + @Override + public void write( + DMLGeneratorResponse dmlGeneratorResponse, TransactionalCheck transactionalCheck) + throws Exception { + if (transactionalCheck != null) { + throw new UnsupportedOperationException( + "TransactionalCheck is not supported for the Spanner target DAO."); + } + + if (!(dmlGeneratorResponse instanceof SpannerMutationResponse)) { + throw new IllegalArgumentException( + "Expected SpannerMutationResponse but received: " + + dmlGeneratorResponse.getClass().getSimpleName()); + } + + DatabaseClient client = connectionHelper.getConnection(connectionKey); + if (client == null) { + throw new ConnectionException("DatabaseClient is null for connection key: " + connectionKey); + } + + Mutation mutation = ((SpannerMutationResponse) dmlGeneratorResponse).getMutation(); + client.writeAtLeastOnce(ImmutableList.of(mutation)); + LOG.debug("Successfully wrote mutation via SpannerTargetDao for key: {}", connectionKey); + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGenerator.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGenerator.java new file mode 100644 index 0000000000..e5585eb07f --- /dev/null +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGenerator.java @@ -0,0 +1,624 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.dml; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.Key; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Value; +import com.google.cloud.teleport.v2.spanner.ddl.Column; +import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.IndexColumn; +import com.google.cloud.teleport.v2.spanner.ddl.Table; +import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; +import com.google.cloud.teleport.v2.spanner.type.Type; +import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; +import com.google.cloud.teleport.v2.templates.models.SpannerMutationResponse; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import org.json.JSONArray; +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Generates Spanner {@link Mutation} objects for reverse-replication to a Cloud Spanner target. + * + *

INSERT and UPDATE change-stream events produce an {@code insertOrUpdate} mutation. DELETE + * events produce a {@code delete} mutation keyed on the primary-key values from the change record. + * + *

Value conversion reads directly from the source Spanner DDL column type ({@link Type}), which + * avoids ambiguity when the source and target share the same type system. + */ +public class SpannerDMLGenerator implements IDMLGenerator { + + private static final Logger LOG = LoggerFactory.getLogger(SpannerDMLGenerator.class); + + @Override + public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest request) { + if (request == null) { + throw new InvalidDMLGenerationException( + "DMLGeneratorRequest is null. Cannot process the request."); + } + + String spannerTableName = request.getSpannerTableName(); + ISchemaMapper schemaMapper = request.getSchemaMapper(); + Ddl spannerDdl = request.getSpannerDdl(); + SourceSchema sourceSchema = request.getSourceSchema(); + + if (schemaMapper == null) { + throw new InvalidDMLGenerationException("SchemaMapper must not be null."); + } + if (spannerDdl == null) { + throw new InvalidDMLGenerationException("Spanner DDL must not be null."); + } + if (sourceSchema == null) { + throw new InvalidDMLGenerationException("SourceSchema must not be null."); + } + + Table spannerTable = spannerDdl.table(spannerTableName); + if (spannerTable == null) { + throw new InvalidDMLGenerationException( + "Spanner table '" + spannerTableName + "' not found in DDL."); + } + + String targetTableName; + try { + targetTableName = schemaMapper.getSourceTableName("", spannerTableName); + } catch (NoSuchElementException e) { + throw new InvalidDMLGenerationException( + "Could not find target table name for Spanner table: " + spannerTableName, e); + } + + SourceTable targetTable = sourceSchema.table(targetTableName); + if (targetTable == null) { + throw new InvalidDMLGenerationException( + "Target table '" + targetTableName + "' not found in SourceSchema."); + } + + if (targetTable.primaryKeyColumns() == null || targetTable.primaryKeyColumns().isEmpty()) { + throw new InvalidDMLGenerationException( + "Cannot reverse replicate to target table '" + + targetTableName + + "' without a primary key."); + } + + String modType = request.getModType(); + if ("INSERT".equals(modType) || "UPDATE".equals(modType)) { + return buildUpsertMutation(spannerTable, targetTable, schemaMapper, request, targetTableName); + } else if ("DELETE".equals(modType)) { + return buildDeleteMutation(spannerTable, targetTable, schemaMapper, request, targetTableName); + } else { + throw new InvalidDMLGenerationException( + "Unsupported modType '" + modType + "' for table " + spannerTableName); + } + } + + private static DMLGeneratorResponse buildUpsertMutation( + Table spannerTable, + SourceTable targetTable, + ISchemaMapper schemaMapper, + DMLGeneratorRequest request, + String targetTableName) { + + Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(targetTableName); + JSONObject newValuesJson = request.getNewValuesJson(); + JSONObject keyValuesJson = request.getKeyValuesJson(); + + for (SourceColumn targetCol : targetTable.columns()) { + if (targetCol.isGenerated()) { + continue; + } + + String targetColName = targetCol.name(); + + String sourceColName; + try { + sourceColName = schemaMapper.getSpannerColumnName("", targetTable.name(), targetColName); + } catch (NoSuchElementException e) { + continue; + } + + Column sourceCol = spannerTable.column(sourceColName); + if (sourceCol == null) { + continue; + } + + if (request.getCustomTransformationResponse() != null + && request.getCustomTransformationResponse().containsKey(targetColName)) { + Object customVal = request.getCustomTransformationResponse().get(targetColName); + if (customVal == null) { + setNullValue(builder, targetColName, sourceCol.type()); + } else { + setCustomColumnValue(builder, targetColName, sourceCol, customVal); + } + continue; + } + + JSONObject valuesJson = keyValuesJson.has(sourceColName) ? keyValuesJson : newValuesJson; + + if (!valuesJson.has(sourceColName)) { + continue; + } + + if (valuesJson.isNull(sourceColName)) { + setNullValue(builder, targetColName, sourceCol.type()); + } else { + setColumnValue(builder, targetColName, sourceCol, valuesJson); + } + } + + return new SpannerMutationResponse(builder.build()); + } + + private static DMLGeneratorResponse buildDeleteMutation( + Table spannerTable, + SourceTable targetTable, + ISchemaMapper schemaMapper, + DMLGeneratorRequest request, + String targetTableName) { + + JSONObject keyValuesJson = request.getKeyValuesJson(); + JSONObject newValuesJson = request.getNewValuesJson(); + + Key.Builder keyBuilder = Key.newBuilder(); + for (IndexColumn pkIndexCol : spannerTable.primaryKeys()) { + String sourceColName = pkIndexCol.name(); + + String targetColName; + try { + targetColName = schemaMapper.getSourceColumnName("", targetTableName, sourceColName); + } catch (NoSuchElementException e) { + targetColName = sourceColName; + } + + Column sourceCol = spannerTable.column(sourceColName); + if (sourceCol == null) { + throw new InvalidDMLGenerationException( + "Column '" + sourceColName + "' not found in Spanner DDL for table " + targetTableName); + } + + if (request.getCustomTransformationResponse() != null + && request.getCustomTransformationResponse().containsKey(targetColName)) { + Object customVal = request.getCustomTransformationResponse().get(targetColName); + appendCustomKeyComponent(keyBuilder, sourceCol, customVal); + continue; + } + + JSONObject valuesJson = keyValuesJson.has(sourceColName) ? keyValuesJson : newValuesJson; + + if (!valuesJson.has(sourceColName)) { + LOG.warn("Primary key column '{}' not found in change record for DELETE.", sourceColName); + throw new InvalidDMLGenerationException( + "Primary key column '" + + sourceColName + + "' missing from change record for table " + + targetTableName); + } + + if (valuesJson.isNull(sourceColName)) { + keyBuilder.append((String) null); + } else { + appendKeyComponent(keyBuilder, sourceCol, valuesJson, sourceColName); + } + } + + Mutation mutation = Mutation.delete(targetTableName, keyBuilder.build()); + return new SpannerMutationResponse(mutation); + } + + private static void setColumnValue( + Mutation.WriteBuilder builder, String targetColName, Column col, JSONObject valuesJson) { + String sourceColName = col.name(); + Type type = col.type(); + + switch (type.getCode()) { + case BOOL: + builder.set(targetColName).to(valuesJson.getBoolean(sourceColName)); + break; + case INT64: + builder.set(targetColName).to(Long.parseLong(valuesJson.getString(sourceColName))); + break; + case FLOAT64: + builder.set(targetColName).to(valuesJson.getBigDecimal(sourceColName).doubleValue()); + break; + case FLOAT32: + builder + .set(targetColName) + .to((float) valuesJson.getBigDecimal(sourceColName).doubleValue()); + break; + case STRING: + builder.set(targetColName).to(valuesJson.getString(sourceColName)); + break; + case JSON: + builder.set(targetColName).to(Value.json(valuesJson.getString(sourceColName))); + break; + case BYTES: + builder.set(targetColName).to(ByteArray.fromBase64(valuesJson.getString(sourceColName))); + break; + case DATE: + builder.set(targetColName).to(Date.parseDate(valuesJson.getString(sourceColName))); + break; + case TIMESTAMP: + builder + .set(targetColName) + .to(Timestamp.parseTimestamp(valuesJson.getString(sourceColName))); + break; + case NUMERIC: + builder.set(targetColName).to(new BigDecimal(valuesJson.getString(sourceColName))); + break; + case ARRAY: + builder + .set(targetColName) + .to( + buildArrayValue( + type.getArrayElementType(), valuesJson.getJSONArray(sourceColName))); + break; + default: + LOG.warn( + "Unrecognised Spanner type code {} for column '{}'; falling back to STRING.", + type.getCode(), + targetColName); + builder.set(targetColName).to(valuesJson.getString(sourceColName)); + } + } + + private static void setNullValue(Mutation.WriteBuilder builder, String targetColName, Type type) { + switch (type.getCode()) { + case BOOL: + builder.set(targetColName).to((Boolean) null); + break; + case INT64: + builder.set(targetColName).to((Long) null); + break; + case FLOAT64: + builder.set(targetColName).to((Double) null); + break; + case FLOAT32: + builder.set(targetColName).to((Float) null); + break; + case BYTES: + builder.set(targetColName).to((ByteArray) null); + break; + case DATE: + builder.set(targetColName).to((Date) null); + break; + case TIMESTAMP: + builder.set(targetColName).to((Timestamp) null); + break; + case NUMERIC: + builder.set(targetColName).to((BigDecimal) null); + break; + case JSON: + builder.set(targetColName).to(Value.json(null)); + break; + case ARRAY: + setNullArrayValue(builder, targetColName, type.getArrayElementType()); + break; + default: + builder.set(targetColName).to((String) null); + } + } + + /** + * Emits a typed NULL for an ARRAY column. The Spanner client requires the null value to carry the + * array element type, otherwise a commit-time type mismatch occurs (e.g. binding {@code + * Value.stringArray(null)} to an {@code ARRAY} column). + */ + private static void setNullArrayValue( + Mutation.WriteBuilder builder, String targetColName, Type elementType) { + switch (elementType.getCode()) { + case BOOL: + builder.set(targetColName).to(Value.boolArray((Iterable) null)); + break; + case INT64: + builder.set(targetColName).to(Value.int64Array((Iterable) null)); + break; + case FLOAT64: + builder.set(targetColName).to(Value.float64Array((Iterable) null)); + break; + case FLOAT32: + builder.set(targetColName).to(Value.float32Array((Iterable) null)); + break; + case BYTES: + builder.set(targetColName).to(Value.bytesArray((Iterable) null)); + break; + case DATE: + builder.set(targetColName).to(Value.dateArray((Iterable) null)); + break; + case TIMESTAMP: + builder.set(targetColName).to(Value.timestampArray((Iterable) null)); + break; + case NUMERIC: + builder.set(targetColName).to(Value.numericArray((Iterable) null)); + break; + case JSON: + builder.set(targetColName).to(Value.jsonArray(null)); + break; + default: + builder.set(targetColName).to(Value.stringArray((Iterable) null)); + } + } + + /** Appends a single primary-key component to the Key builder. */ + private static void appendKeyComponent( + Key.Builder keyBuilder, Column col, JSONObject valuesJson, String sourceColName) { + Type type = col.type(); + switch (type.getCode()) { + case BOOL: + keyBuilder.append(valuesJson.getBoolean(sourceColName)); + break; + case INT64: + keyBuilder.append(Long.parseLong(valuesJson.getString(sourceColName))); + break; + case FLOAT64: + keyBuilder.append(valuesJson.getBigDecimal(sourceColName).doubleValue()); + break; + case FLOAT32: + keyBuilder.append((float) valuesJson.getBigDecimal(sourceColName).doubleValue()); + break; + case BYTES: + keyBuilder.append(ByteArray.fromBase64(valuesJson.getString(sourceColName))); + break; + case DATE: + keyBuilder.append(Date.parseDate(valuesJson.getString(sourceColName))); + break; + case TIMESTAMP: + keyBuilder.append(Timestamp.parseTimestamp(valuesJson.getString(sourceColName))); + break; + case NUMERIC: + keyBuilder.append(new BigDecimal(valuesJson.getString(sourceColName))); + break; + default: + keyBuilder.append(valuesJson.getString(sourceColName)); + } + } + + /** + * Binds a custom-transformation {@link Object} to the mutation builder using the target column's + * Spanner type. Strings are coerced into the correct primitive when needed; already-typed values + * are passed through. + */ + private static void setCustomColumnValue( + Mutation.WriteBuilder builder, String targetColName, Column col, Object value) { + Type type = col.type(); + switch (type.getCode()) { + case BOOL: + if (value instanceof Boolean) { + builder.set(targetColName).to((Boolean) value); + } else { + builder.set(targetColName).to(Boolean.parseBoolean(value.toString())); + } + break; + case INT64: + if (value instanceof Number) { + builder.set(targetColName).to(((Number) value).longValue()); + } else { + builder.set(targetColName).to(Long.parseLong(value.toString())); + } + break; + case FLOAT64: + if (value instanceof Number) { + builder.set(targetColName).to(((Number) value).doubleValue()); + } else { + builder.set(targetColName).to(Double.parseDouble(value.toString())); + } + break; + case FLOAT32: + if (value instanceof Number) { + builder.set(targetColName).to(((Number) value).floatValue()); + } else { + builder.set(targetColName).to(Float.parseFloat(value.toString())); + } + break; + case STRING: + builder.set(targetColName).to(value.toString()); + break; + case JSON: + builder.set(targetColName).to(Value.json(value.toString())); + break; + case BYTES: + if (value instanceof byte[]) { + builder.set(targetColName).to(ByteArray.copyFrom((byte[]) value)); + } else if (value instanceof ByteArray) { + builder.set(targetColName).to((ByteArray) value); + } else { + builder.set(targetColName).to(ByteArray.fromBase64(value.toString())); + } + break; + case DATE: + if (value instanceof Date) { + builder.set(targetColName).to((Date) value); + } else { + builder.set(targetColName).to(Date.parseDate(value.toString())); + } + break; + case TIMESTAMP: + if (value instanceof Timestamp) { + builder.set(targetColName).to((Timestamp) value); + } else { + builder.set(targetColName).to(Timestamp.parseTimestamp(value.toString())); + } + break; + case NUMERIC: + if (value instanceof BigDecimal) { + builder.set(targetColName).to((BigDecimal) value); + } else { + builder.set(targetColName).to(new BigDecimal(value.toString())); + } + break; + default: + LOG.warn( + "Unrecognised Spanner type code {} for custom-transformation column '{}'; falling back to STRING.", + type.getCode(), + targetColName); + builder.set(targetColName).to(value.toString()); + } + } + + /** + * Appends a custom-transformation primary-key {@link Object} to the {@link Key.Builder} using the + * source column's Spanner type. Mirrors {@link #setCustomColumnValue} for the DELETE path. + */ + private static void appendCustomKeyComponent(Key.Builder keyBuilder, Column col, Object value) { + if (value == null) { + keyBuilder.append((String) null); + return; + } + Type type = col.type(); + switch (type.getCode()) { + case BOOL: + if (value instanceof Boolean) { + keyBuilder.append((Boolean) value); + } else { + keyBuilder.append(Boolean.parseBoolean(value.toString())); + } + break; + case INT64: + if (value instanceof Number) { + keyBuilder.append(((Number) value).longValue()); + } else { + keyBuilder.append(Long.parseLong(value.toString())); + } + break; + case FLOAT64: + if (value instanceof Number) { + keyBuilder.append(((Number) value).doubleValue()); + } else { + keyBuilder.append(Double.parseDouble(value.toString())); + } + break; + case FLOAT32: + if (value instanceof Number) { + keyBuilder.append(((Number) value).floatValue()); + } else { + keyBuilder.append(Float.parseFloat(value.toString())); + } + break; + case BYTES: + if (value instanceof byte[]) { + keyBuilder.append(ByteArray.copyFrom((byte[]) value)); + } else if (value instanceof ByteArray) { + keyBuilder.append((ByteArray) value); + } else { + keyBuilder.append(ByteArray.fromBase64(value.toString())); + } + break; + case DATE: + if (value instanceof Date) { + keyBuilder.append((Date) value); + } else { + keyBuilder.append(Date.parseDate(value.toString())); + } + break; + case TIMESTAMP: + if (value instanceof Timestamp) { + keyBuilder.append((Timestamp) value); + } else { + keyBuilder.append(Timestamp.parseTimestamp(value.toString())); + } + break; + case NUMERIC: + if (value instanceof BigDecimal) { + keyBuilder.append((BigDecimal) value); + } else { + keyBuilder.append(new BigDecimal(value.toString())); + } + break; + default: + keyBuilder.append(value.toString()); + } + } + + /** Builds a Spanner {@link Value} representing a Spanner ARRAY column. */ + private static Value buildArrayValue(Type elementType, JSONArray jsonArray) { + switch (elementType.getCode()) { + case BOOL: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : jsonArray.getBoolean(i)); + } + return Value.boolArray(vals); + } + case INT64: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : Long.parseLong(jsonArray.getString(i))); + } + return Value.int64Array(vals); + } + case FLOAT64: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : jsonArray.getBigDecimal(i).doubleValue()); + } + return Value.float64Array(vals); + } + case BYTES: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : ByteArray.fromBase64(jsonArray.getString(i))); + } + return Value.bytesArray(vals); + } + case DATE: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : Date.parseDate(jsonArray.getString(i))); + } + return Value.dateArray(vals); + } + case TIMESTAMP: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : Timestamp.parseTimestamp(jsonArray.getString(i))); + } + return Value.timestampArray(vals); + } + case NUMERIC: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : new BigDecimal(jsonArray.getString(i))); + } + return Value.numericArray(vals); + } + default: + { + List vals = new ArrayList<>(); + for (int i = 0; i < jsonArray.length(); i++) { + vals.add(jsonArray.isNull(i) ? null : jsonArray.getString(i)); + } + return Value.stringArray(vals); + } + } + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java index d95fba1f79..381fea6acd 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/InputRecordProcessor.java @@ -16,6 +16,7 @@ package com.google.cloud.teleport.v2.templates.dbutils.processor; import static com.google.cloud.teleport.v2.templates.constants.Constants.SOURCE_CASSANDRA; +import static com.google.cloud.teleport.v2.templates.constants.Constants.SOURCE_SPANNER; import com.google.cloud.teleport.v2.spanner.ddl.Ddl; import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException; @@ -48,6 +49,77 @@ public class InputRecordProcessor { Metrics.distribution( InputRecordProcessor.class, "apply_custom_transformation_impl_latency_ms"); + /** + * Generates the {@link DMLGeneratorResponse} for the given record without writing to any DAO. + * Returns {@code null} if the event is filtered by a custom transformation. + * + *

Use this when the write must be deferred to a point outside an enclosing transaction (e.g. + * to avoid nested Spanner transactions for the {@code SOURCE_SPANNER} path). + */ + public static DMLGeneratorResponse generateDMLResponse( + TrimmedShardedDataChangeRecord spannerRecord, + ISchemaMapper schemaMapper, + Ddl ddl, + SourceSchema sourceSchema, + String shardId, + String sourceDbTimezoneOffset, + IDMLGenerator dmlGenerator, + ISpannerMigrationTransformer spannerToSourceTransformer, + String source) + throws Exception { + + String tableName = spannerRecord.getTableName(); + String modType = spannerRecord.getModType().name(); + String keysJsonStr = spannerRecord.getMod().getKeysJson(); + String newValueJsonStr = spannerRecord.getMod().getNewValuesJson(); + JSONObject newValuesJson = new JSONObject(newValueJsonStr); + JSONObject keysJson = new JSONObject(keysJsonStr); + Map customTransformationResponse = null; + + if (spannerToSourceTransformer != null) { + org.joda.time.Instant startTimestamp = org.joda.time.Instant.now(); + Map mapRequest = + ChangeEventToMapConvertor.combineJsonObjects(keysJson, newValuesJson); + MigrationTransformationRequest migrationTransformationRequest = + new MigrationTransformationRequest(tableName, mapRequest, shardId, modType); + MigrationTransformationResponse migrationTransformationResponse = null; + try { + migrationTransformationResponse = + spannerToSourceTransformer.toSourceRow(migrationTransformationRequest); + } catch (Exception e) { + throw new InvalidTransformationException(e); + } + org.joda.time.Instant endTimestamp = org.joda.time.Instant.now(); + applyCustomTransformationResponseTimeMetric.update( + new Duration(startTimestamp, endTimestamp).getMillis()); + if (migrationTransformationResponse.isEventFiltered()) { + Metrics.counter(InputRecordProcessor.class, "filtered_events_" + shardId).inc(); + return null; + } + if (migrationTransformationResponse != null) { + customTransformationResponse = migrationTransformationResponse.getResponseRow(); + } + } + + DMLGeneratorRequest dmlGeneratorRequest = + new DMLGeneratorRequest.Builder( + modType, tableName, newValuesJson, keysJson, sourceDbTimezoneOffset) + .setSchemaMapper(schemaMapper) + .setCustomTransformationResponse(customTransformationResponse) + .setCommitTimestamp(spannerRecord.getCommitTimestamp()) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build(); + + DMLGeneratorResponse dmlGeneratorResponse = dmlGenerator.getDMLStatement(dmlGeneratorRequest); + + if (!SOURCE_SPANNER.equals(source) && dmlGeneratorResponse.getDmlStatement().isEmpty()) { + throw new InvalidDMLGenerationException("DML statement is empty for table: " + tableName); + } + + return dmlGeneratorResponse; + } + public static boolean processRecord( TrimmedShardedDataChangeRecord spannerRecord, ISchemaMapper schemaMapper, @@ -63,63 +135,22 @@ public static boolean processRecord( throws Exception { try { + DMLGeneratorResponse dmlGeneratorResponse = + generateDMLResponse( + spannerRecord, + schemaMapper, + ddl, + sourceSchema, + shardId, + sourceDbTimezoneOffset, + dmlGenerator, + spannerToSourceTransformer, + source); - String tableName = spannerRecord.getTableName(); - String modType = spannerRecord.getModType().name(); - String keysJsonStr = spannerRecord.getMod().getKeysJson(); - String newValueJsonStr = spannerRecord.getMod().getNewValuesJson(); - JSONObject newValuesJson = new JSONObject(newValueJsonStr); - JSONObject keysJson = new JSONObject(keysJsonStr); - Map customTransformationResponse = null; - - if (spannerToSourceTransformer != null) { - org.joda.time.Instant startTimestamp = org.joda.time.Instant.now(); - Map mapRequest = - ChangeEventToMapConvertor.combineJsonObjects(keysJson, newValuesJson); - MigrationTransformationRequest migrationTransformationRequest = - new MigrationTransformationRequest(tableName, mapRequest, shardId, modType); - MigrationTransformationResponse migrationTransformationResponse = null; - try { - migrationTransformationResponse = - spannerToSourceTransformer.toSourceRow(migrationTransformationRequest); - } catch (Exception e) { - throw new InvalidTransformationException(e); - } - org.joda.time.Instant endTimestamp = org.joda.time.Instant.now(); - applyCustomTransformationResponseTimeMetric.update( - new Duration(startTimestamp, endTimestamp).getMillis()); - if (migrationTransformationResponse.isEventFiltered()) { - Metrics.counter(InputRecordProcessor.class, "filtered_events_" + shardId).inc(); - return true; - } - if (migrationTransformationResponse != null) { - customTransformationResponse = migrationTransformationResponse.getResponseRow(); - } - } - DMLGeneratorRequest dmlGeneratorRequest = - new DMLGeneratorRequest.Builder( - modType, tableName, newValuesJson, keysJson, sourceDbTimezoneOffset) - .setSchemaMapper(schemaMapper) - .setCustomTransformationResponse(customTransformationResponse) - .setCommitTimestamp(spannerRecord.getCommitTimestamp()) - .setDdl(ddl) - .setSourceSchema(sourceSchema) - .build(); - - DMLGeneratorResponse dmlGeneratorResponse = dmlGenerator.getDMLStatement(dmlGeneratorRequest); - if (dmlGeneratorResponse.getDmlStatement().isEmpty()) { - throw new InvalidDMLGenerationException("DML statement is empty for table: " + tableName); + if (dmlGeneratorResponse == null) { + return true; // filtered } - // TODO we need to handle it as proper Interface Level as of now we have handle Prepared - // TODO Statement and Raw Statement Differently - /* - * TODO: - * Note: The `SOURCE_CASSANDRA` case not covered in the unit tests. - * Answer: Currently, we have implemented unit tests for the Input Record Processor under the SourceWrittenFn. - * These tests cover the majority of scenarios, but they are tightly coupled with the existing code. - * Adding unit tests for SOURCE_CASSANDRA would require a significant refactoring of the entire unit test file. - * Given the current implementation, such refactoring is deemed unnecessary as it would not provide substantial value or impact. - */ + switch (source) { case SOURCE_CASSANDRA: dao.write(dmlGeneratorResponse, null); diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactory.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactory.java index 67ff6702f8..cfdc22e557 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactory.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactory.java @@ -15,20 +15,25 @@ */ package com.google.cloud.teleport.v2.templates.dbutils.processor; +import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.teleport.v2.spanner.migrations.connection.ConnectionHelperRequest; import com.google.cloud.teleport.v2.spanner.migrations.connection.IConnectionHelper; import com.google.cloud.teleport.v2.spanner.migrations.connection.JdbcConnectionHelper; import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; import com.google.cloud.teleport.v2.templates.constants.Constants; import com.google.cloud.teleport.v2.templates.dbutils.connection.CassandraConnectionHelper; +import com.google.cloud.teleport.v2.templates.dbutils.connection.SpannerConnectionHelper; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.CassandraDao; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.IDao; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.JdbcDao; +import com.google.cloud.teleport.v2.templates.dbutils.dao.source.SpannerTargetDao; import com.google.cloud.teleport.v2.templates.dbutils.dml.CassandraDMLGenerator; import com.google.cloud.teleport.v2.templates.dbutils.dml.IDMLGenerator; import com.google.cloud.teleport.v2.templates.dbutils.dml.MySQLDMLGenerator; import com.google.cloud.teleport.v2.templates.dbutils.dml.PostgreSQLDMLGenerator; +import com.google.cloud.teleport.v2.templates.dbutils.dml.SpannerDMLGenerator; import com.google.cloud.teleport.v2.templates.exceptions.UnsupportedSourceException; import java.util.HashMap; import java.util.List; @@ -49,7 +54,9 @@ public class SourceProcessorFactory { Constants.SOURCE_CASSANDRA, "com.datastax.oss.driver.api.core.CqlSession", // Cassandra Session Class Constants.SOURCE_POSTGRESQL, - "org.postgresql.Driver" // PostgreSQL JDBC Driver + "org.postgresql.Driver", // PostgreSQL JDBC Driver + Constants.SOURCE_SPANNER, + "com.google.cloud.spanner.DatabaseClient" // Spanner DatabaseClient ); private static Map> connectionUrl = new HashMap<>(); @@ -58,10 +65,12 @@ public class SourceProcessorFactory { dmlGeneratorMap.put(Constants.SOURCE_MYSQL, new MySQLDMLGenerator()); dmlGeneratorMap.put(Constants.SOURCE_CASSANDRA, new CassandraDMLGenerator()); dmlGeneratorMap.put(Constants.SOURCE_POSTGRESQL, new PostgreSQLDMLGenerator()); + dmlGeneratorMap.put(Constants.SOURCE_SPANNER, new SpannerDMLGenerator()); connectionHelperMap.put(Constants.SOURCE_MYSQL, new JdbcConnectionHelper()); connectionHelperMap.put(Constants.SOURCE_CASSANDRA, new CassandraConnectionHelper()); connectionHelperMap.put(Constants.SOURCE_POSTGRESQL, new JdbcConnectionHelper()); + connectionHelperMap.put(Constants.SOURCE_SPANNER, new SpannerConnectionHelper()); connectionUrl.put( Constants.SOURCE_MYSQL, @@ -88,6 +97,12 @@ public class SourceProcessorFactory { + "/" + cassandraShard.getKeySpaceName(); }); + connectionUrl.put( + Constants.SOURCE_SPANNER, + shard -> { + SpannerShard spannerShard = (SpannerShard) shard; + return SpannerConnectionHelper.connectionKey(spannerShard); + }); } private static Map, Integer, ConnectionHelperRequest>> @@ -119,6 +134,15 @@ public class SourceProcessorFactory { maxConnections, driverMap.get(Constants.SOURCE_CASSANDRA), null, // No specific initialization query for Cassandra + null), + Constants.SOURCE_SPANNER, + (shards, maxConnections) -> + new ConnectionHelperRequest( + shards, + null, + maxConnections, + driverMap.get(Constants.SOURCE_SPANNER), + null, // No JDBC init query for Spanner null)); // for unit testing purposes @@ -195,6 +219,10 @@ private static Map createSourceDaoMap(String source, List s IDao sqlDao; if (source.equals(Constants.SOURCE_MYSQL) || source.equals(Constants.SOURCE_POSTGRESQL)) { sqlDao = new JdbcDao(connectionUrl, shard.getUserName(), getConnectionHelper(source)); + } else if (source.equals(Constants.SOURCE_SPANNER)) { + sqlDao = + new SpannerTargetDao( + connectionUrl, (IConnectionHelper) getConnectionHelper(source)); } else { sqlDao = new CassandraDao(connectionUrl, shard.getUserName(), getConnectionHelper(source)); } diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/SpannerMutationResponse.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/SpannerMutationResponse.java new file mode 100644 index 0000000000..52b399f1c7 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/models/SpannerMutationResponse.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.models; + +import com.google.cloud.spanner.Mutation; + +/** + * A {@link DMLGeneratorResponse} subclass that carries a Spanner {@link Mutation} instead of a raw + * SQL string. Used by the Spanner-to-Spanner reverse-replication path. + */ +public class SpannerMutationResponse extends DMLGeneratorResponse { + + private final Mutation mutation; + + public SpannerMutationResponse(Mutation mutation) { + super(""); + this.mutation = mutation; + } + + public Mutation getMutation() { + return mutation; + } +} diff --git a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java index 636f950311..ed4defb1b7 100644 --- a/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java +++ b/v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/transforms/SourceWriterFn.java @@ -46,6 +46,7 @@ import com.google.cloud.teleport.v2.templates.dbutils.processor.SourceProcessor; import com.google.cloud.teleport.v2.templates.dbutils.processor.SourceProcessorFactory; import com.google.cloud.teleport.v2.templates.exceptions.UnsupportedSourceException; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; import com.google.cloud.teleport.v2.templates.utils.SchemaMapperUtils; import com.google.cloud.teleport.v2.templates.utils.ShadowTableRecord; import com.google.cloud.teleport.v2.templates.utils.SpannerToSourceDbExceptionClassifier; @@ -59,6 +60,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.apache.beam.sdk.metrics.Counter; @@ -247,6 +249,12 @@ public void processElement(ProcessContext c) { ChangeEventSpannerConvertor.changeEventToPrimaryKey( tableName, ddl, keysJson, /* convertNameToLowerCase= */ false); String shadowTableName = shadowTablePrefix + tableName; + + // For the Spanner target the write must happen outside the shadow-table transaction + // (writeAtLeastOnce inside readWriteTransaction causes "Nested transactions not + // supported"). We capture the response here and apply it after the TX commits. + AtomicReference pendingSpannerWrite = new AtomicReference<>(); + Boolean transactionResult = spannerDao .getDatabaseClient() @@ -255,8 +263,6 @@ public void processElement(ProcessContext c) { (TransactionRunner.TransactionCallable) shadowTransaction -> { boolean isSourceAhead = false; - // Boolean reference to capture if the record was written in the - // transaction AtomicBoolean isRecordWritten = new AtomicBoolean(false); ShadowTableRecord shadowTableRecord = spannerDao.readShadowTableRecordWithExclusiveLock( @@ -266,12 +272,8 @@ public void processElement(ProcessContext c) { && ((shadowTableRecord .getProcessedCommitTimestamp() .compareTo(spannerRec.getCommitTimestamp()) - > 0) // either the source already has record with greater - // commit - // timestamp - || (shadowTableRecord // or the source has the same commit - // timestamp but - // greater record sequence + > 0) + || (shadowTableRecord .getProcessedCommitTimestamp() .compareTo(spannerRec.getCommitTimestamp()) == 0 @@ -279,41 +281,65 @@ public void processElement(ProcessContext c) { >= Long.parseLong(spannerRec.getRecordSequence()))); if (!isSourceAhead) { - IDao sourceDao = sourceProcessor.getSourceDao(shardId); - TransactionalCheck check = - () -> { - ShadowTableRecord newShadowTableRecord = - spannerDao.readShadowTableRecordWithExclusiveLock( - shadowTableName, - primaryKey, - shadowTableDdl, - shadowTransaction); - if (!ShadowTableRecord.isEquals( - shadowTableRecord, newShadowTableRecord)) { - throw new TransactionalCheckException( - "Shadow table sequence changed during transaction"); - } - }; - boolean isEventFiltered = - InputRecordProcessor.processRecord( - spannerRec, - schemaMapper, - ddl, - sourceSchema, - sourceDao, - shardId, - sourceDbTimezoneOffset, - sourceProcessor.getDmlGenerator(), - spannerToSourceTransformer, - this.source, - check); - isRecordWritten.set(!isEventFiltered); - if (isEventFiltered) { - outputWithTag( - c, - Constants.FILTERED_TAG, - Constants.FILTERED_TAG_MESSAGE, - spannerRec); + if (Constants.SOURCE_SPANNER.equals(source)) { + DMLGeneratorResponse response = + InputRecordProcessor.generateDMLResponse( + spannerRec, + schemaMapper, + ddl, + sourceSchema, + shardId, + sourceDbTimezoneOffset, + sourceProcessor.getDmlGenerator(), + spannerToSourceTransformer, + source); + if (response == null) { + outputWithTag( + c, + Constants.FILTERED_TAG, + Constants.FILTERED_TAG_MESSAGE, + spannerRec); + } else { + pendingSpannerWrite.set(response); + isRecordWritten.set(true); + } + } else { + IDao sourceDao = sourceProcessor.getSourceDao(shardId); + TransactionalCheck check = + () -> { + ShadowTableRecord newShadowTableRecord = + spannerDao.readShadowTableRecordWithExclusiveLock( + shadowTableName, + primaryKey, + shadowTableDdl, + shadowTransaction); + if (!ShadowTableRecord.isEquals( + shadowTableRecord, newShadowTableRecord)) { + throw new TransactionalCheckException( + "Shadow table sequence changed during transaction"); + } + }; + boolean isEventFiltered = + InputRecordProcessor.processRecord( + spannerRec, + schemaMapper, + ddl, + sourceSchema, + sourceDao, + shardId, + sourceDbTimezoneOffset, + sourceProcessor.getDmlGenerator(), + spannerToSourceTransformer, + this.source, + check); + isRecordWritten.set(!isEventFiltered); + if (isEventFiltered) { + outputWithTag( + c, + Constants.FILTERED_TAG, + Constants.FILTERED_TAG_MESSAGE, + spannerRec); + } } spannerDao.updateShadowTable( @@ -328,6 +354,13 @@ public void processElement(ProcessContext c) { } return isRecordWritten.get(); }); + + // Apply the deferred Spanner target write now that the shadow-table TX has committed. + if (pendingSpannerWrite.get() != null) { + IDao sourceDao = sourceProcessor.getSourceDao(shardId); + sourceDao.write(pendingSpannerWrite.get(), null); + } + if (Boolean.TRUE.equals(transactionResult)) { successRecordCountMetric.inc(); Counter recordsWrittenToSource = diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelperTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelperTest.java new file mode 100644 index 0000000000..46e4b4dc6b --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/connection/SpannerConnectionHelperTest.java @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.connection; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.teleport.v2.spanner.migrations.exceptions.ConnectionException; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; +import java.util.HashMap; +import java.util.Map; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public final class SpannerConnectionHelperTest { + + @After + public void tearDown() { + // Reset the static client map so tests don't leak between cases. + new SpannerConnectionHelper().setClientMap(new HashMap<>()); + } + + @Test + public void connectionKeyReturnsProjectInstanceDatabase() { + SpannerShard shard = new SpannerShard("shard1", "myproj", "myinst", "mydb"); + assertThat(SpannerConnectionHelper.connectionKey(shard)).isEqualTo("myproj/myinst/mydb"); + } + + @Test + public void connectionKeyHandlesDifferentValues() { + SpannerShard a = new SpannerShard("s", "p1", "i1", "d1"); + SpannerShard b = new SpannerShard("s", "p2", "i2", "d2"); + assertThat(SpannerConnectionHelper.connectionKey(a)) + .isNotEqualTo(SpannerConnectionHelper.connectionKey(b)); + } + + @Test + public void getConnectionReturnsClientForRegisteredKey() throws Exception { + SpannerConnectionHelper helper = new SpannerConnectionHelper(); + DatabaseClient client = Mockito.mock(DatabaseClient.class); + Map map = new HashMap<>(); + map.put("p/i/d", client); + helper.setClientMap(map); + + assertThat(helper.getConnection("p/i/d")).isSameAs(client); + } + + @Test + public void getConnectionThrowsWhenPoolEmpty() { + SpannerConnectionHelper helper = new SpannerConnectionHelper(); + helper.setClientMap(new HashMap<>()); + + assertThatThrownBy(() -> helper.getConnection("any-key")) + .isInstanceOf(ConnectionException.class); + } + + @Test + public void getConnectionThrowsForUnknownKey() { + SpannerConnectionHelper helper = new SpannerConnectionHelper(); + Map map = new HashMap<>(); + map.put("a/b/c", Mockito.mock(DatabaseClient.class)); + helper.setClientMap(map); + + assertThatThrownBy(() -> helper.getConnection("does/not/exist")) + .isInstanceOf(ConnectionException.class); + } + + @Test + public void isConnectionPoolInitializedReflectsClientMapState() { + SpannerConnectionHelper helper = new SpannerConnectionHelper(); + helper.setClientMap(new HashMap<>()); + assertThat(helper.isConnectionPoolInitialized()).isFalse(); + + Map map = new HashMap<>(); + map.put("p/i/d", Mockito.mock(DatabaseClient.class)); + helper.setClientMap(map); + assertThat(helper.isConnectionPoolInitialized()).isTrue(); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDaoTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDaoTest.java new file mode 100644 index 0000000000..ac93a0d12b --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dao/source/SpannerTargetDaoTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.dao.source; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyIterable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.teleport.v2.spanner.migrations.connection.IConnectionHelper; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; +import com.google.cloud.teleport.v2.templates.models.SpannerMutationResponse; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SpannerTargetDaoTest { + + private static final String CONNECTION_KEY = "my-project/my-instance/my-db"; + + @Test + @SuppressWarnings("unchecked") + public void writeDispatchesMutationToDatabaseClient() throws Exception { + IConnectionHelper connectionHelper = mock(IConnectionHelper.class); + DatabaseClient mockClient = mock(DatabaseClient.class); + when(connectionHelper.getConnection(CONNECTION_KEY)).thenReturn(mockClient); + + Mutation mutation = Mutation.newInsertOrUpdateBuilder("T").set("Id").to(1L).build(); + SpannerMutationResponse response = new SpannerMutationResponse(mutation); + + SpannerTargetDao dao = new SpannerTargetDao(CONNECTION_KEY, connectionHelper); + dao.write(response, null); + + verify(mockClient).writeAtLeastOnce(anyIterable()); + } + + @Test + @SuppressWarnings("unchecked") + public void transactionalCheckNotSupportedThrows() { + IConnectionHelper connectionHelper = mock(IConnectionHelper.class); + SpannerTargetDao dao = new SpannerTargetDao(CONNECTION_KEY, connectionHelper); + + Mutation mutation = Mutation.newInsertOrUpdateBuilder("T").set("Id").to(1L).build(); + SpannerMutationResponse response = new SpannerMutationResponse(mutation); + + assertThrows(UnsupportedOperationException.class, () -> dao.write(response, () -> {})); + } + + @Test + @SuppressWarnings("unchecked") + public void wrongResponseTypeThrows() throws Exception { + IConnectionHelper connectionHelper = mock(IConnectionHelper.class); + DatabaseClient mockClient = mock(DatabaseClient.class); + when(connectionHelper.getConnection(CONNECTION_KEY)).thenReturn(mockClient); + + DMLGeneratorResponse wrongResponse = new DMLGeneratorResponse("SELECT 1"); + + SpannerTargetDao dao = new SpannerTargetDao(CONNECTION_KEY, connectionHelper); + assertThrows(IllegalArgumentException.class, () -> dao.write(wrongResponse, null)); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGeneratorTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGeneratorTest.java new file mode 100644 index 0000000000..dabb64b8a4 --- /dev/null +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/SpannerDMLGeneratorTest.java @@ -0,0 +1,1139 @@ +/* + * Copyright (C) 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.dbutils.dml; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.spanner.Mutation; +import com.google.cloud.teleport.v2.spanner.ddl.Ddl; +import com.google.cloud.teleport.v2.spanner.ddl.Table; +import com.google.cloud.teleport.v2.spanner.migrations.schema.ISchemaMapper; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceColumn; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceDatabaseType; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceSchema; +import com.google.cloud.teleport.v2.spanner.sourceddl.SourceTable; +import com.google.cloud.teleport.v2.spanner.type.Type; +import com.google.cloud.teleport.v2.templates.exceptions.InvalidDMLGenerationException; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest; +import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse; +import com.google.cloud.teleport.v2.templates.models.SpannerMutationResponse; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SpannerDMLGeneratorTest { + + private static final SourceDatabaseType SRC_TYPE = SourceDatabaseType.SPANNER; + + /** Builds a simple Spanner DDL with one table: Singers(SingerId INT64 PK, FirstName STRING). */ + private static Ddl buildDdl() { + Ddl.Builder builder = Ddl.builder(); + Table.Builder tableBuilder = builder.createTable("Singers"); + tableBuilder.column("SingerId").int64().notNull().endColumn(); + tableBuilder.column("FirstName").string().max().endColumn(); + tableBuilder.column("LastName").string().max().endColumn(); + tableBuilder.primaryKey().asc("SingerId").end(); + tableBuilder.endTable(); + return builder.build(); + } + + /** Builds a SourceSchema (target Spanner) mirroring the DDL above. */ + private static SourceSchema buildSourceSchema() { + SourceColumn singerIdCol = + SourceColumn.builder(SRC_TYPE) + .name("SingerId") + .type("INT64") + .isPrimaryKey(true) + .isNullable(false) + .build(); + SourceColumn firstNameCol = + SourceColumn.builder(SRC_TYPE).name("FirstName").type("STRING").isNullable(true).build(); + SourceColumn lastNameCol = + SourceColumn.builder(SRC_TYPE).name("LastName").type("STRING").isNullable(true).build(); + + SourceTable table = + SourceTable.builder(SRC_TYPE) + .name("Singers") + .columns(ImmutableList.of(singerIdCol, firstNameCol, lastNameCol)) + .primaryKeyColumns(ImmutableList.of("SingerId")) + .foreignKeys(ImmutableList.of()) + .indexes(ImmutableList.of()) + .build(); + + return SourceSchema.builder(SRC_TYPE) + .databaseName("test-db") + .tables(ImmutableMap.of("Singers", table)) + .build(); + } + + /** Creates a schema mapper that maps Singers → Singers with identity column mapping. */ + private static ISchemaMapper buildIdentityMapper() throws Exception { + ISchemaMapper mapper = mock(ISchemaMapper.class); + when(mapper.getSourceTableName("", "Singers")).thenReturn("Singers"); + when(mapper.getSpannerColumnName("", "Singers", "SingerId")).thenReturn("SingerId"); + when(mapper.getSpannerColumnName("", "Singers", "FirstName")).thenReturn("FirstName"); + when(mapper.getSpannerColumnName("", "Singers", "LastName")).thenReturn("LastName"); + when(mapper.getSourceColumnName("", "Singers", "SingerId")).thenReturn("SingerId"); + when(mapper.getSourceColumnName("", "Singers", "FirstName")).thenReturn("FirstName"); + when(mapper.getSourceColumnName("", "Singers", "LastName")).thenReturn("LastName"); + return mapper; + } + + @Test + public void insertProducesInsertOrUpdateMutation() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + JSONObject newValues = new JSONObject("{\"FirstName\":\"John\",\"LastName\":\"Doe\"}"); + JSONObject keyValues = new JSONObject("{\"SingerId\":\"42\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "Singers", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + + assertNotNull(response); + SpannerMutationResponse mutationResponse = (SpannerMutationResponse) response; + Mutation mutation = mutationResponse.getMutation(); + assertEquals(Mutation.Op.INSERT_OR_UPDATE, mutation.getOperation()); + assertEquals("Singers", mutation.getTable()); + } + + @Test + public void updateProducesInsertOrUpdateMutation() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + JSONObject newValues = new JSONObject("{\"FirstName\":\"Jane\",\"LastName\":\"Smith\"}"); + JSONObject keyValues = new JSONObject("{\"SingerId\":\"7\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("UPDATE", "Singers", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + + SpannerMutationResponse mutationResponse = (SpannerMutationResponse) response; + assertEquals(Mutation.Op.INSERT_OR_UPDATE, mutationResponse.getMutation().getOperation()); + } + + @Test + public void deleteProducesDeleteMutation() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + JSONObject newValues = new JSONObject("{}"); + JSONObject keyValues = new JSONObject("{\"SingerId\":\"99\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("DELETE", "Singers", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + + SpannerMutationResponse mutationResponse = (SpannerMutationResponse) response; + assertEquals(Mutation.Op.DELETE, mutationResponse.getMutation().getOperation()); + assertEquals("Singers", mutationResponse.getMutation().getTable()); + } + + @Test + public void nullNonPkColumnIsIncludedInMutation() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + JSONObject newValues = new JSONObject(); + newValues.put("FirstName", JSONObject.NULL); + newValues.put("LastName", "Doe"); + JSONObject keyValues = new JSONObject("{\"SingerId\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "Singers", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void nullRequestThrows() { + assertThrows( + InvalidDMLGenerationException.class, () -> new SpannerDMLGenerator().getDMLStatement(null)); + } + + @Test + public void missingTableInDdlThrows() throws Exception { + Ddl ddl = Ddl.builder().build(); // empty DDL + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build())); + } + + @Test + public void unsupportedModTypeThrows() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema sourceSchema = buildSourceSchema(); + ISchemaMapper mapper = buildIdentityMapper(); + + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "UPSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(sourceSchema) + .build())); + } + + @Test + public void boolColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("BoolVal", Type.bool()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("BoolVal", "BOOL"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"BoolVal\":true}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void bytesColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("BytesVal", Type.bytes()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("BytesVal", "BYTES"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + String base64Hello = java.util.Base64.getEncoder().encodeToString("hello".getBytes()); + JSONObject newValues = new JSONObject("{\"BytesVal\":\"" + base64Hello + "\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void timestampColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("TsVal", Type.timestamp()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("TsVal", "TIMESTAMP"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"TsVal\":\"2024-01-15T10:30:00Z\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void dateColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("DateVal", Type.date()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("DateVal", "DATE"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"DateVal\":\"2024-06-15\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void numericColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("NumVal", Type.numeric()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("NumVal", "NUMERIC"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"NumVal\":\"123.456\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void int64ColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("IntVal", Type.int64()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("IntVal", "INT64"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"IntVal\":\"42\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void float64ColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("FloatVal", Type.float64()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("FloatVal", "FLOAT64"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"FloatVal\":3.14}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void float32ColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Float32Val", Type.float32()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Float32Val", "FLOAT32"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Float32Val\":1.5}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void stringColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("StrVal", Type.string()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("StrVal", "STRING"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"StrVal\":\"hello\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void jsonColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("JsonVal", Type.json()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("JsonVal", "JSON"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"JsonVal\":\"{\\\"k\\\":1}\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void arrayOfInt64ColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("ArrVal", Type.array(Type.int64())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("ArrVal", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"ArrVal\":[\"1\",\"2\",\"3\"]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + @Test + public void arrayOfStringColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("ArrVal", Type.array(Type.string())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("ArrVal", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"ArrVal\":[\"a\",\"b\"]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + assertNotNull(((SpannerMutationResponse) response).getMutation()); + } + + private static Ddl buildDdlWithSingleNonPkCol(String colName, Type colType) { + Ddl.Builder ddlBuilder = Ddl.builder(); + Table.Builder tableBuilder = ddlBuilder.createTable("T"); + tableBuilder.column("Id").int64().notNull().endColumn(); + tableBuilder.column(colName).type(colType).endColumn(); + tableBuilder.primaryKey().asc("Id").end(); + tableBuilder.endTable(); + return ddlBuilder.build(); + } + + private static SourceSchema buildSchemaWithSingleNonPkCol(String colName, String colType) { + SourceColumn idCol = + SourceColumn.builder(SRC_TYPE) + .name("Id") + .type("INT64") + .isPrimaryKey(true) + .isNullable(false) + .build(); + SourceColumn dataCol = + SourceColumn.builder(SRC_TYPE).name(colName).type(colType).isNullable(true).build(); + + SourceTable table = + SourceTable.builder(SRC_TYPE) + .name("T") + .columns(ImmutableList.of(idCol, dataCol)) + .primaryKeyColumns(ImmutableList.of("Id")) + .foreignKeys(ImmutableList.of()) + .indexes(ImmutableList.of()) + .build(); + + return SourceSchema.builder(SRC_TYPE) + .databaseName("test-db") + .tables(ImmutableMap.of("T", table)) + .build(); + } + + private static ISchemaMapper buildMapperForSingleColTable(SourceSchema schema) throws Exception { + SourceTable table = schema.tables().values().iterator().next(); + String tableName = table.name(); + ISchemaMapper mapper = mock(ISchemaMapper.class); + when(mapper.getSourceTableName("", tableName)).thenReturn(tableName); + for (SourceColumn col : table.columns()) { + when(mapper.getSpannerColumnName("", tableName, col.name())).thenReturn(col.name()); + when(mapper.getSourceColumnName("", tableName, col.name())).thenReturn(col.name()); + } + return mapper; + } + + @Test + public void customTransformationInt64IsBoundAsInt64() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Counter", Type.int64()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Counter", "INT64"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Counter\":\"1\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Counter", 42L); // custom returns a Long, not a String + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + // Type-aware binding: the value should be an INT64, not a STRING. + assertEquals(42L, mutation.asMap().get("Counter").getInt64()); + } + + @Test + public void customTransformationBoolIsBoundAsBool() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("IsActive", Type.bool()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("IsActive", "BOOL"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"IsActive\":false}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("IsActive", Boolean.TRUE); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals(true, mutation.asMap().get("IsActive").getBool()); + } + + @Test + public void customTransformationTimestampIsBoundAsTimestamp() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Ts", Type.timestamp()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Ts", "TIMESTAMP"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Ts\":\"2024-01-15T10:30:00Z\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Ts", "2025-06-01T00:00:00Z"); // custom returns a String for TIMESTAMP + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals( + com.google.cloud.Timestamp.parseTimestamp("2025-06-01T00:00:00Z"), + mutation.asMap().get("Ts").getTimestamp()); + } + + @Test + public void nullArrayOfInt64IsBoundAsTypedNullArray() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("ArrVal", Type.array(Type.int64())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("ArrVal", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("ArrVal", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("ArrVal"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + assertEquals( + com.google.cloud.spanner.Type.array(com.google.cloud.spanner.Type.int64()), v.getType()); + } + + @Test + public void nullArrayOfTimestampIsBoundAsTypedNullArray() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("ArrVal", Type.array(Type.timestamp())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("ArrVal", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("ArrVal", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("ArrVal"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + assertEquals( + com.google.cloud.spanner.Type.array(com.google.cloud.spanner.Type.timestamp()), + v.getType()); + } + + @Test + public void nullArrayOfBoolIsBoundAsTypedNullArray() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("ArrVal", Type.array(Type.bool())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("ArrVal", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("ArrVal", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("ArrVal"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + assertEquals( + com.google.cloud.spanner.Type.array(com.google.cloud.spanner.Type.bool()), v.getType()); + } + + @Test + public void customTransformationNullEmitsTypedNull() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Counter", Type.int64()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Counter", "INT64"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Counter\":\"1\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Counter", null); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Counter")); + org.junit.Assert.assertTrue(mutation.asMap().get("Counter").isNull()); + } + + @Test + public void deleteWithCustomTransformationInt64PkUsesTypedKey() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Data", Type.string()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Data", "STRING"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Id", 7L); // custom returns a Long, not String + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("DELETE", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals(Mutation.Op.DELETE, mutation.getOperation()); + // Key contains a typed INT64 part, not a STRING coercion. + assertEquals( + com.google.cloud.spanner.Key.of(7L).toString(), + mutation.getKeySet().getKeys().iterator().next().toString()); + } + + @Test + public void deleteWithCustomTransformationNullPk() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Data", Type.string()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Data", "STRING"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Id", null); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("DELETE", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals(Mutation.Op.DELETE, mutation.getOperation()); + } + + @Test + public void customTransformationStringValueIsBoundAsString() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Name", Type.string()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Name", "STRING"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Name\":\"original\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Name", "overridden"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals("overridden", mutation.asMap().get("Name").getString()); + } + + @Test + public void customTransformationFloat64ValueIsBoundAsFloat() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Ratio", Type.float64()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Ratio", "FLOAT64"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Ratio\":1.5}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Ratio", 3.14); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals(3.14, mutation.asMap().get("Ratio").getFloat64(), 0.0001); + } + + @Test + public void customTransformationNumericValueIsBoundAsNumeric() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Amount", Type.numeric()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Amount", "NUMERIC"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Amount\":\"1.0\"}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + java.util.Map custom = new java.util.HashMap<>(); + custom.put("Amount", new java.math.BigDecimal("12345.6789")); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .setCustomTransformationResponse(custom) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertEquals( + new java.math.BigDecimal("12345.6789"), mutation.asMap().get("Amount").getNumeric()); + } + + @Test + public void nullValueForBoolColumnIsTypedNull() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Flag", Type.bool()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Flag", "BOOL"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("Flag", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("Flag"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + assertEquals(com.google.cloud.spanner.Type.bool(), v.getType()); + } + + @Test + public void nullValueForDateColumnIsTypedNull() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Day", Type.date()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Day", "DATE"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("Day", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("Day"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + assertEquals(com.google.cloud.spanner.Type.date(), v.getType()); + } + + @Test + public void nullValueForJsonColumnIsTypedNull() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Payload", Type.json()); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Payload", "JSON"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject(); + newValues.put("Payload", JSONObject.NULL); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + com.google.cloud.spanner.Value v = + ((SpannerMutationResponse) response).getMutation().asMap().get("Payload"); + assertNotNull(v); + org.junit.Assert.assertTrue(v.isNull()); + } + + @Test + public void arrayOfBoolColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Flags", Type.array(Type.bool())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Flags", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Flags\":[true,false,true]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Flags")); + } + + @Test + public void arrayOfFloat64ColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Vals", Type.array(Type.float64())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Vals", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Vals\":[1.1, 2.2, 3.3]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Vals")); + } + + @Test + public void arrayOfTimestampColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Tss", Type.array(Type.timestamp())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Tss", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = + new JSONObject("{\"Tss\":[\"2024-01-01T00:00:00Z\",\"2024-06-15T12:00:00Z\"]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Tss")); + } + + @Test + public void arrayOfDateColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Days", Type.array(Type.date())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Days", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Days\":[\"2024-01-01\",\"2024-06-15\"]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Days")); + } + + @Test + public void arrayOfNumericColumnIsHandled() throws Exception { + Ddl ddl = buildDdlWithSingleNonPkCol("Nums", Type.array(Type.numeric())); + SourceSchema schema = buildSchemaWithSingleNonPkCol("Nums", "ARRAY"); + ISchemaMapper mapper = buildMapperForSingleColTable(schema); + + JSONObject newValues = new JSONObject("{\"Nums\":[\"1.1\",\"2.2\"]}"); + JSONObject keyValues = new JSONObject("{\"Id\":\"1\"}"); + + DMLGeneratorResponse response = + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder("INSERT", "T", newValues, keyValues, "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(schema) + .build()); + + Mutation mutation = ((SpannerMutationResponse) response).getMutation(); + assertNotNull(mutation.asMap().get("Nums")); + } + + @Test + public void missingTargetTableInSourceSchemaThrows() throws Exception { + Ddl ddl = buildDdl(); + SourceSchema emptySchema = + SourceSchema.builder(SRC_TYPE).databaseName("test-db").tables(ImmutableMap.of()).build(); + ISchemaMapper mapper = buildIdentityMapper(); + + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setSchemaMapper(mapper) + .setDdl(ddl) + .setSourceSchema(emptySchema) + .build())); + } + + @Test + public void nullSchemaMapperThrows() throws Exception { + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setDdl(buildDdl()) + .setSourceSchema(buildSourceSchema()) + .build())); + } + + @Test + public void nullDdlThrows() throws Exception { + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setSchemaMapper(buildIdentityMapper()) + .setSourceSchema(buildSourceSchema()) + .build())); + } + + @Test + public void nullSourceSchemaThrows() throws Exception { + assertThrows( + InvalidDMLGenerationException.class, + () -> + new SpannerDMLGenerator() + .getDMLStatement( + new DMLGeneratorRequest.Builder( + "INSERT", + "Singers", + new JSONObject("{}"), + new JSONObject("{\"SingerId\":\"1\"}"), + "+00:00") + .setSchemaMapper(buildIdentityMapper()) + .setDdl(buildDdl()) + .build())); + } +} diff --git a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java index c101b3d710..e159839381 100644 --- a/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java +++ b/v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/processor/SourceProcessorFactoryTest.java @@ -21,12 +21,16 @@ import com.google.cloud.teleport.v2.spanner.migrations.connection.JdbcConnectionHelper; import com.google.cloud.teleport.v2.spanner.migrations.shard.CassandraShard; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; +import com.google.cloud.teleport.v2.spanner.migrations.shard.SpannerShard; import com.google.cloud.teleport.v2.templates.constants.Constants; import com.google.cloud.teleport.v2.templates.dbutils.connection.CassandraConnectionHelper; +import com.google.cloud.teleport.v2.templates.dbutils.connection.SpannerConnectionHelper; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.CassandraDao; import com.google.cloud.teleport.v2.templates.dbutils.dao.source.JdbcDao; +import com.google.cloud.teleport.v2.templates.dbutils.dao.source.SpannerTargetDao; import com.google.cloud.teleport.v2.templates.dbutils.dml.CassandraDMLGenerator; import com.google.cloud.teleport.v2.templates.dbutils.dml.MySQLDMLGenerator; +import com.google.cloud.teleport.v2.templates.dbutils.dml.SpannerDMLGenerator; import com.google.cloud.teleport.v2.templates.exceptions.UnsupportedSourceException; import java.util.Arrays; import java.util.List; @@ -109,4 +113,25 @@ public void testCreateSourceProcessor_cassandra_validSource() throws Exception { Assert.assertEquals(1, processor.getSourceDaoMap().size()); Assert.assertTrue(processor.getSourceDaoMap().get("shard1") instanceof CassandraDao); } + + @Test + public void testCreateSourceProcessor_spanner_validSource() throws Exception { + SpannerShard spannerShard = + new SpannerShard("shard1", "my-project", "my-instance", "my-database"); + + List shards = List.of(spannerShard); + int maxConnections = 10; + SpannerConnectionHelper mockConnectionHelper = Mockito.mock(SpannerConnectionHelper.class); + doNothing().when(mockConnectionHelper).init(any()); + SourceProcessorFactory.setConnectionHelperMap( + Map.of(Constants.SOURCE_SPANNER, mockConnectionHelper)); + SourceProcessor processor = + SourceProcessorFactory.createSourceProcessor( + Constants.SOURCE_SPANNER, shards, maxConnections); + + Assert.assertNotNull(processor); + Assert.assertTrue(processor.getDmlGenerator() instanceof SpannerDMLGenerator); + Assert.assertEquals(1, processor.getSourceDaoMap().size()); + Assert.assertTrue(processor.getSourceDaoMap().get("shard1") instanceof SpannerTargetDao); + } }