Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,58 +148,20 @@ internal fun AggregationSpec.getFeatureId(): String {
}
}

internal fun List<AggregationSpec>.metrics(): List<MetricDefinition> = buildList {
for (aggregation in this@metrics) {
when (aggregation) {
// Count and PrivacyIdCount do not aggregate any specific value, therefore they are handled
// differently.
is PrivacyIdCount ->
add(
MetricDefinition(
MetricType.PRIVACY_ID_COUNT,
aggregation.budget?.toInternalBudgetPerOpSpec(),
)
)
is Count ->
add(MetricDefinition(MetricType.COUNT, aggregation.budget?.toInternalBudgetPerOpSpec()))
is ValueAggregations<*> -> {
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
add(
MetricDefinition(
valueAggregationSpec.metricType,
valueAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
)
)
}
}
is VectorAggregations<*> -> {
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
add(
MetricDefinition(
vectorAggregationSpec.metricType,
vectorAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
)
)
}
}
}
}
}

internal fun List<AggregationSpec>.outputColumnNamesWithMetricTypes():
List<Pair<String, MetricType>> = buildList {
for (aggregation in this@outputColumnNamesWithMetricTypes) {
when (aggregation) {
is PrivacyIdCount -> add(aggregation.outputColumnName to MetricType.PRIVACY_ID_COUNT)
is Count -> add(aggregation.outputColumnName to MetricType.COUNT)
is PrivacyIdCount -> add(Pair(aggregation.outputColumnName, MetricType.PRIVACY_ID_COUNT))
is Count -> add(Pair(aggregation.outputColumnName, MetricType.COUNT))
is ValueAggregations<*> -> {
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
add(valueAggregationSpec.outputColumnName to valueAggregationSpec.metricType)
add(Pair(valueAggregationSpec.outputColumnName, valueAggregationSpec.metricType))
}
}
is VectorAggregations<*> -> {
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
add(vectorAggregationSpec.outputColumnName to vectorAggregationSpec.metricType)
add(Pair(vectorAggregationSpec.outputColumnName, vectorAggregationSpec.metricType))
}
}
}
Expand Down Expand Up @@ -227,3 +189,22 @@ internal fun List<AggregationSpec>.outputColumnNameToFeatureIdMap(): Map<String,

internal fun List<AggregationSpec>.outputColumnNames(): List<String> =
outputColumnNamesWithMetricTypes().map { it.first }

internal fun AggregationSpec.toNonFeatureMetricDefinition(): MetricDefinition {
val (metricType, budget) =
when (this) {
is Count -> Pair(MetricType.COUNT, this.budget)
is PrivacyIdCount -> Pair(MetricType.PRIVACY_ID_COUNT, this.budget)
else ->
throw IllegalArgumentException("Unsupported AggregationSpec type for non feature metrics")
}
return MetricDefinition(metricType, budget?.toInternalBudgetPerOpSpec())
}

internal fun ValueAggregationSpec.toMetricDefinition(): MetricDefinition {
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
}

internal fun VectorAggregationSpec.toMetricDefinition(): MetricDefinition {
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.FeatureValuesExt
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollection
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkTable
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType
import com.google.privacy.differentialprivacy.pipelinedp4j.core.ScalarFeatureSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.core.SelectPartitionsParams
import com.google.privacy.differentialprivacy.pipelinedp4j.core.VectorFeatureSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.DpAggregates
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.PerFeature
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.copy
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.dpAggregates
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.perFeature

sealed interface Query<ReturnT> {
/** Executes the query (in production mode). */
Expand Down Expand Up @@ -81,8 +79,7 @@ protected constructor(

// 1. If aggregation is empty then we do partition selection.
if (aggregations.isEmpty()) {
val extractors =
createDataExtractors(valueExtractor = null, vectorExtractor = null, featureId = null)
val extractors = createDataExtractors(aggregations)
val result = dpEngine.selectPartitions(data, createSelectPartitionsParams(), extractors)
dpEngine.done()

Expand All @@ -94,81 +91,10 @@ protected constructor(
)
}

val isValueOrVectorAgg: (AggregationSpec) -> Boolean = {
it is ValueAggregations<*> || it is VectorAggregations<*>
}
val valueAndVectorAggs: List<AggregationSpec> = aggregations.filter(isValueOrVectorAgg)

// Count/PidCount aggregations
val countAggs: List<AggregationSpec> = aggregations.filterNot(isValueOrVectorAgg)

var partitions: FrameworkCollection<GroupKeysT>? = groupsType.getPublicGroups()
// 2. If aggregations are not empty, split them into runs.
// The first run contains all aggregations that do not relate to specific values or vectors
// (e.g. COUNT), plus the first value or vector aggregation (if any).
// The subsequent runs contain one value or vector aggregation each.
val firstFeatureAggregation = valueAndVectorAggs.firstOrNull()
val firstRun = buildList {
if (firstFeatureAggregation != null) {
add(firstFeatureAggregation)
}
addAll(countAggs)
}
val otherOneFeatureRuns = valueAndVectorAggs.drop(1)

val aggResults = mutableListOf<FrameworkTable<GroupKeysT, DpAggregates>>()

// 3. Run the first aggregation. If public partitions are not provided,
// this run performs partition selection, and the result partitions are used
// in subsequent runs.
val result = aggregateWithDpEngine(dpEngine, firstFeatureAggregation, firstRun, partitions)
aggResults.add(result)
if (partitions == null) {
partitions = result.keys("GetPartitions")
}

// 4. Run all subsequent aggregations using partitions from the first run.
for (featureAggregation in otherOneFeatureRuns) {
val result =
aggregateWithDpEngine(dpEngine, featureAggregation, listOf(featureAggregation), partitions)
aggResults.add(result)
}
val partitions: FrameworkCollection<GroupKeysT>? = groupsType.getPublicGroups()
val result = aggregateWithDpEngine(dpEngine, aggregations, partitions)
dpEngine.done()

val featureIdPerRun =
if (valueAndVectorAggs.isEmpty()) {
listOf(null)
} else {
valueAndVectorAggs.map { it.getFeatureId() }
}
return aggResults
.zip(featureIdPerRun)
.map { (table, featureId) ->
table.mapValues("TagWithFeatureId", encoderFactory.protos(DpAggregates::class)) { _, agg ->
if (featureId == null) {
agg
} else {
val perFeature = constructPerFeature(agg, featureId)
dpAggregates {
count = agg.count
privacyIdCount = agg.privacyIdCount
this.perFeature += perFeature
}
}
}
}
.reduce {
acc: FrameworkTable<GroupKeysT, DpAggregates>,
table: FrameworkTable<GroupKeysT, DpAggregates> ->
acc.flattenWith("FlattenResultsFromMultipleRuns", table)
}
.groupAndCombineValues("MergeDpAggregates") { acc, dpAggregatesFromSingleRun ->
acc.copy {
count += dpAggregatesFromSingleRun.count
privacyIdCount += dpAggregatesFromSingleRun.privacyIdCount
perFeature += dpAggregatesFromSingleRun.perFeatureList
}
}
return result
}

private fun validate() {
Expand Down Expand Up @@ -413,73 +339,52 @@ protected constructor(

private fun aggregateWithDpEngine(
dpEngine: DpEngine,
featureAggregation: AggregationSpec?,
aggregationSpecs: List<AggregationSpec>,
partitions: FrameworkCollection<GroupKeysT>?,
): FrameworkTable<GroupKeysT, DpAggregates> {
@Suppress("UNCHECKED_CAST") val va = featureAggregation as? ValueAggregations<DataRowT>
@Suppress("UNCHECKED_CAST") val vea = featureAggregation as? VectorAggregations<DataRowT>
val extractors =
createDataExtractors(
va?.valueExtractor,
vea?.vectorExtractor,
featureAggregation?.getFeatureId(),
)
val params = createAggregationParams(aggregationSpecs, va, vea)
val extractors = createDataExtractors(aggregationSpecs)
val params = createAggregationParams(aggregationSpecs)
return dpEngine.aggregate(data, params, extractors, partitions)
}

private fun createDataExtractors(
valueExtractor: ((DataRowT) -> Double)?,
vectorExtractor: ((DataRowT) -> List<Double>)?,
featureId: String?,
) =
when {
valueExtractor == null && vectorExtractor == null ->
DataExtractors.from(
privacyUnitExtractor,
privacyUnitEncoder,
groupKeyExtractor,
groupKeyEncoder,
)
valueExtractor != null && vectorExtractor == null ->
DataExtractors.from(
privacyUnitExtractor,
privacyUnitEncoder,
groupKeyExtractor,
groupKeyEncoder,
valuesExtractors =
listOf(
FeatureValuesExtractor(
checkNotNull(featureId) {
"featureId must not be null when a value extractor is provided."
}
) {
listOf(valueExtractor(it))
}
),
)
valueExtractor == null && vectorExtractor != null ->
DataExtractors.from(
privacyUnitExtractor,
privacyUnitEncoder,
groupKeyExtractor,
groupKeyEncoder,
valuesExtractors =
listOf(
FeatureValuesExtractor(
checkNotNull(featureId) {
"featureId must not be null when a vector extractor is provided."
},
vectorExtractor,
)
),
)
else ->
throw IllegalArgumentException(
"Only one of valueExtractor and vectorExtractor can be specified, but both were specified."
)
aggregations: List<AggregationSpec>
): DataExtractors<DataRowT, PrivacyUnitT, GroupKeysT> {
val featureValueExtractors: List<FeatureValuesExtractor<DataRowT>> =
aggregations.mapNotNull {
when (it) {
is ValueAggregations<*> -> {
@Suppress("UNCHECKED_CAST") val va = it as ValueAggregations<DataRowT>
val featureId = va.getFeatureId()
val valueExtractor = va.valueExtractor
FeatureValuesExtractor<DataRowT>(featureId, { row -> listOf(valueExtractor(row)) })
}
is VectorAggregations<*> -> {
@Suppress("UNCHECKED_CAST") val vea = it as VectorAggregations<DataRowT>
val featureId = vea.getFeatureId()
val vectorExtractor = vea.vectorExtractor
FeatureValuesExtractor<DataRowT>(featureId, { row -> vectorExtractor(row) })
}
else -> null
}
}
return if (featureValueExtractors.isEmpty()) {
DataExtractors.from(
privacyUnitExtractor,
privacyUnitEncoder,
groupKeyExtractor,
groupKeyEncoder,
)
} else {
DataExtractors.from(
privacyUnitExtractor,
privacyUnitEncoder,
groupKeyExtractor,
groupKeyEncoder,
valuesExtractors = featureValueExtractors,
)
}
}

private fun createSelectPartitionsParams() =
SelectPartitionsParams(
Expand All @@ -489,48 +394,53 @@ protected constructor(
contributionBoundingLevel = contributionBoundingLevel.toInternalContributionBoundingLevel(),
)

private fun createAggregationParams(
aggregationSpecs: List<AggregationSpec>,
valueAggregations: ValueAggregations<*>?,
vectorAggregations: VectorAggregations<*>?,
): AggregationParams {
val valueContributionBounds = valueAggregations?.contributionBounds
val vectorContributionBounds = vectorAggregations?.vectorContributionBounds
private fun createAggregationParams(aggregationSpecs: List<AggregationSpec>): AggregationParams {
val nonFeatureMetrics =
aggregationSpecs
.filter { it is Count || it is PrivacyIdCount }
.map { it.toNonFeatureMetricDefinition() }
val features =
aggregationSpecs.mapNotNull {
when (it) {
is ValueAggregations<*> -> {
val valueContributionBounds = it.contributionBounds
ScalarFeatureSpec(
featureId = it.getFeatureId(),
metrics = it.valueAggregationSpecs.map { it.toMetricDefinition() }.toImmutableList(),
minValue = valueContributionBounds.valueBounds?.minValue,
maxValue = valueContributionBounds.valueBounds?.maxValue,
minTotalValue = valueContributionBounds.totalValueBounds?.minValue,
maxTotalValue = valueContributionBounds.totalValueBounds?.maxValue,
)
}
is VectorAggregations<*> -> {
val vectorContributionBounds = it.vectorContributionBounds
VectorFeatureSpec(
featureId = it.getFeatureId(),
metrics = it.vectorAggregationSpecs.map { it.toMetricDefinition() }.toImmutableList(),
vectorSize = it.vectorSize,
normKind = vectorContributionBounds.maxVectorTotalNorm.normKind.toInternalNormKind(),
vectorMaxTotalNorm = vectorContributionBounds.maxVectorTotalNorm.value,
)
}
else -> null
}
}

return AggregationParams(
metrics = ImmutableList.copyOf(aggregationSpecs.metrics()),
nonFeatureMetrics = nonFeatureMetrics.toImmutableList(),
features = features.toImmutableList(),
noiseKind =
checkNotNull(noiseKind) { "noiseKind cannot be null if there are aggregations." }
.toInternalNoiseKind(),
maxPartitionsContributed = contributionBoundingLevel.getMaxPartitionsContributed(),
maxContributionsPerPartition = contributionBoundingLevel.getMaxContributionsPerPartition(),
minValue = valueContributionBounds?.valueBounds?.minValue,
maxValue = valueContributionBounds?.valueBounds?.maxValue,
minTotalValue = valueContributionBounds?.totalValueBounds?.minValue,
maxTotalValue = valueContributionBounds?.totalValueBounds?.maxValue,
vectorNormKind = vectorContributionBounds?.maxVectorTotalNorm?.normKind?.toInternalNormKind(),
vectorMaxTotalNorm = vectorContributionBounds?.maxVectorTotalNorm?.value,
vectorSize = vectorAggregations?.vectorSize,
partitionSelectionBudget = groupsType.getBudget()?.toInternalBudgetPerOpSpec(),
preThreshold = groupsType.getPreThreshold(),
contributionBoundingLevel = contributionBoundingLevel.toInternalContributionBoundingLevel(),
partitionsBalance = groupByAdditionalParameters.groupsBalance.toPartitionsBalance(),
)
}

companion object {
private fun constructPerFeature(dpAggregates: DpAggregates, featureId: String): PerFeature {
return perFeature {
this.featureId = featureId
sum = dpAggregates.sum
mean = dpAggregates.mean
variance = dpAggregates.variance
if (dpAggregates.quantilesList.isNotEmpty()) {
quantiles += dpAggregates.quantilesList
}
if (dpAggregates.vectorSumList.isNotEmpty()) {
vectorSum += dpAggregates.vectorSumList
}
}
}
}
}

private fun <T : Any> Iterable<T>.toImmutableList(): ImmutableList<T> = ImmutableList.copyOf(this)
Loading
Loading