@@ -14,6 +14,7 @@ import java.io._
1414import java .sql ._
1515import java .util .Properties
1616import java .util .UUID .randomUUID
17+ import scala .util .{Failure , Success , Try }
1718
1819sealed trait BulkLoadMode
1920
@@ -35,8 +36,8 @@ class PGTool(
3536
3637 private var password : String = " "
3738
38- def setPassword (pwd : String = " " ): PGTool = {
39- password = PGTool .passwordFromConn(url, pwd)
39+ def setPassword (pwd : String ): PGTool = {
40+ password = pwd
4041 this
4142 }
4243
@@ -73,7 +74,8 @@ class PGTool(
7374 table : String ,
7475 df : Dataset [Row ],
7576 numPartitions : Option [Int ] = None ,
76- reindex : Boolean = false
77+ reindex : Boolean = false ,
78+ primaryKeys : Seq [String ] = Seq .empty
7779 ): PGTool = {
7880 PGTool .outputBulk(
7981 spark,
@@ -84,7 +86,8 @@ class PGTool(
8486 numPartitions.getOrElse(8 ),
8587 password,
8688 reindex,
87- bulkLoadBufferSize
89+ bulkLoadBufferSize,
90+ primaryKeys = primaryKeys
8891 )
8992 this
9093 }
@@ -106,45 +109,44 @@ object PGTool extends java.io.Serializable with LazyLogging {
106109 url : String ,
107110 tmpPath : String ,
108111 bulkLoadMode : BulkLoadMode = defaultBulkLoadStrategy,
109- bulkLoadBufferSize : Int = defaultBulkLoadBufferSize
112+ bulkLoadBufferSize : Int = defaultBulkLoadBufferSize,
113+ pgPassFile : Path = new Path (scala.sys.env(" HOME" ), " .pgpass" )
110114 ): PGTool = {
111115 new PGTool (
112116 spark,
113117 url,
114118 tmpPath + " /spark-postgres-" + randomUUID.toString,
115119 bulkLoadMode,
116120 bulkLoadBufferSize
117- ).setPassword()
121+ ).setPassword(passwordFromConn(url, pgPassFile) )
118122 }
119123
120- def connOpen (url : String , password : String = " " ): Connection = {
124+ def connOpen (url : String , password : String ): Connection = {
121125 val prop = new Properties ()
122126 prop.put(" driver" , " org.postgresql.Driver" )
123- prop.put(" password" , passwordFromConn(url, password) )
127+ prop.put(" password" , password)
124128 DriverManager .getConnection(url, prop)
125129 }
126130
127- def passwordFromConn (url : String , password : String ): String = {
128- if (password.nonEmpty) {
129- return password
130- }
131+ def passwordFromConn (url : String , pgPassFile : Path ): String = {
131132 val pattern = " jdbc:postgresql://(.*):(\\ d+)/(\\ w+)[?]user=(\\ w+).*" .r
132133 val pattern(host, port, database, username) = url
133- dbPassword(host, port, database, username)
134+ dbPassword(host, port, database, username, pgPassFile )
134135 }
135136
136137 private def dbPassword (
137138 hostname : String ,
138139 port : String ,
139140 database : String ,
140- username : String
141+ username : String ,
142+ pgPassFile : Path
141143 ): String = {
142144 // Usage: val thatPassWord = dbPassword(hostname,port,database,username)
143145 // .pgpass file format, hostname:port:database:username:password
144146
145147 val fs = FileSystem .get(new java.net.URI (" file:///" ), new Configuration )
146148 val reader = new BufferedReader (
147- new InputStreamReader (fs.open(new Path (scala.sys.env( " HOME " ), " .pgpass " ) ))
149+ new InputStreamReader (fs.open(pgPassFile ))
148150 )
149151 val content = Iterator
150152 .continually(reader.readLine())
@@ -185,7 +187,9 @@ object PGTool extends java.io.Serializable with LazyLogging {
185187 numPartitions : Int = 8 ,
186188 password : String = " " ,
187189 reindex : Boolean = false ,
188- bulkLoadBufferSize : Int = defaultBulkLoadBufferSize
190+ bulkLoadBufferSize : Int = defaultBulkLoadBufferSize,
191+ withRetry : Boolean = true ,
192+ primaryKeys : Seq [String ] = Seq .empty
189193 ): Unit = {
190194 logger.debug(" using CSV strategy" )
191195 try {
@@ -208,7 +212,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
208212 .mode(org.apache.spark.sql.SaveMode .Overwrite )
209213 .save(path)
210214
211- outputBulkCsvLow(
215+ val success = outputBulkCsvLow(
212216 spark,
213217 url,
214218 table,
@@ -220,6 +224,45 @@ object PGTool extends java.io.Serializable with LazyLogging {
220224 password,
221225 bulkLoadBufferSize
222226 )
227+ if (! success) {
228+ if (! withRetry) {
229+ throw new Exception (" Bulk load failed" )
230+ } else {
231+ logger.warn(
232+ " Bulk load failed, retrying with filtering existing items"
233+ )
234+ // try again with filtering the original dataframe with existing items
235+ val selectedColumns = if (primaryKeys.isEmpty) " *" else primaryKeys.map(sanP).mkString(" ," )
236+ val existingItems = sqlExecWithResult(
237+ spark,
238+ url,
239+ s " SELECT $selectedColumns FROM $table" ,
240+ password
241+ )
242+ val existingItemsSet = existingItems.collect().map(_.mkString(" ," )).toSet
243+ val dfWithSelectedColumns = if (primaryKeys.isEmpty) df else df.select(primaryKeys.map(col): _* )
244+ val dfFiltered = dfWithSelectedColumns
245+ .filter(
246+ row =>
247+ ! existingItemsSet.contains(
248+ row.mkString(" ," )
249+ )
250+ )
251+ outputBulk(
252+ spark,
253+ url,
254+ table,
255+ dfFiltered,
256+ path,
257+ numPartitions,
258+ password,
259+ reindex,
260+ bulkLoadBufferSize,
261+ withRetry = false
262+ )
263+ }
264+ }
265+
223266 } finally {
224267 if (reindex)
225268 indexReactivate(url, table, password)
@@ -235,7 +278,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
235278 numPartitions : Int = 8 ,
236279 password : String = " " ,
237280 bulkLoadBufferSize : Int = defaultBulkLoadBufferSize
238- ): Unit = {
281+ ): Boolean = {
239282
240283 // load the csv files from hdfs in parallel
241284 val fs = FileSystem .get(new Configuration ())
@@ -251,22 +294,31 @@ object PGTool extends java.io.Serializable with LazyLogging {
251294 .rdd
252295 .partitionBy(new ExactPartitioner (numPartitions))
253296
254- rdd.foreachPartition (x => {
297+ val statusRdd = rdd.mapPartitions (x => {
255298 val conn = connOpen(url, password)
256- x.foreach { s =>
257- {
258- val stream : InputStream = FileSystem
259- .get(new Configuration ())
260- .open(new Path (s._2))
261- .getWrappedStream
262- val copyManager : CopyManager =
263- new CopyManager (conn.asInstanceOf [BaseConnection ])
264- copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
265- }
299+ val res = Try {
300+ x.map { s =>
301+ {
302+ val stream : InputStream = FileSystem
303+ .get(new Configuration ())
304+ .open(new Path (s._2))
305+ .getWrappedStream
306+ val copyManager : CopyManager =
307+ new CopyManager (conn.asInstanceOf [BaseConnection ])
308+ copyManager.copyIn(sqlCopy, stream, bulkLoadBufferSize)
309+ }
310+ }.toList
266311 }
267312 conn.close()
268- x.toIterator
313+ res match {
314+ case Success (_) => Iterator (true ) // Partition succeeded
315+ case Failure (error) => {
316+ logger.error(" Partition output loading failed" , error)
317+ Iterator (false ) // Partition failed
318+ }
319+ }
269320 })
321+ ! statusRdd.collect().contains(false )
270322 }
271323
272324 def outputBulkCsvLow (
@@ -280,7 +332,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
280332 extensionPattern : String = " .*.csv" ,
281333 password : String = " " ,
282334 bulkLoadBufferSize : Int = defaultBulkLoadBufferSize
283- ): Unit = {
335+ ): Boolean = {
284336 val csvSqlCopy =
285337 s """ COPY " $table" ( $columns) FROM STDIN WITH CSV DELIMITER ' $delimiter' NULL '' ESCAPE '"' QUOTE '"' """
286338 outputBulkFileLow(
@@ -301,7 +353,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
301353 schema
302354 }
303355
304- def indexDeactivate (url : String , table : String , password : String = " " ): Unit = {
356+ def indexDeactivate (url : String , table : String , password : String ): Unit = {
305357 val schema = getSchema(url)
306358 val query =
307359 s """
@@ -316,7 +368,7 @@ object PGTool extends java.io.Serializable with LazyLogging {
316368 logger.debug(s " Deactivating indexes from $schema. $table" )
317369 }
318370
319- def indexReactivate (url : String , table : String , password : String = " " ): Unit = {
371+ def indexReactivate (url : String , table : String , password : String ): Unit = {
320372
321373 val schema = getSchema(url)
322374 val query =
0 commit comments