Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ botorch/qmc/sobol.c*
# Sphinx documentation
sphinx/build/

# Docusaurus
# Docusaurus and diagnostic tools
website/build/
website/i18n/
website/node_modules/
node_modules

# Tutorials
docs/overview/tutorials/*/*.mdx
Expand Down
2 changes: 2 additions & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import experimental
from .diagnostics import Diagnostics
from .diagnostics.common_statistics import effective_sample_size, r_hat, split_r_hat
from .diagnostics.tools import viz
from .inference import (
CompositionalInference,
empirical,
Expand Down Expand Up @@ -60,4 +61,5 @@
"random_variable",
"simulate",
"split_r_hat",
"viz",
]
22 changes: 22 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# flake8: noqa

"""Visual diagnostic tools for Bean Machine models."""

import sys
from pathlib import Path


if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict


TOOLS_DIR = Path(__file__).parent.resolve()
JS_DIR = TOOLS_DIR.joinpath("js")
JS_DIST_DIR = JS_DIR.joinpath("dist")
75 changes: 75 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

const OFF = 0;
const WARNING = 1;
const ERROR = 2;

module.exports = {
root: true,
env: {
browser: true,
commonjs: true,
jest: true,
node: true,
},
parser: '@typescript-eslint/parser',
parserOptions: {
allowImportExportEverywhere: true,
},
extends: ['airbnb', 'prettier', 'plugin:import/typescript'],
plugins: ['prefer-arrow'],
rules: {
// Allow more than 1 class per file.
'max-classes-per-file': ['error', {ignoreExpressions: true, max: 2}],
// Allow snake_case.
camelcase: [
OFF,
{
properties: 'never',
ignoreDestructuring: true,
ignoreImports: true,
ignoreGlobals: true,
},
],
'no-underscore-dangle': OFF,
// Arrow function rules.
'prefer-arrow/prefer-arrow-functions': [
ERROR,
{
disallowPrototype: true,
singleReturnOnly: false,
classPropertiesAllowed: false,
},
],
'prefer-arrow-callback': [ERROR, {allowNamedFunctions: true}],
'arrow-parens': [ERROR, 'always'],
'arrow-body-style': [ERROR, 'always'],
'func-style': [ERROR, 'declaration', {allowArrowFunctions: true}],
'react/function-component-definition': [
ERROR,
{
namedComponents: 'arrow-function',
unnamedComponents: 'arrow-function',
},
],
// Ignore the global require, since some required packages are BrowserOnly.
'global-require': 0,
// We reassign several parameter objects since Bokeh is just updating values in the
// them.
'no-param-reassign': 0,
// Ignore certain webpack alias because it can't be resolved
'import/no-unresolved': [
ERROR,
{ignore: ['^@theme', '^@docusaurus', '^@generated', '^@bokeh']},
],
'import/extensions': OFF,
'object-shorthand': [ERROR, 'never'],
'prefer-destructuring': [WARNING, {object: true, array: true}],
'no-nested-ternary': 0,
},
};
8 changes: 8 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/.prettierrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"arrowParens": "always",
"bracketSpacing": false,
"printWidth": 88,
"proseWrap": "never",
"singleQuote": true,
"trailingComma": "all"
}
55 changes: 55 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"name": "visual-diagnostic-tools",
"version": "0.1.0",
"description": "",
"license": "MIT",
"keywords": [],
"repository": {},
"scripts": {
"build": "webpack"
},
"dependencies": {
"@bokeh/bokehjs": "^2.4.3",
"fast-kde": "^0.2.1"
},
"devDependencies": {
"@types/node": "^18.0.4",
"@typescript-eslint/eslint-plugin": "^5.30.5",
"@typescript-eslint/parser": "^5.30.5",
"eslint": "^8.19.0",
"eslint-config-airbnb": "^19.0.4",
"eslint-config-prettier": "^8.5.0",
"eslint-plugin-import": "^2.26.0",
"eslint-plugin-jsx-a11y": "^6.5.1",
"eslint-plugin-prefer-arrow": "^1.2.3",
"eslint-plugin-react": "^7.28.0",
"eslint-plugin-react-hooks": "^4.3.0",
"prettier": "^2.7.1",
"ts-loader": "^9.3.1",
"ts-node": "^10.9.1",
"typescript": "^4.7.4",
"webpack": "^5.74.0",
"webpack-cli": "^4.10.0"
},
"overrides": {
"cwise": "$cwise",
"minimist": "$minimist",
"quote-stream": "$quote-stream",
"static-eval": "$static-eval",
"static-module": "$static-module",
"typedarray-pool": "$typedarray-pool"
},
"peerDependencies": {
"@types/cwise": "^1.0.4",
"@types/minimist": "^1.2.2",
"@types/static-eval": "^0.2.31",
"@types/typedarray-pool": "^1.1.2",
"buffer": "^6.0.3",
"cwise": "^1.0.10",
"minimist": "^1.2.6",
"quote-stream": "^1.0.2",
"static-eval": "2.1.0",
"static-module": "^3.0.4",
"typedarray-pool": "^1.2.0"
}
}
190 changes: 190 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {Axis} from '@bokehjs/models/axes/axis';
import {cumulativeSum} from '../stats/array';
import {scaleToOne} from '../stats/dataTransformation';
import {
interval as hdiInterval,
data as hdiData,
} from '../stats/highestDensityInterval';
import {oneD} from '../stats/marginal';
import {mean as computeMean} from '../stats/pointStatistic';
import {interpolatePoints} from '../stats/utils';
import * as interfaces from './interfaces';

// Define the names of the figures used for this Bokeh application.
const figureNames = ['marginal', 'cumulative'];

/**
* Update the given Bokeh Axis object with the new label string. You must use this
* method to update axis strings using TypeScript, otherwise the ts compiler will throw
* a type check error.
*
* @param {Axis} axis - The Bokeh Axis object needing a new label.
* @param {string | null} label - The new label for the Bokeh Axis object.
*/
export const updateAxisLabel = (axis: Axis, label: string | null): void => {
// Type check requirement.
if ('axis_label' in axis) {
axis.axis_label = label;
}
};

/**
* Compute the following statistics for the given random variable data
*
* - lower bound for the highest density interval calculated from the marginalX;
* - mean of the rawData;
* - upper bound for the highest density interval calculated from the marginalY.
*
* @param {number[]} rawData - Raw random variable data from the model.
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the
* random variable.
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable.
* @param {number | null} [hdiProb=null] - The highest density interval probability
* value. If the default value is not overwritten, then the default HDI probability
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this
* value is the default.
* @param {string[]} [text_align=['right', 'center', 'left']] - How to align the text
* shown in the figure for the point statistics.
* @param {number[]} [x_offset=[-5, 0, 5]] - Offset values for the text along the
* x-axis.
* @param {number[]} [y_offset=[0, 10, 0]] - Offset values for the text along the
* y-axis
* @returns {interfaces.LabelsData} Object containing the computed stats.
*/
export const computeStats = (
rawData: number[],
marginalX: number[],
marginalY: number[],
hdiProb: number | null = null,
text_align: string[] = ['right', 'center', 'left'],
x_offset: number[] = [-5, 0, 5],
y_offset: number[] = [0, 10, 0],
): interfaces.LabelsData => {
// Set the default value to 0.89 if no default value has been given.
const hdiProbability = hdiProb ?? 0.89;

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(marginalX);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
const text = [
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`,
`Mean: ${mean.toFixed(3)}`,
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`,
];

return {
x: x,
y: y,
text: text,
text_align: text_align,
x_offset: x_offset,
y_offset: y_offset,
};
};

/**
* Compute data for the one-dimensional marginal diagnostic tool.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the Kernel Density Estimate (KDE).
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @returns {interfaces.Data} The marginal distribution and cumulative
* distribution calculated from the given random variable data. Point statistics are
* also calculated.
*/
export const computeData = (
data: number[],
bwFactor: number,
hdiProbability: number,
): interfaces.Data => {
const output = {} as interfaces.Data;
for (let i = 0; i < figureNames.length; i += 1) {
const figureName = figureNames[i];
output[figureName] = {} as interfaces.GlyphData;

// Compute the one-dimensional KDE and its cumulative distribution.
const distribution = oneD(data, bwFactor);
switch (figureName) {
case 'cumulative':
distribution.y = scaleToOne(cumulativeSum(distribution.y));
break;
default:
break;
}

// Compute the point statistics for the given data.
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability);

output[figureName] = {
distribution: distribution,
hdi: hdiData(data, distribution.x, distribution.y, hdiProbability),
stats: {x: stats.x, y: stats.y, text: stats.text},
labels: stats,
};
}
return output;
};

/**
* Callback used to update the Bokeh application in the notebook.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {string} rvName - The name of the random variable from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the kernel density estimate.
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @param {interfaces.Sources} sources - Bokeh sources used to render glyphs in the
* application.
* @param {interfaces.Figures} figures - Bokeh figures shown in the application.
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs.
* @returns {number} We display the value of the bandwidth used for computing the Kernel
* Density Estimate in a div, and must return that value here in order to update the
* value displayed to the user.
*/
export const update = (
data: number[],
rvName: string,
bwFactor: number,
hdiProbability: number,
sources: interfaces.Sources,
figures: interfaces.Figures,
tooltips: interfaces.Tooltips,
): number => {
const computedData = computeData(data, bwFactor, hdiProbability);
for (let i = 0; i < figureNames.length; i += 1) {
// Update all sources with new data calculated above.
const figureName = figureNames[i];
sources[figureName].distribution.data = {
x: computedData[figureName].distribution.x,
y: computedData[figureName].distribution.y,
};
sources[figureName].hdi.data = {
base: computedData[figureName].hdi.base,
lower: computedData[figureName].hdi.lower,
upper: computedData[figureName].hdi.upper,
};
sources[figureName].stats.data = computedData[figureName].stats;
sources[figureName].labels.data = computedData[figureName].labels;

// Update the axes labels.
updateAxisLabel(figures[figureName].below[0], rvName);

// Update the tooltips.
tooltips[figureName].stats.tooltips = [['', '@text']];
tooltips[figureName].distribution.tooltips = [[rvName, '@x']];
}
return computedData.marginal.distribution.bandwidth;
};
12 changes: 12 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import * as marginal1d from './callbacks';

// The CustomJS methods used by Bokeh require us to make the JavaScript available in the
// browser, which is done by defining it below.
(window as any).marginal1d = marginal1d;
Loading