Skip to content

Commit 2ee9176

Browse files
authored
feat(pgbulk): add retry with filtered df for pg output write (#10)
1 parent 9ce2a5f commit 2ee9176

5 files changed

Lines changed: 173 additions & 40 deletions

File tree

pom.xml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
<hadoop.version>2.6.5</hadoop.version>
3232
<spark-solr.version>4.0.4</spark-solr.version>
3333
<wiremock.version>3.3.1</wiremock.version>
34+
<postgrestest.version>1.20.4</postgrestest.version>
3435
<!-- Sonar -->
3536
<sonar.qualitygate.wait>true</sonar.qualitygate.wait>
3637

@@ -296,6 +297,12 @@
296297
<version>${wiremock.version}</version>
297298
<scope>test</scope>
298299
</dependency>
300+
<dependency>
301+
<groupId>org.testcontainers</groupId>
302+
<artifactId>postgresql</artifactId>
303+
<version>${postgrestest.version}</version>
304+
<scope>test</scope>
305+
</dependency>
299306

300307

301308

@@ -446,7 +453,9 @@
446453
</execution>
447454
</executions>
448455
<configuration>
449-
<argLine>--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED</argLine>
456+
<argLine>
457+
--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED
458+
</argLine>
450459
</configuration>
451460
</plugin>
452461

src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
5050
}
5151

5252
/**
53-
* This loads both a cohort and its definition into postgres and solr
54-
*/
53+
* This loads both a cohort and its definition into postgres and solr
54+
*/
5555
override def updateCohort(cohortId: Long,
5656
cohort: DataFrame,
5757
sourcePopulation: SourcePopulation,
@@ -68,9 +68,9 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
6868
.withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid")
6969
.withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid")))
7070
.select(F.col("_itemreferenceid"),
71-
F.col("item__reference"),
72-
F.col("_provider"),
73-
F.col("_listid"))
71+
F.col("item__reference"),
72+
F.col("_provider"),
73+
F.col("_listid"))
7474

7575
uploadCohortTableToPG(dataframe)
7676

@@ -172,7 +172,10 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
172172
).toSet == df.columns.toSet,
173173
"cohort dataframe shall have _listid, _provider, _provider and item__reference"
174174
)
175-
pg.outputBulk(cohort_item_table_rw, dfAddHash(df), Some(4))
175+
pg.outputBulk(cohort_item_table_rw,
176+
dfAddHash(df),
177+
Some(4),
178+
primaryKeys = Seq("_listid", "_itemreferenceid", "_provider"))
176179
}
177180

178181
/**

src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import java.io._
1414
import java.sql._
1515
import java.util.Properties
1616
import java.util.UUID.randomUUID
17+
import scala.util.{Failure, Success, Try}
1718

1819
sealed trait BulkLoadMode
1920

@@ -35,8 +36,8 @@ class PGTool(
3536

3637
private var password: String = ""
3738

38-
def setPassword(pwd: String = ""): PGTool = {
39-
password = PGTool.passwordFromConn(url, pwd)
39+
def setPassword(pwd: String): PGTool = {
40+
password = pwd
4041
this
4142
}
4243

@@ -73,7 +74,8 @@ class PGTool(
7374
table: String,
7475
df: Dataset[Row],
7576
numPartitions: Option[Int] = None,
76-
reindex: Boolean = false
77+
reindex: Boolean = false,
78+
primaryKeys: Seq[String] = Seq.empty
7779
): PGTool = {
7880
PGTool.outputBulk(
7981
spark,
@@ -84,7 +86,8 @@ class PGTool(
8486
numPartitions.getOrElse(8),
8587
password,
8688
reindex,
87-
bulkLoadBufferSize
89+
bulkLoadBufferSize,
90+
primaryKeys = primaryKeys
8891
)
8992
this
9093
}
@@ -106,45 +109,44 @@ object PGTool extends java.io.Serializable with LazyLogging {
106109
url: String,
107110
tmpPath: String,
108111
bulkLoadMode: BulkLoadMode = defaultBulkLoadStrategy,
109-
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
112+
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
113+
pgPassFile: Path = new Path(scala.sys.env("HOME"), ".pgpass")
110114
): PGTool = {
111115
new PGTool(
112116
spark,
113117
url,
114118
tmpPath + "/spark-postgres-" + randomUUID.toString,
115119
bulkLoadMode,
116120
bulkLoadBufferSize
117-
).setPassword()
121+
).setPassword(passwordFromConn(url, pgPassFile))
118122
}
119123

120-
def connOpen(url: String, password: String = ""): Connection = {
124+
def connOpen(url: String, password: String): Connection = {
121125
val prop = new Properties()
122126
prop.put("driver", "org.postgresql.Driver")
123-
prop.put("password", passwordFromConn(url, password))
127+
prop.put("password", password)
124128
DriverManager.getConnection(url, prop)
125129
}
126130

127-
def passwordFromConn(url: String, password: String): String = {
128-
if (password.nonEmpty) {
129-
return password
130-
}
131+
def passwordFromConn(url: String, pgPassFile: Path): String = {
131132
val pattern = "jdbc:postgresql://(.*):(\\d+)/(\\w+)[?]user=(\\w+).*".r
132133
val pattern(host, port, database, username) = url
133-
dbPassword(host, port, database, username)
134+
dbPassword(host, port, database, username, pgPassFile)
134135
}
135136

136137
private def dbPassword(
137138
hostname: String,
138139
port: String,
139140
database: String,
140-
username: String
141+
username: String,
142+
pgPassFile: Path
141143
): String = {
142144
// Usage: val thatPassWord = dbPassword(hostname,port,database,username)
143145
// .pgpass file format, hostname:port:database:username:password
144146

145147
val fs = FileSystem.get(new java.net.URI("file:///"), new Configuration)
146148
val reader = new BufferedReader(
147-
new InputStreamReader(fs.open(new Path(scala.sys.env("HOME"), ".pgpass")))
149+
new InputStreamReader(fs.open(pgPassFile))
148150
)
149151
val content = Iterator
150152
.continually(reader.readLine())
@@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging {
185187
numPartitions: Int = 8,
186188
password: String = "",
187189
reindex: Boolean = false,
188-
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
190+
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
191+
withRetry: Boolean = true,
192+
primaryKeys: Seq[String] = Seq.empty
189193
): Unit = {
190194
logger.debug("using CSV strategy")
191195
try {
@@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
208212
.mode(org.apache.spark.sql.SaveMode.Overwrite)
209213
.save(path)
210214

211-
outputBulkCsvLow(
215+
val success = outputBulkCsvLow(
212216
spark,
213217
url,
214218
table,
@@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging {
220224
password,
221225
bulkLoadBufferSize
222226
)
227+
if (!success) {
228+
if (!withRetry) {
229+
throw new Exception("Bulk load failed")
230+
} else {
231+
logger.warn(
232+
"Bulk load failed, retrying with filtering existing items"
233+
)
234+
// try again with filtering the original dataframe with existing items
235+
val selectedColumns = if (primaryKeys.isEmpty) "*" else primaryKeys.map(sanP).mkString(",")
236+
val existingItems = sqlExecWithResult(
237+
spark,
238+
url,
239+
s"SELECT $selectedColumns FROM $table",
240+
password
241+
)
242+
val existingItemsSet = existingItems.collect().map(_.mkString(",")).toSet
243+
val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _*)
244+
val dfFiltered = dfWithSelectedColumns
245+
.filter(
246+
row =>
247+
!existingItemsSet.contains(
248+
row.mkString(",")
249+
)
250+
)
251+
outputBulk(
252+
spark,
253+
url,
254+
table,
255+
dfFiltered,
256+
path,
257+
numPartitions,
258+
password,
259+
reindex,
260+
bulkLoadBufferSize,
261+
withRetry = false
262+
)
263+
}
264+
}
265+
223266
} finally {
224267
if (reindex)
225268
indexReactivate(url, table, password)
@@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
235278
numPartitions: Int = 8,
236279
password: String = "",
237280
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
238-
): Unit = {
281+
): Boolean = {
239282

240283
// load the csv files from hdfs in parallel
241284
val fs = FileSystem.get(new Configuration())
@@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging {
251294
.rdd
252295
.partitionBy(new ExactPartitioner(numPartitions))
253296

254-
rdd.foreachPartition(x => {
297+
val statusRdd = rdd.mapPartitions(x => {
255298
val conn = connOpen(url, password)
256-
x.foreach { s =>
257-
{
258-
val stream: InputStream = FileSystem
259-
.get(new Configuration())
260-
.open(new Path(s._2))
261-
.getWrappedStream
262-
val copyManager: CopyManager =
263-
new CopyManager(conn.asInstanceOf[BaseConnection])
264-
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
265-
}
299+
val res = Try {
300+
x.map { s =>
301+
{
302+
val stream: InputStream = FileSystem
303+
.get(new Configuration())
304+
.open(new Path(s._2))
305+
.getWrappedStream
306+
val copyManager: CopyManager =
307+
new CopyManager(conn.asInstanceOf[BaseConnection])
308+
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
309+
}
310+
}.toList
266311
}
267312
conn.close()
268-
x.toIterator
313+
res match {
314+
case Success(_) => Iterator(true) // Partition succeeded
315+
case Failure(error) => {
316+
logger.error("Partition output loading failed", error)
317+
Iterator(false) // Partition failed
318+
}
319+
}
269320
})
321+
!statusRdd.collect().contains(false)
270322
}
271323

272324
def outputBulkCsvLow(
@@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
280332
extensionPattern: String = ".*.csv",
281333
password: String = "",
282334
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
283-
): Unit = {
335+
): Boolean = {
284336
val csvSqlCopy =
285337
s"""COPY "$table" ($columns) FROM STDIN WITH CSV DELIMITER '$delimiter' NULL '' ESCAPE '"' QUOTE '"' """
286338
outputBulkFileLow(
@@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
301353
schema
302354
}
303355

304-
def indexDeactivate(url: String, table: String, password: String = ""): Unit = {
356+
def indexDeactivate(url: String, table: String, password: String): Unit = {
305357
val schema = getSchema(url)
306358
val query =
307359
s"""
@@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
316368
logger.debug(s"Deactivating indexes from $schema.$table")
317369
}
318370

319-
def indexReactivate(url: String, table: String, password: String = ""): Unit = {
371+
def indexReactivate(url: String, table: String, password: String): Unit = {
320372

321373
val schema = getSchema(url)
322374
val query =

src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer {
114114
ArgumentMatchers.eq("list__entry_cohort360"),
115115
df.capture(),
116116
ArgumentMatchers.eq(Some(4)),
117+
ArgumentMatchersSugar.*,
117118
ArgumentMatchersSugar.*
118119
)
119120
assertSmallDatasetEquality(df.getValue.asInstanceOf[DataFrame], expectedDf)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package fr.aphp.id.eds.requester.cohort.pg
2+
3+
import org.apache.commons.io.FileUtils
4+
import org.apache.hadoop.fs.Path
5+
import org.apache.spark.sql.SparkSession
6+
import org.apache.spark.sql.functions.col
7+
import org.scalatest.BeforeAndAfterAll
8+
import org.scalatest.flatspec.AnyFlatSpec
9+
import org.scalatest.matchers.should.Matchers
10+
import org.testcontainers.containers.PostgreSQLContainer
11+
import org.scalatest.funsuite.AnyFunSuiteLike
12+
13+
import java.nio.file.{Files, Path}
14+
15+
class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll {
16+
val sparkSession: SparkSession = SparkSession
17+
.builder()
18+
.master("local[*]")
19+
.getOrCreate()
20+
private var tempDir: java.nio.file.Path = _
21+
private val postgresContainer = new PostgreSQLContainer("postgres:15.3")
22+
23+
override def beforeAll(): Unit = {
24+
super.beforeAll()
25+
tempDir = Files.createTempDirectory("test-temp-dir")
26+
postgresContainer.start()
27+
postgresContainer.withPassword("test")
28+
postgresContainer.withUsername("test")
29+
val pgPassFile = tempDir.resolve(".pgpass")
30+
Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes)
31+
}
32+
33+
override def afterAll(): Unit = {
34+
super.afterAll()
35+
FileUtils.deleteDirectory(tempDir.toFile)
36+
postgresContainer.stop()
37+
}
38+
39+
test("testOutputBulk") {
40+
import sparkSession.implicits._
41+
val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}&currentSchema=public"
42+
val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString))
43+
val createTableQuery = """
44+
CREATE TABLE test_table (
45+
id INT PRIMARY KEY,
46+
value TEXT,
47+
id_2 INT
48+
)
49+
"""
50+
pgTool.sqlExec(createTableQuery)
51+
52+
val insertDataQuery = """
53+
INSERT INTO test_table (id, value, id_2) VALUES
54+
(1, '1', 1),
55+
(2, '2', 2)
56+
"""
57+
pgTool.sqlExec(insertDataQuery)
58+
val baseContent = pgTool.sqlExecWithResult("select * from test_table")
59+
baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2)
60+
61+
// generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database
62+
val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id"))
63+
pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2"))
64+
val updatedContent = pgTool.sqlExecWithResult("select * from test_table")
65+
updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100)
66+
}
67+
68+
}

0 commit comments

Comments
 (0)