Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ The job query format is as follows :
// or "ratio", this will activate a detailed count of final matched patients per criteria
"details": "<details>",
// optional sampling ratio value between 0.0 and 1.0 to limit the number of patients of the cohort to create (it can be used to sample an existing cohort)
"sampling": "<sampling>"
"sampling": "<sampling>",
// optional cohort id to use as a base for the "createDiff" mode
"baseCohortId": "<base cohort id>"
},
"callbackUrl": "<callback url>" // optional callback url to retrieve the result
}
Expand All @@ -75,6 +77,7 @@ The job query format is as follows :
with `mode` being one of the following values:
- `count` : Return the number of patients that match the criteria of the `cohortDefinitionSyntax`
- `create`: Create a cohort of patients that match the criteria of the `cohortDefinitionSyntax`
- `create_diff`: Create a change list from a base cohort of patients (defined in `modeOptions`) and the new/deleted ones that match the criteria of the `cohortDefinitionSyntax`

and `cohortDefinitionSyntax` being a JSON string that represents the criteria described in the [query format section](#query-format).

Expand Down
69 changes: 64 additions & 5 deletions src/main/scala/fr/aphp/id/eds/requester/CreateQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ import fr.aphp.id.eds.requester.tools.JobUtils.addEmptyGroup
import fr.aphp.id.eds.requester.tools.{JobUtils, JobUtilsService, StageDetails}
import org.apache.log4j.Logger
import org.apache.spark.sql.{SparkSession, functions => F}
import org.hl7.fhir.r4.model.ListResource.ListMode

object CreateOptions extends Enumeration {
type CreateOptions = String
val sampling = "sampling"
}

object CreateDiffOptions extends Enumeration {
type CreateDiffOptions = String
val baseCohortId = "baseCohortId"
}

case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
jobUtilsService: JobUtilsService = JobUtils)
extends JobBase {
Expand All @@ -36,10 +42,13 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
runtime: JobEnv,
data: SparkJobParameter
): JobBaseResult = {
implicit val (request, criterionTagsMap, omopTools, resourceResolver, cacheEnabled) =
implicit val (request, criterionTagsMap, omopTools, resourceResolver, cacheEnabled) = {
jobUtilsService.initSparkJobRequest(logger, spark, runtime, data)
}
implicit val sparkSession: SparkSession = spark

validateRequestOrThrow(request)
validateModeOptionsOrThrow(data)

// Init values here because we are in an object (i.e a singleton) and not a class
var status: String = ""
Expand Down Expand Up @@ -72,12 +81,13 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
.filter(c => cohort.columns.contains(c))
.map(c => F.col(c)): _*)
.dropDuplicates()
.withColumn("deleted", F.lit(false))

if (data.modeOptions.contains(CreateOptions.sampling)) {
val sampling = data.modeOptions(CreateOptions.sampling).toDouble
// https://stackoverflow.com/questions/37416825/dataframe-sample-in-apache-spark-scala#comment62349780_37418684
// to be sure to have the right number of rows
cohort = cohort.sample(sampling+0.1).limit((sampling * cohort.count()).round.toInt)
cohort = cohort.sample(sampling + 0.1).limit((sampling * cohort.count()).round.toInt)
}
cohort.cache()
count = cohort.count()
Expand All @@ -93,8 +103,17 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),
data.cohortDefinitionSyntax,
data.ownerEntityId,
request.resourceType,
if (data.mode == JobType.createDiff && data.modeOptions.contains(
CreateDiffOptions.baseCohortId))
Some(data.modeOptions(CreateDiffOptions.baseCohortId).toLong)
else None,
if (data.mode == JobType.createDiff) {
ListMode.CHANGES
} else {
ListMode.SNAPSHOT
},
count
))
))
.getOrElse(-1L)
}
// get a new cohortId
Expand All @@ -107,19 +126,59 @@ case class CreateQuery(queryBuilder: QueryBuilder = new DefaultQueryBuilder(),

// upload into pg and solr
if (omopTools.isDefined) {
val cohortToUpload =
if (data.mode == JobType.createDiff && data.modeOptions.contains(
CreateDiffOptions.baseCohortId)) {
val baseCohortItems =
omopTools.get.readCohortEntries(data.modeOptions(CreateDiffOptions.baseCohortId).toLong)
baseCohortItems
.join(cohort,
baseCohortItems("_itemreferenceid") === cohort(ResultColumn.SUBJECT),
"full_outer")
.filter(
baseCohortItems("_itemreferenceid").isNull || F
.col(ResultColumn.SUBJECT)
.isNull)
.select(
F.coalesce(baseCohortItems("_itemreferenceid"), cohort(ResultColumn.SUBJECT))
.as(ResultColumn.SUBJECT),
F.when(cohort(ResultColumn.SUBJECT).isNull, true).otherwise(false).as("deleted")
)
} else {
cohort
}

omopTools.get.updateCohort(
cohortDefinitionId,
cohort,
cohortToUpload,
completeRequest.sourcePopulation,
count,
cohortSizeBiggerThanLimit,
cohortSizeBiggerThanLimit || data.mode == JobType.createDiff,
request.resourceType
)
}

getCreationResult(cohortDefinitionId, count, status)
}

private def validateModeOptionsOrThrow(data: SparkJobParameter): Unit = {
val modeOptions = data.modeOptions
if (data.mode == JobType.createDiff) {
if (modeOptions.contains(CreateOptions.sampling)) {
throw new RuntimeException("Can't use sampling with createDiff mode")
}
if (!modeOptions.contains(CreateDiffOptions.baseCohortId)) {
throw new RuntimeException("baseCohortId is required for createDiff mode")
}
}
if (modeOptions.contains(CreateOptions.sampling)) {
val sampling = modeOptions(CreateOptions.sampling).toDouble
if (sampling <= 0 || sampling > 1) {
throw new RuntimeException("Sampling value should be between 0 and 1")
}
}
}

private def getCreationResult(cohortDefinitionId: Long,
count: Long,
status: String): JobBaseResult = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import fr.aphp.id.eds.requester.query.model.SourcePopulation
import fr.aphp.id.eds.requester.query.resolver.rest.DefaultRestFhirClient
import fr.aphp.id.eds.requester.{AppConfig, FhirServerConfig, PGConfig}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.hl7.fhir.r4.model.ListResource.ListMode


trait CohortCreation {

Expand All @@ -24,6 +26,8 @@ trait CohortCreation {
cohortDefinitionSyntax: String,
ownerEntityId: String,
resourceType: String,
baseCohortId: Option[Long],
mode: ListMode,
size: Long): Long

def updateCohort(cohortId: Long,
Expand All @@ -33,6 +37,8 @@ trait CohortCreation {
delayCohortCreation: Boolean,
resourceType: String): Unit

def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame

}

object CohortCreation {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package fr.aphp.id.eds.requester.cohort.fhir

import ca.uhn.fhir.rest.api.{SortOrderEnum, SortSpec}
import fr.aphp.id.eds.requester.ResultColumn
import fr.aphp.id.eds.requester.cohort.CohortCreation
import fr.aphp.id.eds.requester.query.model.SourcePopulation
import fr.aphp.id.eds.requester.query.resolver.rest.RestFhirClient
import org.apache.spark.sql.{DataFrame, Row}
import org.hl7.fhir.r4.model.{ListResource, Reference}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.hl7.fhir.r4.model.ListResource.ListMode
import org.hl7.fhir.r4.model.{Bundle, Identifier, ListResource, Reference}

import scala.jdk.CollectionConverters.seqAsJavaListConverter
import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter}

class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation {

Expand All @@ -25,10 +27,19 @@ class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation
cohortDefinitionSyntax: String,
ownerEntityId: String,
resourceType: String,
baseCohortId: Option[Long],
mode: ListMode,
size: Long): Long = {
val list = new ListResource()
list.setTitle(cohortDefinitionName)
restFhirClient.getClient.create().resource(list).execute().getId.getIdPartAsLong
list.setMode(mode)
if (baseCohortId.isDefined) {
list.setIdentifier(
List(new Identifier().setValue(baseCohortId.get.toString)).asJava
)
}
val fhirCreatedResource = restFhirClient.getClient.create().resource(list).execute()
fhirCreatedResource.getResource.getIdElement.getIdPartAsLong
}

override def updateCohort(cohortId: Long,
Expand All @@ -47,12 +58,50 @@ class FhirCohortCreation(restFhirClient: RestFhirClient) extends CohortCreation
restFhirClient.getClient.update().resource(list).execute()
}

override def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame = {
val baseList = restFhirClient.getClient
.read()
.resource(classOf[ListResource])
.withId(cohortId.toString)
.execute()
.getEntry

val diffListsResults: Bundle = restFhirClient.getClient
.search()
.forResource(classOf[ListResource])
.where(ListResource.IDENTIFIER.exactly().code(cohortId.toString))
.sort(new SortSpec("date", SortOrderEnum.ASC))
.execute()
val diffLists = diffListsResults.getEntry.asScala
.map(_.getResource.asInstanceOf[ListResource])
.filter(l => l.hasMode && l.getMode.equals(ListMode.CHANGES))

val diffEntries = diffLists.flatMap(_.getEntry.asScala)
val result = diffEntries.foldLeft(baseList.asScala.map(_.getItem.getReference).toSet) {
(acc, entry) =>
val itemId = entry.getItem.getReference
val deleted = entry.getDeleted
if (deleted) {
acc - itemId
} else {
acc + itemId
}
}

import spark.implicits._

result.toSeq.map(id => id.split("/").last.toLong).toDF("_itemreferenceid")
}

private def createEntry(row: Row): ListResource.ListEntryComponent = {
val patient = new Reference()
val patientId = row.getAs[String](ResultColumn.SUBJECT)
val patientId = row.getAs[Long](ResultColumn.SUBJECT)
val deleted = row.getAs[Boolean]("deleted")
patient.setReference("Patient/" + patientId)
patient.setId(patientId)
patient.setId(patientId.toString)
val entry = new ListResource.ListEntryComponent()
entry.setItem(patient)
entry.setDeleted(deleted)
entry
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import fr.aphp.id.eds.requester.query.model.SourcePopulation
import fr.aphp.id.eds.requester.tools.SolrTools
import fr.aphp.id.eds.requester.{AppConfig, ResultColumn}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, functions => F}
import org.apache.spark.sql.{DataFrame, SparkSession, functions => F}
import org.hl7.fhir.r4.model.ListResource.ListMode

/**
* @param pg pgTool obj
Expand All @@ -24,12 +25,19 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
cohortDefinitionSyntax: String,
ownerEntityId: String,
resourceType: String,
baseCohortId: Option[Long],
mode: ListMode,
size: Long): Long = {
val (indentifier_col, identifier_val) = if (baseCohortId.isDefined) {
(" identifier,", s" ${baseCohortId.get.toString},")
} else {
("", "")
}
val stmt =
s"""
|insert into ${cohort_table_rw}
|(hash, title, ${note_text_column_name}, _sourcereferenceid, source__reference, _provider, source__type, mode, status, subject__type, date, _size)
|values (-1, ?, ?, ?, ?, '$cohort_provider_name', 'Practitioner', 'snapshot', '${CohortStatus.RUNNING}', ?, now(), ?)
|(hash, title, ${note_text_column_name},${indentifier_col} _sourcereferenceid, source__reference, _provider, source__type, mode, status, subject__type, date, _size)
|values (-1, ?, ?,${identifier_val} ?, ?, '$cohort_provider_name', 'Practitioner', '${mode.toCode}', '${CohortStatus.RUNNING}', ?, now(), ?)
|returning id
|""".stripMargin
val result = pg
Expand Down Expand Up @@ -62,17 +70,22 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
uploadCount(cohortId, count)
uploadRelationship(cohortId, sourcePopulation)

val withDeleteField = cohort.columns.contains("deleted")
val selectedColumns = List(
"_itemreferenceid",
"item__reference",
"_provider",
"_listid"
) ++ (if (withDeleteField) List("deleted") else List())

val dataframe = cohort
.withColumn("_listid", lit(cohortId))
.withColumn("_provider", lit(cohort_provider_name))
.withColumnRenamed(ResultColumn.SUBJECT, "_itemreferenceid")
.withColumn("item__reference", concat(lit(s"${resourceType}/"), col("_itemreferenceid")))
.select(F.col("_itemreferenceid"),
F.col("item__reference"),
F.col("_provider"),
F.col("_listid"))
.select(selectedColumns.map(F.col): _*)

uploadCohortTableToPG(dataframe)
uploadCohortTableToPG(dataframe, withDeleteField)

if (!delayCohortCreation && resourceType == ResourceType.patient)
uploadCohortTableToSolr(cohortId, dataframe, count)
Expand All @@ -87,6 +100,40 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
}
}

override def readCohortEntries(cohortId: Long)(implicit spark: SparkSession): DataFrame = {
val stmt =
s"""
|select _itemreferenceid
|from ${cohort_item_table_rw}
|where _listid = $cohortId
|""".stripMargin
val baseCohort = pg.sqlExecWithResult(stmt)
val diffs = readCohortDiffEntries(cohortId)
val addedDiffs = diffs.filter(col("deleted").isNull || col("deleted") === false)
val deletedDiffs = diffs.filter(col("deleted") === true)

val result = baseCohort
.union(addedDiffs.select("_itemreferenceid"))
.except(deletedDiffs.select("_itemreferenceid"))

result
}

private def readCohortDiffEntries(cohortId: Long): DataFrame = {
val stmt =
s"""
|select date,_itemreferenceid,deleted
|from ${cohort_item_table_rw}
|join ${cohort_table_rw} on ${cohort_table_rw}.id = ${cohort_item_table_rw}._listid
|where ${cohort_table_rw}.identifier___official__value = '$cohortId'
|""".stripMargin
pg.sqlExecWithResult(stmt)
.select(col("date"), col("_itemreferenceid"), col("deleted"))
.orderBy(col("date").asc)
.groupBy(col("_itemreferenceid"))
.agg(last(col("deleted")).as("deleted"))
}

private def uploadRelationship(cohortDefinitionId: Long,
sourcePopulation: SourcePopulation): Unit = {
if (sourcePopulation.cohortList.isDefined) {
Expand Down Expand Up @@ -162,14 +209,14 @@ class PGCohortCreation(pg: PGTool) extends CohortCreation with LazyLogging {
pg.sqlExec(stmt)
}

private def uploadCohortTableToPG(df: DataFrame): Unit = {
private def uploadCohortTableToPG(df: DataFrame, withDeleteField: Boolean = false): Unit = {
require(
List(
(List(
"_listid",
"item__reference",
"_provider",
"_itemreferenceid"
).toSet == df.columns.toSet,
) ++ (if (withDeleteField) List("deleted") else List())).toSet == df.columns.toSet,
"cohort dataframe shall have _listid, _provider, _provider and item__reference"
)
pg.outputBulk(cohort_item_table_rw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object JobType extends Enumeration {
val countAll = "count_all"
val countWithDetails = "count_with_details"
val create = "create"
val createDiff = "create_diff"
val purgeCache = "purge_cache"
}

Expand Down
Loading