From bcded842fdaa877c311750d524deaad00adeb20d Mon Sep 17 00:00:00 2001 From: Paul Bui-Quang Date: Mon, 3 Mar 2025 16:32:45 +0100 Subject: [PATCH] feat(create): add new createDiff mode --- README.md | 5 +- .../aphp/id/eds/requester/CreateQuery.scala | 69 ++- .../eds/requester/cohort/CohortCreation.scala | 6 + .../cohort/fhir/FhirCohortCreation.scala | 61 ++- .../cohort/pg/PGCohortCreation.scala | 69 ++- .../requester/jobs/SparkJobParameter.scala | 1 + .../eds/requester/server/JobController.scala | 1 + .../id/eds/requester/CreateQueryTest.scala | 145 +++++- .../cohort/fhir/FhirCohortCreationTest.scala | 429 ++++++++++++++++++ .../cohort/pg/PGCohortCreationTest.scala | 147 +++++- 10 files changed, 872 insertions(+), 61 deletions(-) create mode 100644 src/test/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreationTest.scala diff --git a/README.md b/README.md index a70e7a3..12bb0a1 100755 --- a/README.md +++ b/README.md @@ -65,7 +65,9 @@ The job query format is as follows : // or "ratio", this will activate a detailed count of final matched patients per criteria "details": "
", // optional sampling ratio value between 0.0 and 1.0 to limit the number of patients of the cohort to create (it can be used to sample an existing cohort) - "sampling": "" + "sampling": "", + // optional cohort id to use as a base for the "createDiff" mode + "baseCohortId": "" }, "callbackUrl": "" // optional callback url to retrieve the result } @@ -75,6 +77,7 @@ The job query format is as follows : with `mode` being one of the following values: - `count` : Return the number of patients that match the criteria of the `cohortDefinitionSyntax` - `create`: Create a cohort of patients that match the criteria of the `cohortDefinitionSyntax` +- `create_diff`: Create a change list from a base cohort of patients (defined in `modeOptions`) and the new/deleted ones that match the criteria of the `cohortDefinitionSyntax` and `cohortDefinitionSyntax` being a JSON string that represents the criteria described in the [query format section](#query-format). diff --git a/src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala b/src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala index 5a415f1..c27d305 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala @@ -8,12 +8,18 @@ import fr.aphp.id.eds.requester.tools.JobUtils.addEmptyGroup import fr.aphp.id.eds.requester.tools.{JobUtils, JobUtilsService, StageDetails} import org.apache.log4j.Logger import org.apache.spark.sql.{SparkSession, functions => F} +import org.hl7.fhir.r4.model.ListResource.ListMode object CreateOptions extends Enumeration { type CreateOptions = String val sampling = "sampling" } +object CreateDiffOptions extends Enumeration { + type CreateDiffOptions = String + val baseCohortId = "baseCohortId" +} + case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), jobUtilsService: JobUtilsService = JobUtils) extends JobBase { @@ -36,10 +42,13 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), runtime: JobEnv, data: SparkJobParameter ): JobBaseResult = { - implicit val (request, criterionTagsMap, omopTools, resourceResolver, cacheEnabled) = + implicit val (request, criterionTagsMap, omopTools, resourceResolver, cacheEnabled) = { jobUtilsService.initSparkJobRequest(logger, spark, runtime, data) + } + implicit val sparkSession: SparkSession = spark validateRequestOrThrow(request) + validateModeOptionsOrThrow(data) // Init values here because we are in an object (i.e a singleton) and not a class var status: String = "" @@ -72,12 +81,13 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), .filter(c => cohort.columns.contains(c)) .map(c => F.col(c)): _*) .dropDuplicates() + .withColumn("deleted", F.lit(false)) if (data.modeOptions.contains(CreateOptions.sampling)) { val sampling = data.modeOptions(CreateOptions.sampling).toDouble // https://stackoverflow.com/questions/37416825/dataframe-sample-in-apache-spark-scala#comment62349780_37418684 // to be sure to have the right number of rows - cohort = cohort.sample(sampling+0.1).limit((sampling * cohort.count()).round.toInt) + cohort = cohort.sample(sampling + 0.1).limit((sampling * cohort.count()).round.toInt) } cohort.cache() count = cohort.count() @@ -93,8 +103,17 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), data.cohortDefinitionSyntax, data.ownerEntityId, request.resourceType, + if (data.mode == JobType.createDiff && data.modeOptions.contains( + CreateDiffOptions.baseCohortId)) + Some(data.modeOptions(CreateDiffOptions.baseCohortId).toLong) + else None, + if (data.mode == JobType.createDiff) { + ListMode.CHANGES + } else { + ListMode.SNAPSHOT + }, count - )) + )) .getOrElse(-1L) } // get a new cohortId @@ -107,12 +126,34 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), // upload into pg and solr if (omopTools.isDefined) { + val cohortToUpload = + if (data.mode == JobType.createDiff && data.modeOptions.contains( + CreateDiffOptions.baseCohortId)) { + val baseCohortItems = + omopTools.get.readCohortEntries(data.modeOptions(CreateDiffOptions.baseCohortId).toLong) + baseCohortItems + .join(cohort, + baseCohortItems("_itemreferenceid") === cohort(ResultColumn.SUBJECT), + "full_outer") + .filter( + baseCohortItems("_itemreferenceid").isNull || F + .col(ResultColumn.SUBJECT) + .isNull) + .select( + F.coalesce(baseCohortItems("_itemreferenceid"), cohort(ResultColumn.SUBJECT)) + .as(ResultColumn.SUBJECT), + F.when(cohort(ResultColumn.SUBJECT).isNull, true).otherwise(false).as("deleted") + ) + } else { + cohort + } + omopTools.get.updateCohort( cohortDefinitionId, - cohort, + cohortToUpload, completeRequest.sourcePopulation, count, - cohortSizeBiggerThanLimit, + cohortSizeBiggerThanLimit || data.mode == JobType.createDiff, request.resourceType ) } @@ -120,6 +161,24 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(), getCreationResult(cohortDefinitionId, count, status) } + private def validateModeOptionsOrThrow(data: SparkJobParameter): Unit = { + val modeOptions = data.modeOptions + if (data.mode == JobType.createDiff) { + if (modeOptions.contains(CreateOptions.sampling)) { + throw new RuntimeException("Can't use sampling with createDiff mode") + } + if (!modeOptions.contains(CreateDiffOptions.baseCohortId)) { + throw new RuntimeException("baseCohortId is required for createDiff mode") + } + } + if (modeOptions.contains(CreateOptions.sampling)) { + val sampling = modeOptions(CreateOptions.sampling).toDouble + if (sampling <= 0 || sampling > 1) { + throw new RuntimeException("Sampling value should be between 0 and 1") + } + } + } + private def getCreationResult(cohortDefinitionId: Long, count: Long, status: String): JobBaseResult = { diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/CohortCreation.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/CohortCreation.scala index 953e72c..027b1e1 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/cohort/CohortCreation.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/CohortCreation.scala @@ -7,6 +7,8 @@ import fr.aphp.id.eds.requester.query.model.SourcePopulation import fr.aphp.id.eds.requester.query.resolver.rest.DefaultRestFhirClient import fr.aphp.id.eds.requester.{AppConfig, FhirServerConfig, PGConfig} import org.apache.spark.sql.{DataFrame, SparkSession} +import org.hl7.fhir.r4.model.ListResource.ListMode + trait CohortCreation { @@ -24,6 +26,8 @@ trait CohortCreation { cohortDefinitionSyntax: String, ownerEntityId: String, resourceType: String, + baseCohortId: Option[Long], + mode: ListMode, size: Long): Long def updateCohort(cohortId: Long, @@ -33,6 +37,8 @@ trait CohortCreation { delayCohortCreation: Boolean, resourceType: String): Unit + def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame + } object CohortCreation { diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreation.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreation.scala index 8fa9194..16936ef 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreation.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreation.scala @@ -1,13 +1,15 @@ package fr.aphp.id.eds.requester.cohort.fhir +import ca.uhn.fhir.rest.api.{SortOrderEnum, SortSpec} import fr.aphp.id.eds.requester.ResultColumn import fr.aphp.id.eds.requester.cohort.CohortCreation import fr.aphp.id.eds.requester.query.model.SourcePopulation import fr.aphp.id.eds.requester.query.resolver.rest.RestFhirClient -import org.apache.spark.sql.{DataFrame, Row} -import org.hl7.fhir.r4.model.{ListResource, Reference} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.hl7.fhir.r4.model.ListResource.ListMode +import org.hl7.fhir.r4.model.{Bundle, Identifier, ListResource, Reference} -import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation { @@ -25,10 +27,19 @@ class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation cohortDefinitionSyntax: String, ownerEntityId: String, resourceType: String, + baseCohortId: Option[Long], + mode: ListMode, size: Long): Long = { val list = new ListResource() list.setTitle(cohortDefinitionName) - restFhirClient.getClient.create().resource(list).execute().getId.getIdPartAsLong + list.setMode(mode) + if (baseCohortId.isDefined) { + list.setIdentifier( + List(new Identifier().setValue(baseCohortId.get.toString)).asJava + ) + } + val fhirCreatedResource = restFhirClient.getClient.create().resource(list).execute() + fhirCreatedResource.getResource.getIdElement.getIdPartAsLong } override def updateCohort(cohortId: Long, @@ -47,12 +58,50 @@ class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation restFhirClient.getClient.update().resource(list).execute() } + override def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame = { + val baseList = restFhirClient.getClient + .read() + .resource(classOf[ListResource]) + .withId(cohortId.toString) + .execute() + .getEntry + + val diffListsResults: Bundle = restFhirClient.getClient + .search() + .forResource(classOf[ListResource]) + .where(ListResource.IDENTIFIER.exactly().code(cohortId.toString)) + .sort(new SortSpec("date", SortOrderEnum.ASC)) + .execute() + val diffLists = diffListsResults.getEntry.asScala + .map(_.getResource.asInstanceOf[ListResource]) + .filter(l => l.hasMode && l.getMode.equals(ListMode.CHANGES)) + + val diffEntries = diffLists.flatMap(_.getEntry.asScala) + val result = diffEntries.foldLeft(baseList.asScala.map(_.getItem.getReference).toSet) { + (acc, entry) => + val itemId = entry.getItem.getReference + val deleted = entry.getDeleted + if (deleted) { + acc - itemId + } else { + acc + itemId + } + } + + import spark.implicits._ + + result.toSeq.map(id => id.split("/").last.toLong).toDF("_itemreferenceid") + } + private def createEntry(row: Row): ListResource.ListEntryComponent = { val patient = new Reference() - val patientId = row.getAs[String](ResultColumn.SUBJECT) + val patientId = row.getAs[Long](ResultColumn.SUBJECT) + val deleted = row.getAs[Boolean]("deleted") patient.setReference("Patient/" + patientId) - patient.setId(patientId) + patient.setId(patientId.toString) val entry = new ListResource.ListEntryComponent() entry.setItem(patient) + entry.setDeleted(deleted) + entry } } diff --git a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala index afddc6a..901f957 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreation.scala @@ -8,7 +8,8 @@ import fr.aphp.id.eds.requester.query.model.SourcePopulation import fr.aphp.id.eds.requester.tools.SolrTools import fr.aphp.id.eds.requester.{AppConfig, ResultColumn} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, functions => F} +import org.apache.spark.sql.{DataFrame, SparkSession, functions => F} +import org.hl7.fhir.r4.model.ListResource.ListMode /** * @param pg pgTool obj @@ -24,12 +25,19 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { cohortDefinitionSyntax: String, ownerEntityId: String, resourceType: String, + baseCohortId: Option[Long], + mode: ListMode, size: Long): Long = { + val (indentifier_col, identifier_val) = if (baseCohortId.isDefined) { + (" identifier,", s" ${baseCohortId.get.toString},") + } else { + ("", "") + } val stmt = s""" |insert into ${cohort_table_rw} - |(hash, title, ${note_text_column_name}, _sourcereferenceid, source__reference, _provider, source__type, mode, status, subject__type, date, _size) - |values (-1, ?, ?, ?, ?, '$cohort_provider_name', 'Practitioner', 'snapshot', '${CohortStatus.RUNNING}', ?, now(), ?) + |(hash, title, ${note_text_column_name},${indentifier_col} _sourcereferenceid, source__reference, _provider, source__type, mode, status, subject__type, date, _size) + |values (-1, ?, ?,${identifier_val} ?, ?, '$cohort_provider_name', 'Practitioner', '${mode.toCode}', '${CohortStatus.RUNNING}', ?, now(), ?) |returning id |""".stripMargin val result = pg @@ -62,17 +70,22 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { uploadCount(cohortId, count) uploadRelationship(cohortId, sourcePopulation) + val withDeleteField = cohort.columns.contains("deleted") + val selectedColumns = List( + "_itemreferenceid", + "item__reference", + "_provider", + "_listid" + ) ++ (if (withDeleteField) List("deleted") else List()) + val dataframe = cohort .withColumn("_listid", lit(cohortId)) .withColumn("_provider", lit(cohort_provider_name)) .withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid") .withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid"))) - .select(F.col("_itemreferenceid"), - F.col("item__reference"), - F.col("_provider"), - F.col("_listid")) + .select(selectedColumns.map(F.col): _*) - uploadCohortTableToPG(dataframe) + uploadCohortTableToPG(dataframe, withDeleteField) if (!delayCohortCreation && resourceType == ResourceType.patient) uploadCohortTableToSolr(cohortId, dataframe, count) @@ -87,6 +100,40 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { } } + override def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame = { + val stmt = + s""" + |select _itemreferenceid + |from ${cohort_item_table_rw} + |where _listid = $cohortId + |""".stripMargin + val baseCohort = pg.sqlExecWithResult(stmt) + val diffs = readCohortDiffEntries(cohortId) + val addedDiffs = diffs.filter(col("deleted").isNull || col("deleted") === false) + val deletedDiffs = diffs.filter(col("deleted") === true) + + val result = baseCohort + .union(addedDiffs.select("_itemreferenceid")) + .except(deletedDiffs.select("_itemreferenceid")) + + result + } + + private def readCohortDiffEntries(cohortId: Long): DataFrame = { + val stmt = + s""" + |select date,_itemreferenceid,deleted + |from ${cohort_item_table_rw} + |join ${cohort_table_rw} on ${cohort_table_rw}.id = ${cohort_item_table_rw}._listid + |where ${cohort_table_rw}.identifier___official__value = '$cohortId' + |""".stripMargin + pg.sqlExecWithResult(stmt) + .select(col("date"), col("_itemreferenceid"), col("deleted")) + .orderBy(col("date").asc) + .groupBy(col("_itemreferenceid")) + .agg(last(col("deleted")).as("deleted")) + } + private def uploadRelationship(cohortDefinitionId: Long, sourcePopulation: SourcePopulation): Unit = { if (sourcePopulation.cohortList.isDefined) { @@ -162,14 +209,14 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging { pg.sqlExec(stmt) } - private def uploadCohortTableToPG(df: DataFrame): Unit = { + private def uploadCohortTableToPG(df: DataFrame, withDeleteField: Boolean = false): Unit = { require( - List( + (List( "_listid", "item__reference", "_provider", "_itemreferenceid" - ).toSet == df.columns.toSet, + ) ++ (if (withDeleteField) List("deleted") else List())).toSet == df.columns.toSet, "cohort dataframe shall have _listid, _provider, _provider and item__reference" ) pg.outputBulk(cohort_item_table_rw, diff --git a/src/main/scala/fr/aphp/id/eds/requester/jobs/SparkJobParameter.scala b/src/main/scala/fr/aphp/id/eds/requester/jobs/SparkJobParameter.scala index 4f1cf2b..a50e463 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/jobs/SparkJobParameter.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/jobs/SparkJobParameter.scala @@ -29,6 +29,7 @@ object JobType extends Enumeration { val countAll = "count_all" val countWithDetails = "count_with_details" val create = "create" + val createDiff = "create_diff" val purgeCache = "purge_cache" } diff --git a/src/main/scala/fr/aphp/id/eds/requester/server/JobController.scala b/src/main/scala/fr/aphp/id/eds/requester/server/JobController.scala index 2453e93..293716a 100644 --- a/src/main/scala/fr/aphp/id/eds/requester/server/JobController.scala +++ b/src/main/scala/fr/aphp/id/eds/requester/server/JobController.scala @@ -78,6 +78,7 @@ class JobController(implicit val swagger: Swagger) case JobType.countAll => jobManager.execJob(JobsConfig.countJob, jobData) case JobType.countWithDetails => jobManager.execJob(JobsConfig.countJob, jobData) case JobType.create => jobManager.execJob(JobsConfig.createJob, jobData) + case JobType.createDiff => jobManager.execJob(JobsConfig.createJob, jobData) } } diff --git a/src/test/scala/fr/aphp/id/eds/requester/CreateQueryTest.scala b/src/test/scala/fr/aphp/id/eds/requester/CreateQueryTest.scala index 332236c..372bea3 100644 --- a/src/test/scala/fr/aphp/id/eds/requester/CreateQueryTest.scala +++ b/src/test/scala/fr/aphp/id/eds/requester/CreateQueryTest.scala @@ -1,25 +1,37 @@ package fr.aphp.id.eds.requester +import com.github.mrpowers.spark.fast.tests.DatasetComparer import fr.aphp.id.eds.requester.cohort.CohortCreation -import fr.aphp.id.eds.requester.jobs.{JobEnv, JobsConfig, SparkJobParameter} +import fr.aphp.id.eds.requester.jobs.{JobEnv, JobType, JobsConfig, SparkJobParameter} import fr.aphp.id.eds.requester.query.engine.QueryBuilder import fr.aphp.id.eds.requester.query.model.SourcePopulation import fr.aphp.id.eds.requester.query.resolver.{ResourceResolver, ResourceResolvers} import fr.aphp.id.eds.requester.tools.JobUtilsService import org.apache.spark.sql.{DataFrame, SparkSession} +import org.hl7.fhir.r4.model.ListResource.ListMode import org.mockito.{ArgumentCaptor, ArgumentMatchersSugar} import org.mockito.MockitoSugar.{mock, spy, verify, when} +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuiteLike import java.nio.file.Paths -class CreateQueryTest extends AnyFunSuiteLike { +class CreateQueryTest extends AnyFunSuiteLike with BeforeAndAfterAll with DatasetComparer { System.setProperty("config.resource", "application.test.conf") val pgpassFile = Paths.get(scala.sys.env("HOME"), ".pgpass").toFile if (!pgpassFile.exists()) { pgpassFile.createNewFile() } + implicit var sparkSession: SparkSession = _ + + override def beforeAll(): Unit = { + sparkSession = SparkSession + .builder() + .master("local[*]") + .appName("PGCohortCreationTest") + .getOrCreate() + } test("testCallbackUrl") { var callbackUrl = JobsConfig.createJob.callbackUrl( @@ -55,10 +67,6 @@ class CreateQueryTest extends AnyFunSuiteLike { } test("testRunJob") { - val sparkSession: SparkSession = SparkSession - .builder() - .master("local[*]") - .getOrCreate() val queryBuilderMock = mock[QueryBuilder] val omopTools = spy(mock[CohortCreation]) @@ -147,14 +155,17 @@ class CreateQueryTest extends AnyFunSuiteLike { ArgumentMatchersSugar.* ) ).thenReturn(expectedResult) - when(omopTools.createCohort( - ArgumentMatchersSugar.eqTo("testCohortSimple"), - ArgumentMatchersSugar.any[Option[String]], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[Long] - )).thenReturn(0) + when( + omopTools.createCohort( + ArgumentMatchersSugar.eqTo("testCohortSimple"), + ArgumentMatchersSugar.any[Option[String]], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[Option[Long]], + ArgumentMatchersSugar.any[ListMode], + ArgumentMatchersSugar.any[Long] + )).thenReturn(0) val request = """ {"cohortUuid":"ecd89963-ac90-470d-a397-c846882615a6","sourcePopulation":{"caresiteCohortList":[31558]},"_type":"request","request":{"_type":"andGroup","_id":0,"isInclusive":true,"criteria":[{"_type":"basicResource","_id":1,"isInclusive":true,"resourceType":"patientAphp","filterSolr":"fq=gender:f&fq=deceased:false&fq=active:true","filterFhir":"active=true&gender=f&deceased=false&age-day=ge0&age-day=le130"}],"temporalConstraints":[]}}" @@ -173,14 +184,17 @@ class CreateQueryTest extends AnyFunSuiteLike { assert(res.data("group.count") == "6") assert(res.data("group.id") == "0") - when(omopTools.createCohort( - ArgumentMatchersSugar.eqTo("testCohortSampling"), - ArgumentMatchersSugar.any[Option[String]], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[String], - ArgumentMatchersSugar.any[Long] - )).thenReturn(1L) + when( + omopTools.createCohort( + ArgumentMatchersSugar.eqTo("testCohortSampling"), + ArgumentMatchersSugar.any[Option[String]], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[Option[Long]], + ArgumentMatchersSugar.any[ListMode], + ArgumentMatchersSugar.any[Long] + )).thenReturn(1L) val sampled = createJob.runJob( sparkSession, JobEnv("someid", AppConfig.get), @@ -207,6 +221,93 @@ class CreateQueryTest extends AnyFunSuiteLike { assert(sampled.status == "FINISHED") assert(sampled.data("group.count").toInt >= 1 && sampled.data("group.count").toInt <= 2) assert(sampled.data("group.id") == "1") + + } + + test("runCreateDiff") { + val queryBuilderMock = mock[QueryBuilder] + val omopTools = spy(mock[CohortCreation]) + val resourceResolver = ResourceResolver.get(ResourceResolvers.solr) + class JobUtilsMock extends JobUtilsService { + override def getRandomIdNotInTabooList(allTabooId: List[Short], negative: Boolean): Short = 99 + + override def getCohortCreationService(data: SparkJobParameter, + spark: SparkSession): Option[CohortCreation] = + Some(omopTools) + + override def getResourceResolver(data: SparkJobParameter): ResourceResolver = resourceResolver + } + + val createJob = CreateQuery(queryBuilderMock, new JobUtilsMock) + val existingCohort = sparkSession + .createDataFrame(Seq(Tuple1("1"), Tuple1("2"), Tuple1("3"))) + .toDF("_itemreferenceid") + val newResult = sparkSession + .createDataFrame(Seq(Tuple1("1"), Tuple1("3"), Tuple1("4"))) + .toDF(ResultColumn.SUBJECT) + val expectedUpdateDf = sparkSession + .createDataFrame(Seq(Tuple2("2", true), Tuple2("4", false))) + .toDF(ResultColumn.SUBJECT, "deleted") + when( + queryBuilderMock.processRequest( + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.*, + ArgumentMatchersSugar.* + ) + ).thenReturn(newResult) + when( + omopTools.createCohort( + ArgumentMatchersSugar.eqTo("testCohortDiff"), + ArgumentMatchersSugar.any[Option[String]], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[String], + ArgumentMatchersSugar.any[Option[Long]], + ArgumentMatchersSugar.any[ListMode], + ArgumentMatchersSugar.any[Long] + )).thenReturn(0) + when( + omopTools.updateCohort( + ArgumentMatchersSugar.anyLong, + ArgumentMatchersSugar.any[DataFrame], + ArgumentMatchersSugar.any[SourcePopulation], + ArgumentMatchersSugar.anyLong, + ArgumentMatchersSugar.anyBoolean, + ArgumentMatchersSugar.any[String] + )).thenAnswer((invocation: org.mockito.invocation.InvocationOnMock) => { + val dataFrame = invocation.getArgument[DataFrame](1) + assertSmallDatasetEquality(dataFrame, expectedUpdateDf, orderedComparison = false) + dataFrame + }) + + when( + omopTools.readCohortEntries( + ArgumentMatchersSugar.eqTo(1L) + )(ArgumentMatchersSugar.eqTo(sparkSession))).thenReturn(existingCohort) + val request = + """ + {"cohortUuid":"ecd89963-ac90-470d-a397-c846882615a6","sourcePopulation":{"caresiteCohortList":[31558]},"_type":"request","request":{"_type":"andGroup","_id":0,"isInclusive":true,"criteria":[{"_type":"basicResource","_id":1,"isInclusive":true,"resourceType":"patientAphp","filterSolr":"fq=gender:f&fq=deceased:false&fq=active:true","filterFhir":"active=true&gender=f&deceased=false&age-day=ge0&age-day=le130"}],"temporalConstraints":[]}}" + """.stripMargin + val res = createJob.runJob( + sparkSession, + JobEnv("someid", AppConfig.get), + SparkJobParameter( + "testCohortDiff", + None, + request, + "someOwnerId", + mode = JobType.createDiff, + modeOptions = Map(CreateDiffOptions.baseCohortId -> "1") + ) + ) + assert(res.status == "FINISHED") + assert(res.data("group.count") == "3") + assert(res.data("group.id") == "0") } } diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreationTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreationTest.scala new file mode 100644 index 0000000..499a158 --- /dev/null +++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/fhir/FhirCohortCreationTest.scala @@ -0,0 +1,429 @@ +package fr.aphp.id.eds.requester.cohort.fhir + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{BooleanType, LongType, StructField, StructType} +import com.github.tomakehurst.wiremock.WireMockServer +import com.github.tomakehurst.wiremock.client.WireMock +import com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig +import fr.aphp.id.eds.requester.{FhirServerConfig, ResultColumn} +import fr.aphp.id.eds.requester.query.resolver.rest.DefaultRestFhirClient +import org.apache.spark.sql.SparkSession +import org.hl7.fhir.r4.model.ListResource.ListMode +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuiteLike +import org.scalatest.matchers.should.Matchers + +class FhirCohortCreationTest extends AnyFunSuiteLike with Matchers with BeforeAndAfterEach { + val Port = 8080 + val Host = "localhost" + val wireMockServer = new WireMockServer(wireMockConfig().port(Port)) + val fhirCohortCreationService = new FhirCohortCreation( + new DefaultRestFhirClient(FhirServerConfig("http://" + Host + ":" + Port, None, None), + cohortServer = true) + ) + // For debugging + wireMockServer.addMockServiceRequestListener((request, response) => { + println(request) + println(response) + }) + + override def beforeEach: Unit = { + wireMockServer.start() + WireMock.configureFor(Host, Port) + } + + override def afterEach: Unit = { + wireMockServer.stop() + } + + test("testCreateCohort") { + wireMockServer.addStubMapping( + WireMock + .post(WireMock.urlEqualTo("/List")) + .willReturn( + WireMock + .aResponse() + .withStatus(201) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "List", + "id": "1" + } + """ + ) + ) + .build() + ) + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlEqualTo("/metadata")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "CapabilityStatement", + "id": "1" + } + """ + ) + ) + .build() + ) + fhirCohortCreationService.createCohort("test", + Some("test"), + "test", + "test", + "test", + None, + ListMode.SNAPSHOT, + 1) should be(1) + wireMockServer.verify(1, WireMock.postRequestedFor(WireMock.urlEqualTo("/List"))) + wireMockServer.verify(1, WireMock.getRequestedFor(WireMock.urlEqualTo("/metadata"))) + wireMockServer.listAllStubMappings().getMappings.forEach { mapping => + wireMockServer.verify(WireMock.exactly(1), WireMock.requestMadeFor(mapping.getRequest)) + } + } + + test("testUpdateCohort") { + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlEqualTo("/List/1")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "List", + "id": "1" + } + """ + ) + ) + .build() + ) + wireMockServer.addStubMapping( + WireMock + .put(WireMock.urlEqualTo("/List/1")) + .withRequestBody( + WireMock.equalToJson( + """ +{ + "entry": [ + { + "deleted": false, + "item": { + "id": "1", + "reference": "Patient/1" + } + }, + { + "deleted": false, + "item": { + "id": "2", + "reference": "Patient/2" + } + }, + { + "deleted": false, + "item": { + "id": "3", + "reference": "Patient/3" + } + }, + { + "deleted": false, + "item": { + "id": "4", + "reference": "Patient/4" + } + }, + { + "deleted": false, + "item": { + "id": "5", + "reference": "Patient/5" + } + }, + { + "deleted": false, + "item": { + "id": "6", + "reference": "Patient/6" + } + }, + { + "deleted": false, + "item": { + "id": "7", + "reference": "Patient/7" + } + }, + { + "deleted": false, + "item": { + "id": "8", + "reference": "Patient/8" + } + }, + { + "deleted": false, + "item": { + "id": "9", + "reference": "Patient/9" + } + }, + { + "deleted": true, + "item": { + "id": "10", + "reference": "Patient/10" + } + }, + { + "deleted": false, + "item": { + "id": "11", + "reference": "Patient/11" + } + }, + { + "deleted": false, + "item": { + "id": "12", + "reference": "Patient/12" + } + } + ], + "id": "1", + "resourceType": "List" +} + """, + true, + false + ) + ) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "List", + "id": "1" + } + """ + ) + ) + .build() + ) + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlEqualTo("/metadata")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "CapabilityStatement", + "id": "1" + } + """ + ) + ) + .build() + ) + + implicit val sparkSession: SparkSession = + SparkSession.builder().master("local[*]").getOrCreate() + + val data = Seq( + Row(1L, false), + Row(2L, false), + Row(3L, false), + Row(4L, false), + Row(5L, false), + Row(6L, false), + Row(7L, false), + Row(8L, false), + Row(9L, false), + Row(10L, true), + Row(11L, false), + Row(12L, false) + ) + val schema = StructType( + Seq( + StructField(ResultColumn.SUBJECT, LongType, nullable = false), + StructField("deleted", BooleanType, nullable = false) + )) + val df: DataFrame = sparkSession.createDataFrame( + sparkSession.sparkContext.parallelize(data), + schema + ) + + fhirCohortCreationService.updateCohort(1, df, null, 12, delayCohortCreation = false, "Patient") + + wireMockServer.verify(1, WireMock.putRequestedFor(WireMock.urlEqualTo("/List/1"))) + wireMockServer.verify(1, WireMock.getRequestedFor(WireMock.urlEqualTo("/metadata"))) + + } + + test("testReadCohortEntries") { + implicit val sparkSession: SparkSession = + SparkSession.builder().master("local[*]").getOrCreate() + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlEqualTo("/List/1")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "List", + "id": "1", + "entry": [ + { + "item": { + "reference": "Patient/1" + } + } + ] + } + """ + ) + ) + .build() + ) + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlPathEqualTo("/List")) + .withQueryParam("identifier", WireMock.equalTo("1")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "Bundle", + "id": "bundle-id", + "type": "searchset", + "entry": [ + { + "resource": { + "resourceType": "List", + "mode": "snapshot", + "id": "2", + "entry": [ + { + "item": { + "reference": "Patient/4" + } + } + ] + } + },{ + "resource": { + "resourceType": "List", + "mode": "changes", + "id": "2", + "entry": [ + { + "item": { + "reference": "Patient/1" + }, + "deleted": true + }, + { + "item": { + "reference": "Patient/2" + } + } + ] + } + }, + { + "resource": { + "resourceType": "List", + "mode": "changes", + "id": "2", + "entry": [ + { + "item": { + "reference": "Patient/3" + } + }, + { + "item": { + "reference": "Patient/2" + }, + "deleted": true + } + ] + } + }, + { + "resource": { + "resourceType": "List", + "mode": "changes", + "id": "2", + "entry": [ + { + "item": { + "reference": "Patient/1" + } + } + ] + } + } + ] + } + """ + ) + ) + .build() + ) + wireMockServer.addStubMapping( + WireMock + .get(WireMock.urlEqualTo("/metadata")) + .willReturn( + WireMock + .aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "resourceType": "CapabilityStatement", + "id": "1" + } + """ + ) + ) + .build() + ) + + val df = fhirCohortCreationService.readCohortEntries(1) + assert(df.count() == 2) + assert(df.collect()(0).getAs[Long]("_itemreferenceid") == 3) + assert(df.collect()(1).getAs[Long]("_itemreferenceid") == 1) + } + +} diff --git a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala index ece9eaf..e84c8f9 100644 --- a/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala +++ b/src/test/scala/fr/aphp/id/eds/requester/cohort/pg/PGCohortCreationTest.scala @@ -2,17 +2,35 @@ package fr.aphp.id.eds.requester.cohort.pg import com.github.mrpowers.spark.fast.tests.DatasetComparer import fr.aphp.id.eds.requester.query.model.SourcePopulation -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.hl7.fhir.r4.model.ListResource.ListMode import org.mockito.{ArgumentCaptor, ArgumentMatchers, ArgumentMatchersSugar, MockitoSugar} -import org.mockito.MockitoSugar.{atLeast, mock, verify, verifyNoMoreInteractions, when} +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuiteLike +import org.scalatest.matchers.should.Matchers import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` -class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer { +class PGCohortCreationTest + extends AnyFunSuiteLike + with DatasetComparer + with Matchers + with MockitoSugar + with BeforeAndAfterAll { System.setProperty("config.resource", "application.test.conf") + implicit var spark: SparkSession = _ + + override def beforeAll(): Unit = { + spark = SparkSession + .builder() + .master("local[*]") + .appName("PGCohortCreationTest") + .getOrCreate() + } + test("testCreateCohort") { val pgTools = mock[PGTool] val pgCohortCreation = new PGCohortCreation(pgTools) @@ -28,21 +46,26 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer { |""".stripMargin, List("test", "test", "test", "Practitioner/test", "test", 1) )).thenReturn(expectedResult) - pgCohortCreation.createCohort("test", Some("test"), "test", "test", "test", 1) + pgCohortCreation.createCohort("test", + Some("test"), + "test", + "test", + "test", + None, + ListMode.SNAPSHOT, + 1) } test("testUpdateCohort") { - val sparkSession: SparkSession = SparkSession - .builder() - .master("local[*]") - .getOrCreate() - val pgTools = mock[PGTool] val pgCohortCreation = new PGCohortCreation(pgTools) - val cohortData = Seq(Tuple1(1), Tuple1(3), Tuple1(5)) - val cohort: DataFrame = sparkSession.createDataFrame(cohortData).toDF("subject_id") + val cohortData = Seq(Tuple2(1, false), Tuple2(3, false), Tuple2(5, false)) + val cohort: DataFrame = spark + .createDataFrame(cohortData) + .toDF("subject_id", "deleted") + .withColumn("deleted", col("deleted").cast(BooleanType)) pgCohortCreation.updateCohort(12345, cohort, @@ -95,17 +118,20 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer { List(888, 12345) )) - val expectedDf = sparkSession + val expectedDf = spark .createDataFrame( - sparkSession.sparkContext.parallelize( - List(Row(1, "Patient/1", "Cohort360", 12345L, -1238008758), - Row(3, "Patient/3", "Cohort360", 12345L, -1332131217), - Row(5, "Patient/5", "Cohort360", 12345L, 399554890))), + spark.sparkContext.parallelize( + List( + Row(1, "Patient/1", "Cohort360", 12345L, false, -1662328687), + Row(3, "Patient/3", "Cohort360", 12345L, false, -1015512107), + Row(5, "Patient/5", "Cohort360", 12345L, false, -1473784149) + )), StructType( StructField("_itemreferenceid", IntegerType, nullable = false) :: StructField("item__reference", StringType, nullable = false) :: StructField("_provider", StringType, nullable = false) :: StructField("_listid", LongType, nullable = false) :: + StructField("deleted", BooleanType, nullable = false) :: StructField("hash", IntegerType, nullable = false) :: Nil ) ) @@ -124,6 +150,95 @@ class PGCohortCreationTest extends AnyFunSuiteLike with DatasetComparer { verifyNoMoreInteractions(pgTools) } + test("readCohortEntries") { + // Mock PGTool + val mockPGTool = mock[PGTool] + + val cohortId = 123L + + // Create base cohort DataFrame + val baseCohortSchema = StructType( + Seq( + StructField("_itemreferenceid", StringType, nullable = false) + )) + + val baseCohortData = Seq( + Row("patient1"), + Row("patient2"), + Row("patient3"), + Row("patient4") + ) + + val baseCohortDf = spark.createDataFrame( + spark.sparkContext.parallelize(baseCohortData), + baseCohortSchema + ) + + // Create diff DataFrame with additions and deletions + val diffSchema = StructType( + Seq( + StructField("date", TimestampType, nullable = false), + StructField("_itemreferenceid", StringType, nullable = false), + StructField("deleted", BooleanType, nullable = true) + )) + + import java.sql.Timestamp + import java.time.Instant + + val now = Timestamp.from(Instant.now()) + val earlier = Timestamp.from(Instant.now().minusSeconds(3600)) + + val diffData = Seq( + // Patient4 was deleted + Row(now, "patient4", true), + // Patient5 was added + Row(now, "patient5", null), + // Patient6 was added and then deleted (should not be included) + Row(earlier, "patient6", null), + Row(now, "patient6", true), + // Patient7 was deleted and then added (should be included) + Row(earlier, "patient7", true), + Row(now, "patient7", false) + ) + + val diffDf = spark.createDataFrame( + spark.sparkContext.parallelize(diffData), + diffSchema + ) + + // Configure mock behavior + when(mockPGTool.sqlExecWithResult(ArgumentMatchers.contains("select _itemreferenceid"), ArgumentMatchers.any())) + .thenReturn(baseCohortDf) + + when( + mockPGTool.sqlExecWithResult( + ArgumentMatchers.contains("select date,_itemreferenceid,deleted"), ArgumentMatchers.any())) + .thenReturn(diffDf) + + // Create the cohort creation instance with our mock + val cohortCreation = new PGCohortCreation(mockPGTool) + + // Call the method under test + val result = cohortCreation.readCohortEntries(cohortId) + + // Expected result: base cohort minus deletions plus additions + // Expected: patient1, patient2, patient3, patient5, patient7 + // Not expected: patient4 (deleted), patient6 (added then deleted) + + val expectedIds = Set("patient1", "patient2", "patient3", "patient5", "patient7") + + // Verify the result + val resultIds = result.collect().map(row => row.getAs[String]("_itemreferenceid")).toSet + resultIds should be(expectedIds) + + // Verify that the sqlExecWithResult was called twice + verify(mockPGTool, times(1)) + .sqlExecWithResult(ArgumentMatchers.contains("select _itemreferenceid"), ArgumentMatchers.any()) + verify(mockPGTool, times(1)) + .sqlExecWithResult(ArgumentMatchers.contains("select date,_itemreferenceid,deleted"), ArgumentMatchers.any()) + } + + def compareList(actual: List[Any], expected: List[Any]): Unit = { assert(actual.size == expected.size) for (i <- actual.indices) {