Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions database/src/main/scala/no/ndla/database/TableIdType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Part of NDLA database
* Copyright (C) 2026 NDLA
*
* See LICENSE
*
*/

package no.ndla.database

import scalikejdbc.*

import java.util.UUID as JavaUUID

sealed trait TableIdType {
type ScalaType
def zeroValueScala: ScalaType
def zeroValueSql: SQLSyntax
def fromResultSet(rs: WrappedResultSet): ScalaType
}

object TableIdType {
case object Bigint extends TableIdType {
override type ScalaType = Long
override def zeroValueSql: SQLSyntax = sqls"0::bigint"
override def zeroValueScala: ScalaType = 0L
override def fromResultSet(rs: WrappedResultSet): ScalaType = rs.long("id")
}

case object UUID extends TableIdType {
override type ScalaType = JavaUUID
override def zeroValueSql: SQLSyntax = sqls"'00000000-0000-0000-0000-000000000000'::uuid"
override def zeroValueScala: ScalaType = JavaUUID(0L, 0L)
override def fromResultSet(rs: WrappedResultSet): ScalaType = JavaUUID.fromString(rs.string("id"))
}
}
44 changes: 23 additions & 21 deletions database/src/main/scala/no/ndla/database/TableMigration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,43 @@ package no.ndla.database
import org.flywaydb.core.api.migration.{BaseJavaMigration, Context}
import scalikejdbc.*

/** Base class for Scala-based migrations.
*
* **NOTE:** If the table you are migrating does not use `bigint` as the ID type, you must override `tableIdType` to
* return the correct [[TableIdType]].
*/
abstract class TableMigration[ROW_DATA] extends BaseJavaMigration {
val tableName: String
lazy val whereClause: SQLSyntax
val chunkSize: Int = 1000
def extractRowData(rs: WrappedResultSet): ROW_DATA
def updateRow(rowData: ROW_DATA)(implicit session: DBSession): Int
lazy val tableNameSQL: SQLSyntax = SQLSyntax.createUnsafely(tableName)

private def countAllRows(implicit session: DBSession): Option[Long] = {
sql"select count(*) from $tableNameSQL where $whereClause".map(rs => rs.long("count")).single()
}

private def allRows(offset: Long)(implicit session: DBSession): Seq[ROW_DATA] = {
sql"select * from $tableNameSQL where $whereClause order by id limit $chunkSize offset $offset"
.map(rs => extractRowData(rs))
.list()
}
val tableIdType: TableIdType = TableIdType.Bigint

override def migrate(context: Context): Unit = DB(context.getConnection)
.autoClose(false)
.withinTx { session =>
migrateRows(using session)
}

protected def migrateRows(implicit session: DBSession): Unit = {
val count = countAllRows.get
var numPagesLeft = (count / chunkSize) + 1
var offset = 0L

while (numPagesLeft > 0) {
allRows(offset * chunkSize).map { rowData =>
updateRow(rowData)
}: Unit
numPagesLeft -= 1
offset += 1
protected def migrateRows(implicit session: DBSession): Unit = Iterator
.unfold(tableIdType.zeroValueScala) { lastId =>
getRowChunk(lastId) match {
case Nil => None
case chunk => Some((chunk, chunk.last._1))
}
}
.takeWhile(_.nonEmpty)
.foreach { chunk =>
chunk.foreach((_, rowData) => updateRow(rowData))
}

private def getRowChunk(
lastId: tableIdType.ScalaType
)(implicit session: DBSession): Seq[(tableIdType.ScalaType, ROW_DATA)] = {
sql"select * from $tableNameSQL where $whereClause and id > $lastId order by id limit $chunkSize"
.map(rs => (tableIdType.fromResultSet(rs), extractRowData(rs)))
.list()
}
}
157 changes: 157 additions & 0 deletions database/src/test/scala/no/ndla/database/TableMigrationTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Part of NDLA database
* Copyright (C) 2026 NDLA
*
* See LICENSE
*
*/

package no.ndla.database

import no.ndla.scalatestsuite.{DatabaseIntegrationSuite, UnitTestSuite}
import org.flywaydb.core.Flyway
import scalikejdbc.*

import java.util.UUID

class TableMigrationTest extends DatabaseIntegrationSuite, UnitTestSuite, TestEnvironment {
val dataSource: DataSource = testDataSource.get
val schema: String = "testschema"
val schemaSql: SQLSyntax = SQLSyntax.createUnsafely(schema)
val intTableName: String = "test"
val intTableNameSql: SQLSyntax = SQLSyntax.createUnsafely(intTableName)
val uuidTableName: String = "test2"
val uuidTableNameSql: SQLSyntax = SQLSyntax.createUnsafely(uuidTableName)

override def beforeAll(): Unit = {
super.beforeAll()

dataSource.connectToDatabase()
}

override def beforeEach(): Unit = {
super.beforeEach()

DB.autoCommit { implicit session =>
sql"""
drop schema if exists $schemaSql cascade;
create schema $schemaSql;
create table $intTableNameSql (id int primary key, data text);
create table $uuidTableNameSql (id uuid primary key, data text);""".execute()
}
}

private def insertIdsFromRange(range: Range): Unit = {
DB.autoCommit { implicit session =>
val sqlInsertParts = range.map(id => sqls"insert into $intTableNameSql (id, data) values ($id, ${"row" + id})")
val joinedSqlInsert = SQLSyntax.join(sqlInsertParts, sqls";")
sql"$joinedSqlInsert".execute()
}
}

private def runMigration[A](migration: TableMigration[A]): Unit = {
val flyway = Flyway
.configure()
.javaMigrations(migration)
.dataSource(dataSource)
.schemas(schema)
.baselineVersion("00")
.baselineOnMigrate(true)
.load()

flyway.migrate()
}

test("that all rows are updated with no where clause") {
insertIdsFromRange(1 to 50)

class V01__Foo extends TableMigration[Long] {
override val tableName: String = intTableName
override lazy val whereClause: SQLSyntax = sqls"true"
override val chunkSize: Int = 10

override def extractRowData(rs: WrappedResultSet): Long = rs.long("id")

override def updateRow(rowData: Long)(implicit session: DBSession): Int = {
sql"update $intTableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update()
}
}

runMigration(V01__Foo())

DB.readOnly { implicit session =>
val updatedRowsCount = sql"select count(*) from $intTableNameSql where data like 'updated_row%'"
.map(_.int(1))
.single()
.get
updatedRowsCount should be(50)
}
}

test("that keyset pagination works correctly") {
val step = 3
insertIdsFromRange(100 to 1 by -step)
val maxIdToUpdate = 50
val expectedUpdateCount = (maxIdToUpdate / step) + 1

class V01__Foo extends TableMigration[Long] {
override val tableName: String = intTableName
override lazy val whereClause: SQLSyntax = sqls"id < $maxIdToUpdate"
override val chunkSize: Int = 10

override def extractRowData(rs: WrappedResultSet): Long = rs.long("id")

override def updateRow(rowData: Long)(implicit session: DBSession): Int = {
sql"update $intTableNameSql set data = ${"updated_row" + rowData} where id = $rowData".update()
}
}

runMigration(V01__Foo())

DB.readOnly { implicit session =>
val updatedIds = sql"select id from $intTableNameSql where data like 'updated_row%' order by id"
.map(_.int("id"))
.list()
all(updatedIds) should be < maxIdToUpdate
updatedIds.length should be(expectedUpdateCount)
}
}

test("that migration works with UUIDs as primary keys") {
val numRows = 50
DB.autoCommit { implicit session =>
val sqlInsertParts = (
1 to numRows
).map { i =>
val uuid = UUID.randomUUID()
sqls"insert into $uuidTableNameSql (id, data) values ($uuid, ${"row" + i})"
}
val joinedSqlInsert = SQLSyntax.join(sqlInsertParts, sqls";")
sql"$joinedSqlInsert".execute()
}

class V01__Foo extends TableMigration[(UUID, String)] {
override val tableName: String = uuidTableName
override lazy val whereClause: SQLSyntax = sqls"true"
override val chunkSize: Int = 10
override val tableIdType: TableIdType = TableIdType.UUID

override def extractRowData(rs: WrappedResultSet): (UUID, String) =
(UUID.fromString(rs.string("id")), rs.string("data"))

override def updateRow(rowData: (UUID, String))(implicit session: DBSession): Int = {
sql"update $uuidTableNameSql set data = ${"updated_row " + rowData._2} where id = ${rowData._1}".update()
}
}

runMigration(V01__Foo())

DB.readOnly { implicit session =>
val updatedRowsCount = sql"select count(*) from $uuidTableNameSql where data like 'updated_row %'"
.map(_.int(1))
.single()
.get
updatedRowsCount should be(numRows)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

package no.ndla.myndlaapi.db.migrationwithdependencies

import no.ndla.database.TableMigration
import no.ndla.database.{TableIdType, TableMigration}
import no.ndla.myndlaapi.integration.TaxonomyApiClient
import scalikejdbc.{DBSession, WrappedResultSet, scalikejdbcSQLInterpolationImplicitDef}

import java.util.UUID

class V16__MigrateResourcePaths(using taxonomyApiClient: TaxonomyApiClient) extends TableMigration[ResourceRow] {
override val tableName: String = "resources"
override lazy val whereClause: scalikejdbc.SQLSyntax = sqls"path is not null"
override val chunkSize: Int = 1000
override val tableName: String = "resources"
override lazy val whereClause: scalikejdbc.SQLSyntax = sqls"path is not null"
override val chunkSize: Int = 1000
override val tableIdType: TableIdType = TableIdType.UUID

override def extractRowData(rs: WrappedResultSet): ResourceRow =
ResourceRow(UUID.fromString(rs.string("id")), rs.string("resource_type"), rs.string("path"))
override def updateRow(rowData: ResourceRow)(implicit session: DBSession): Int = {
Expand Down