Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<hadoop.version>2.6.5</hadoop.version>
<spark-solr.version>4.0.4</spark-solr.version>
<wiremock.version>3.3.1</wiremock.version>
<postgrestest.version>1.20.4</postgrestest.version>
<!-- Sonar -->
<sonar.qualitygate.wait>true</sonar.qualitygate.wait>

Expand Down Expand Up @@ -296,6 +297,12 @@
<version>${wiremock.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<version>${postgrestest.version}</version>
<scope>test</scope>
</dependency>



Expand Down Expand Up @@ -446,7 +453,9 @@
</execution>
</executions>
<configuration>
<argLine>--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED</argLine>
<argLine>
--add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED
</argLine>
</configuration>
</plugin>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
}

/**
* This loads both a cohort and its definition into postgres and solr
*/
* This loads both a cohort and its definition into postgres and solr
*/
override def updateCohort(cohortId: Long,
cohort: DataFrame,
sourcePopulation: SourcePopulation,
Expand All @@ -68,9 +68,9 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
.withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid")
.withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid")))
.select(F.col("_itemreferenceid"),
F.col("item__reference"),
F.col("_provider"),
F.col("_listid"))
F.col("item__reference"),
F.col("_provider"),
F.col("_listid"))

uploadCohortTableToPG(dataframe)

Expand Down Expand Up @@ -172,7 +172,10 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
).toSet == df.columns.toSet,
"cohort dataframe shall have _listid, _provider, _provider and item__reference"
)
pg.outputBulk(cohort_item_table_rw, dfAddHash(df), Some(4))
pg.outputBulk(cohort_item_table_rw,
dfAddHash(df),
Some(4),
primaryKeys = Seq("_listid", "_itemreferenceid", "_provider"))
}

/**
Expand Down
118 changes: 85 additions & 33 deletions src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGTools.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import java.io._
import java.sql._
import java.util.Properties
import java.util.UUID.randomUUID
import scala.util.{Failure, Success, Try}

sealed trait BulkLoadMode

Expand All @@ -35,8 +36,8 @@ class PGTool(

private var password: String = ""

def setPassword(pwd: String = ""): PGTool = {
password = PGTool.passwordFromConn(url, pwd)
def setPassword(pwd: String): PGTool = {
password = pwd
this
}

Expand Down Expand Up @@ -73,7 +74,8 @@ class PGTool(
table: String,
df: Dataset[Row],
numPartitions: Option[Int] = None,
reindex: Boolean = false
reindex: Boolean = false,
primaryKeys: Seq[String] = Seq.empty
): PGTool = {
PGTool.outputBulk(
spark,
Expand All @@ -84,7 +86,8 @@ class PGTool(
numPartitions.getOrElse(8),
password,
reindex,
bulkLoadBufferSize
bulkLoadBufferSize,
primaryKeys = primaryKeys
)
this
}
Expand All @@ -106,45 +109,44 @@ object PGTool extends java.io.Serializable with LazyLogging {
url: String,
tmpPath: String,
bulkLoadMode: BulkLoadMode = defaultBulkLoadStrategy,
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
pgPassFile: Path = new Path(scala.sys.env("HOME"), ".pgpass")
): PGTool = {
new PGTool(
spark,
url,
tmpPath + "/spark-postgres-" + randomUUID.toString,
bulkLoadMode,
bulkLoadBufferSize
).setPassword()
).setPassword(passwordFromConn(url, pgPassFile))
}

def connOpen(url: String, password: String = ""): Connection = {
def connOpen(url: String, password: String): Connection = {
val prop = new Properties()
prop.put("driver", "org.postgresql.Driver")
prop.put("password", passwordFromConn(url, password))
prop.put("password", password)
DriverManager.getConnection(url, prop)
}

def passwordFromConn(url: String, password: String): String = {
if (password.nonEmpty) {
return password
}
def passwordFromConn(url: String, pgPassFile: Path): String = {
val pattern = "jdbc:postgresql://(.*):(\\d+)/(\\w+)[?]user=(\\w+).*".r
val pattern(host, port, database, username) = url
dbPassword(host, port, database, username)
dbPassword(host, port, database, username, pgPassFile)
}

private def dbPassword(
hostname: String,
port: String,
database: String,
username: String
username: String,
pgPassFile: Path
): String = {
// Usage: val thatPassWord = dbPassword(hostname,port,database,username)
// .pgpass file format, hostname:port:database:username:password

val fs = FileSystem.get(new java.net.URI("file:///"), new Configuration)
val reader = new BufferedReader(
new InputStreamReader(fs.open(new Path(scala.sys.env("HOME"), ".pgpass")))
new InputStreamReader(fs.open(pgPassFile))
)
val content = Iterator
.continually(reader.readLine())
Expand Down Expand Up @@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
reindex: Boolean = false,
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize,
withRetry: Boolean = true,
primaryKeys: Seq[String] = Seq.empty
): Unit = {
logger.debug("using CSV strategy")
try {
Expand All @@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
.mode(org.apache.spark.sql.SaveMode.Overwrite)
.save(path)

outputBulkCsvLow(
val success = outputBulkCsvLow(
spark,
url,
table,
Expand All @@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging {
password,
bulkLoadBufferSize
)
if (!success) {
if (!withRetry) {
throw new Exception("Bulk load failed")
} else {
logger.warn(
"Bulk load failed, retrying with filtering existing items"
)
// try again with filtering the original dataframe with existing items
val selectedColumns = if (primaryKeys.isEmpty) "*" else primaryKeys.map(sanP).mkString(",")
val existingItems = sqlExecWithResult(
spark,
url,
s"SELECT $selectedColumns FROM $table",
password
)
val existingItemsSet = existingItems.collect().map(_.mkString(",")).toSet
val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _*)
val dfFiltered = dfWithSelectedColumns
.filter(
row =>
!existingItemsSet.contains(
row.mkString(",")
)
)
outputBulk(
spark,
url,
table,
dfFiltered,
path,
numPartitions,
password,
reindex,
bulkLoadBufferSize,
withRetry = false
)
}
}

} finally {
if (reindex)
indexReactivate(url, table, password)
Expand All @@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
numPartitions: Int = 8,
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
): Unit = {
): Boolean = {

// load the csv files from hdfs in parallel
val fs = FileSystem.get(new Configuration())
Expand All @@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging {
.rdd
.partitionBy(new ExactPartitioner(numPartitions))

rdd.foreachPartition(x => {
val statusRdd = rdd.mapPartitions(x => {
val conn = connOpen(url, password)
x.foreach { s =>
{
val stream: InputStream = FileSystem
.get(new Configuration())
.open(new Path(s._2))
.getWrappedStream
val copyManager: CopyManager =
new CopyManager(conn.asInstanceOf[BaseConnection])
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
}
val res = Try {
x.map { s =>
{
val stream: InputStream = FileSystem
.get(new Configuration())
.open(new Path(s._2))
.getWrappedStream
val copyManager: CopyManager =
new CopyManager(conn.asInstanceOf[BaseConnection])
copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
}
}.toList
}
conn.close()
x.toIterator
res match {
case Success(_) => Iterator(true) // Partition succeeded
case Failure(error) => {
logger.error("Partition output loading failed", error)
Iterator(false) // Partition failed
}
}
})
!statusRdd.collect().contains(false)
}

def outputBulkCsvLow(
Expand All @@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
extensionPattern: String = ".*.csv",
password: String = "",
bulkLoadBufferSize: Int = defaultBulkLoadBufferSize
): Unit = {
): Boolean = {
val csvSqlCopy =
s"""COPY "$table" ($columns) FROM STDIN WITH CSV DELIMITER '$delimiter' NULL '' ESCAPE '"' QUOTE '"' """
outputBulkFileLow(
Expand All @@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
schema
}

def indexDeactivate(url: String, table: String, password: String = ""): Unit = {
def indexDeactivate(url: String, table: String, password: String): Unit = {
val schema = getSchema(url)
val query =
s"""
Expand All @@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
logger.debug(s"Deactivating indexes from $schema.$table")
}

def indexReactivate(url: String, table: String, password: String = ""): Unit = {
def indexReactivate(url: String, table: String, password: String): Unit = {

val schema = getSchema(url)
val query =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer {
ArgumentMatchers.eq("list__entry_cohort360"),
df.capture(),
ArgumentMatchers.eq(Some(4)),
ArgumentMatchersSugar.*,
ArgumentMatchersSugar.*
)
assertSmallDatasetEquality(df.getValue.asInstanceOf[DataFrame], expectedDf)
Expand Down
68 changes: 68 additions & 0 deletions src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGToolTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package fr.aphp.id.eds.requester.cohort.pg

import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.testcontainers.containers.PostgreSQLContainer
import org.scalatest.funsuite.AnyFunSuiteLike

import java.nio.file.{Files, Path}

class PGToolTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterAll {
val sparkSession: SparkSession = SparkSession
.builder()
.master("local[*]")
.getOrCreate()
private var tempDir: java.nio.file.Path = _
private val postgresContainer = new PostgreSQLContainer("postgres:15.3")

override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Files.createTempDirectory("test-temp-dir")
postgresContainer.start()
postgresContainer.withPassword("test")
postgresContainer.withUsername("test")
val pgPassFile = tempDir.resolve(".pgpass")
Files.write(pgPassFile, s"${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}:*:${postgresContainer.getUsername}:${postgresContainer.getPassword}".getBytes)
}

override def afterAll(): Unit = {
super.afterAll()
FileUtils.deleteDirectory(tempDir.toFile)
postgresContainer.stop()
}

test("testOutputBulk") {
import sparkSession.implicits._
val pgUrl = s"jdbc:postgresql://${postgresContainer.getHost}:${postgresContainer.getFirstMappedPort}/${postgresContainer.getDatabaseName}?user=${postgresContainer.getUsername}&currentSchema=public"
val pgTool = PGTool(sparkSession, pgUrl, tempDir.toString, pgPassFile = new org.apache.hadoop.fs.Path(tempDir.resolve(".pgpass").toString))
val createTableQuery = """
CREATE TABLE test_table (
id INT PRIMARY KEY,
value TEXT,
id_2 INT
)
"""
pgTool.sqlExec(createTableQuery)

val insertDataQuery = """
INSERT INTO test_table (id, value, id_2) VALUES
(1, '1', 1),
(2, '2', 2)
"""
pgTool.sqlExec(insertDataQuery)
val baseContent = pgTool.sqlExecWithResult("select * from test_table")
baseContent.collect().map(_.getInt(0)) should contain theSameElementsAs Array(1, 2)

// generate a new dataframe containing 100 elements with 2 columns id and value that will be written to the database
val data = sparkSession.range(100).toDF("id").withColumn("value", 'id.cast("string")).withColumn("id_2", col("id"))
pgTool.outputBulk("test_table", data, primaryKeys = Seq("id", "id_2"))
val updatedContent = pgTool.sqlExecWithResult("select * from test_table")
updatedContent.collect().map(_.getInt(0)) should contain theSameElementsAs (0 until 100)
}

}
Loading