diff --git a/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySQLEngine.java b/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySQLEngine.java index e7054bff37..e1ed4c0f90 100644 --- a/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySQLEngine.java +++ b/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySQLEngine.java @@ -317,7 +317,11 @@ public SQLDatasetProducer getProducer(SQLPullRequest pullRequest, PullCapability String table = datasets.get(pullRequest.getDatasetName()).getBigQueryTable(); - return new BigQuerySparkDatasetProducer(sqlEngineConfig, datasetProject, dataset, table); + return new BigQuerySparkDatasetProducer(sqlEngineConfig, + datasetProject, + dataset, + table, + pullRequest.getDatasetSchema()); } @Override diff --git a/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySparkDatasetProducer.java b/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySparkDatasetProducer.java index aaad0952fd..b4de6b60f3 100644 --- a/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySparkDatasetProducer.java +++ b/src/main/java/io/cdap/plugin/gcp/bigquery/sqlengine/BigQuerySparkDatasetProducer.java @@ -16,17 +16,20 @@ package io.cdap.plugin.gcp.bigquery.sqlengine; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.etl.api.engine.sql.dataset.RecordCollection; import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset; import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetDescription; import io.cdap.cdap.etl.api.engine.sql.dataset.SQLDatasetProducer; import io.cdap.cdap.etl.api.sql.engine.dataset.SparkRecordCollectionImpl; -import io.cdap.plugin.gcp.common.GCPConfig; import org.apache.spark.SparkContext; import org.apache.spark.sql.DataFrameReader; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.Serializable; import java.nio.charset.StandardCharsets; @@ -39,6 +42,8 @@ public class BigQuerySparkDatasetProducer implements SQLDatasetProducer, Serializable { + private static final Logger LOG = LoggerFactory.getLogger(BigQuerySparkDatasetProducer.class); + private static final String FORMAT = "bigquery"; private static final String CONFIG_CREDENTIALS_FILE = "credentialsFile"; private static final String CONFIG_CREDENTIALS = "credentials"; @@ -47,15 +52,19 @@ public class BigQuerySparkDatasetProducer private String project; private String bqDataset; private String bqTable; + private Schema schema; + public BigQuerySparkDatasetProducer(BigQuerySQLEngineConfig config, String project, String bqDataset, - String bqTable) { + String bqTable, + Schema schema) { this.config = config; this.project = project; this.bqDataset = bqDataset; this.bqTable = bqTable; + this.schema = schema; } @Override @@ -87,6 +96,7 @@ public RecordCollection produce(SQLDataset sqlDataset) { // Load path into dataset. Dataset ds = bqReader.load(path); + ds = convertFieldTypes(ds); return new SparkRecordCollectionImpl(ds); } @@ -95,4 +105,37 @@ public RecordCollection produce(SQLDataset sqlDataset) { private String encodeBase64(String serviceAccountJson) { return Base64.getEncoder().encodeToString(serviceAccountJson.getBytes(StandardCharsets.UTF_8)); } + + /** + * Adjust CDAP types for int and float fields. + * + * @param ds input dataframe + * @return dataframe with updated schema. + */ + private Dataset convertFieldTypes(Dataset ds) { + for (Schema.Field field : schema.getFields()) { + String fieldName = field.getName(); + Schema fieldSchema = field.getSchema(); + + // For nullable types, check the underlying type. + if (fieldSchema.isNullable()) { + fieldSchema = fieldSchema.getNonNullable(); + } + + // Handle Int types + if (fieldSchema.getType() == Schema.Type.INT && fieldSchema.getLogicalType() == null) { + LOG.trace("Converting field {} to Integer", fieldName); + ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.IntegerType)); + } + + // Handle float types + if (fieldSchema.getType() == Schema.Type.FLOAT && fieldSchema.getLogicalType() == null) { + LOG.trace("Converting field {} to Float", fieldName); + ds = ds.withColumn(fieldName, ds.col(fieldName).cast(DataTypes.FloatType)); + } + } + + return ds; + } + }