diff --git a/changelog/21.docs.md b/changelog/21.docs.md new file mode 100644 index 0000000..d23de37 --- /dev/null +++ b/changelog/21.docs.md @@ -0,0 +1,2 @@ ++ add how-to-guide for cosine weight-decay in documentation ++ guide presents scenarios for both cases: `harmonisation-time < convergence-time` and vice versa. diff --git a/changelog/21.improvement.md b/changelog/21.improvement.md new file mode 100644 index 0000000..a81b9f8 --- /dev/null +++ b/changelog/21.improvement.md @@ -0,0 +1,3 @@ ++ add tests for cosine weight-decay that explicitly checks the function value and gradient value at the harmonisation time and convergence time ++ add test for cosine weight-decay that checks the case `harmonisation time > convergence time` ++ adapt CosineDecaySplineHelper to support the case `harmonisation-time > convergence-time` diff --git a/docs/NAVIGATION.md b/docs/NAVIGATION.md index b263140..ff07bc7 100644 --- a/docs/NAVIGATION.md +++ b/docs/NAVIGATION.md @@ -9,6 +9,7 @@ See https://oprypin.github.io/mkdocs-literate-nav/ - [How-to guides](how-to-guides/index.md) - [Do a basic calculation](how-to-guides/basic-calculation.md) - [Use a cubic spline for harmonisation](how-to-guides/cubic_spline.py) + - [Use a cosine-weight decay for harmonisation](how-to-guides/cosine_decay.py) - [Tutorials](tutorials/index.md) - [Getting Started](tutorials/tutorial.py) - [Further background](further-background/index.md) diff --git a/docs/how-to-guides/cosine_decay.py b/docs/how-to-guides/cosine_decay.py new file mode 100644 index 0000000..979768d --- /dev/null +++ b/docs/how-to-guides/cosine_decay.py @@ -0,0 +1,272 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.6 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # How to use cosine weight-decay as harmonisation of two functions? +# In this tutorial, we present use cases for applying cosine-weight decay +# to harmonise two functions which we will call in the following +# `diverge_from` and `harmonisee`. +# The `cosine-weight-decay` interpolates between `diverge_from` and `harmonisee`. + + +# %% +# import relevant libraries +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +import scipy.interpolate + +from gradient_aware_harmonisation.convergence import get_cosine_decay_harmonised_spline +from gradient_aware_harmonisation.spline import SplineScipy + +# %% [markdown] + +# We start by defining the spline `diverge_from` as a linear +# function with intercept=1.0 and slope=2.5. + +# %% +diverge_from_gradient = 2.5 +diverge_from_y_intercept = 1.0 + +diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[[diverge_from_gradient], [diverge_from_y_intercept]], + x=[0, 1e8], + ) +) + +# %% [markdown] +# ## Scenarios +# ### Harmonisation time < convergence time +# In the following, we consider nine scenarios in which the +# `harmonisee` spline differs from the `diverge_from` spline +# due to varying shifts in the intercept ([0.0, -1.2, 1.2]) +# and slope ([1.0, 0.7, 1.4]). +# In all of these scenarios we consider harmonisation time +# (=0) < convergence time (=3.2). + +# %% +harmonisation_time = 0.0 +convergence_time = 3.2 + + +# %% +def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 + ax.plot( + x, + spline(x), + label=label, + ) + + if gradient: + ax.set_title("Gradient") + else: + ax.set_title("Function") + + +# %% +i = 0 +for y_intercept_shift in [0.0, -1.2, 1.2]: + for gradient_factor in [1.0, 0.7, 1.4]: + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient * gradient_factor], + [diverge_from_y_intercept + y_intercept_shift], + ], + x=[0, 1e8], + ) + ) + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + fig, axes = plt.subplots(ncols=2, figsize=(12, 4)) + + plot_spline( + diverge_from, np.linspace(-1.0, 3.0, 101), ax=axes[0], label="diverge_from" + ) + plot_spline( + harmonisee, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="harmonisee", + ) + plot_spline( + res, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="res", + ) + + plot_spline( + diverge_from.derivative(), + np.linspace(-1.0, 3.0, 101), + ax=axes[1], + label="diverge_from", + gradient=True, + ) + plot_spline( + harmonisee.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="harmonisee", + gradient=True, + ) + plot_spline( + res.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="cosine_weight_decay", + gradient=True, + ) + + for ax in axes: + ax.axvline( + harmonisation_time, + label="harmonisation_time", + color="gray", + linestyle=":", + ) + ax.axvline( + convergence_time, label="convergence_time", color="gray", linestyle="--" + ) + for ax in axes[1::2]: + ax.legend(handlelength=1.1, loc="center right", fontsize="small") + + fig.suptitle( + f"Scenario {i+1} (intercept shift: {y_intercept_shift}," + + f" slope factor: {gradient_factor})" + ) + plt.show() + i = i + 1 + +# %% [markdown] +# ### Harmonisation time > convergence time +# In the following, we consider the same nine scenarios as +# above in which the `harmonisee` spline differs +# from the `diverge_from` spline due to varying shifts in the +# intercept ([0.0, -1.2, 1.2]) and slope ([1.0, 0.7, 1.4]). +# However, this time we consider in all upcoming scenarios +# harmonisation time (=1.0) > convergence time (=-1.0). + +# %% +diverge_from_gradient = 2.5 +diverge_from_y_intercept = 1.0 + +# TODO: from left-edge or something here +diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient], + [diverge_from_y_intercept - 10.0 * diverge_from_gradient], + ], + x=[-10.0, 10.0], + ) +) + +# %% +harmonisation_time = 1.0 +convergence_time = -1.0 + +# %% +# Backwards along x harmonisation +i = 0 +for y_intercept_shift in [0.0, -1.2, 1.2]: + for gradient_factor in [1.0, 0.7, 1.4]: + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient * gradient_factor], + [ + diverge_from_y_intercept + - 10.0 * diverge_from_gradient + + y_intercept_shift + ], + ], + x=[-10.0, 10.0], + ) + ) + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + fig, axes = plt.subplots(ncols=2, figsize=(12, 4)) + + plot_spline( + diverge_from, np.linspace(-1.0, 3.0, 101), ax=axes[0], label="diverge_from" + ) + plot_spline( + harmonisee, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="harmonisee", + ) + plot_spline( + res, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="res", + ) + + plot_spline( + diverge_from.derivative(), + np.linspace(-1.0, 3.0, 101), + ax=axes[1], + label="diverge_from", + gradient=True, + ) + plot_spline( + harmonisee.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="harmonisee", + gradient=True, + ) + plot_spline( + res.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="cosine_weight_decay", + gradient=True, + ) + + for ax in axes: + ax.axvline( + harmonisation_time, + label="harmonisation_time", + color="gray", + linestyle=":", + ) + ax.axvline( + convergence_time, label="convergence_time", color="gray", linestyle="--" + ) + for ax in axes[1::2]: + ax.legend(handlelength=1.1, loc="center right", fontsize="small") + + fig.suptitle( + f"Scenario {i+1} (intercept shift: {y_intercept_shift}," + + f" slope factor: {gradient_factor})" + ) + plt.show() + i = i + 1 diff --git a/docs/how-to-guides/index.md b/docs/how-to-guides/index.md index b0e9469..e76823c 100644 --- a/docs/how-to-guides/index.md +++ b/docs/how-to-guides/index.md @@ -5,3 +5,4 @@ focuses on a **problem-oriented** approach. We'll go over how to solve common tasks. + [How can I use a cubic spline for harmonisation?](cubic_spline) ++ [How can I use a cosine-weight-decay for harmonisation?](cosine_decay) diff --git a/src/gradient_aware_harmonisation/convergence.py b/src/gradient_aware_harmonisation/convergence.py index 611a50b..37dc819 100644 --- a/src/gradient_aware_harmonisation/convergence.py +++ b/src/gradient_aware_harmonisation/convergence.py @@ -84,7 +84,9 @@ def calc_gamma( """Get cosine-decay derivative""" # compute weight (here: gamma) according to a cosine-decay angle = ( - np.pi * (x - self.initial_time) / (self.final_time - self.initial_time) + np.pi + * (x - self.initial_time) + / abs(self.final_time - self.initial_time) ) gamma_decaying = 0.5 * (1 + np.cos(angle)) @@ -92,9 +94,16 @@ def calc_gamma( return gamma_decaying if not isinstance(x, np.ndarray): - if x <= self.initial_time: + if self.initial_time <= self.final_time: + if x <= self.initial_time: + gamma: float | NP_FLOAT_OR_INT | NP_ARRAY_OF_FLOAT_OR_INT = 1.0 + elif x >= self.final_time: + gamma = 0.0 + else: + gamma = calc_gamma(x) + elif x >= self.initial_time: gamma: float | NP_FLOAT_OR_INT | NP_ARRAY_OF_FLOAT_OR_INT = 1.0 - elif x >= self.final_time: + elif x <= self.final_time: gamma = 0.0 else: gamma = calc_gamma(x) @@ -109,10 +118,14 @@ def calc_gamma( return gamma - # apply decay function only to values that lie between harmonisation - # time and convergence-time - x_gte_final_time = np.where(x >= self.final_time) - x_decay = np.logical_and(x >= self.initial_time, x < self.final_time) + # apply decay function only to values that lie between + # harmonisation time and convergence-time + if self.initial_time <= self.final_time: + x_gte_final_time = np.where(x >= self.final_time) + x_decay = np.logical_and(x >= self.initial_time, x < self.final_time) + else: + x_gte_final_time = np.where(x <= self.final_time) + x_decay = np.logical_and(x <= self.initial_time, x > self.final_time) gamma = np.ones_like(x, dtype=np.floating) gamma[x_gte_final_time] = 0.0 gamma[x_decay] = calc_gamma(x[x_decay]) @@ -210,14 +223,19 @@ def calc_gamma_rising_derivative( """Get cosine-decay derivative""" # compute derivative of gamma according to a cosine-decay angle = ( - np.pi * (x - self.initial_time) / (self.final_time - self.initial_time) + np.pi + * (x - self.initial_time) + / abs(self.final_time - self.initial_time) ) gamma_decaying_derivative = -0.5 * np.sin(angle) return gamma_decaying_derivative if not isinstance(x, np.ndarray): - if x <= self.initial_time or x >= self.final_time: + if self.initial_time <= self.final_time: + if x <= self.initial_time or x >= self.final_time: + return 0.0 + elif x >= self.initial_time or x <= self.final_time: return 0.0 gamma_rising_derivative = calc_gamma_rising_derivative(x) @@ -234,7 +252,15 @@ def calc_gamma_rising_derivative( # apply decay function only to values that lie between harmonisation # time and convergence-time - x_decay = np.where(np.logical_and(x > self.initial_time, x < self.final_time)) + if self.initial_time <= self.final_time: + x_decay = np.where( + np.logical_and(x > self.initial_time, x < self.final_time) + ) + else: + x_decay = np.where( + np.logical_and(x < self.initial_time, x > self.final_time) + ) + gamma_rising_derivative = np.zeros_like(x, dtype=np.floating) gamma_rising_derivative[x_decay] = calc_gamma_rising_derivative(x[x_decay]) @@ -272,8 +298,8 @@ def antiderivative(self) -> CosineDecaySplineHelperDerivative: def get_cosine_decay_harmonised_spline( harmonisation_time: Union[int, float], convergence_time: Union[int, float], - harmonised_spline_no_convergence: Spline, - convergence_spline: Spline, + diverge_from: Spline, + harmonisee: Spline, ) -> SumOfSplines: """ Generate the harmonised spline based on a cosine-decay @@ -284,18 +310,18 @@ def get_cosine_decay_harmonised_spline( Harmonisation time This is the time at and before which - the solution should be equal to `harmonised_spline_no_convergence`. + the solution should be equal to `diverge_from`. convergence_time Convergence time This is the time at and after which - the solution should be equal to `convergence_spline`. + the solution should be equal to `harmonisee`. - harmonised_spline_no_convergence + diverge_from Harmonised spline that does not consider convergence - convergence_spline + harmonisee The spline to which the result should converge Returns @@ -308,9 +334,10 @@ def get_cosine_decay_harmonised_spline( # first order derivative). Then we use a decay function to let # the harmonised spline converge to the convergence-spline. # This decay function has the form of a weighted sum: - # weight * harmonised_spline + (1-weight) * convergence_spline + # weight * diverge_from + (1-weight) * harmonisee # With weights decaying from 1 to 0 whereby the decay trajectory # is determined by the cosine decay. + return SumOfSplines( ProductOfSplines( CosineDecaySplineHelper( @@ -318,7 +345,7 @@ def get_cosine_decay_harmonised_spline( final_time=convergence_time, apply_to_convergence=False, ), - harmonised_spline_no_convergence, + diverge_from, ), ProductOfSplines( CosineDecaySplineHelper( @@ -326,6 +353,6 @@ def get_cosine_decay_harmonised_spline( final_time=convergence_time, apply_to_convergence=True, ), - convergence_spline, + harmonisee, ), ) diff --git a/src/gradient_aware_harmonisation/utils.py b/src/gradient_aware_harmonisation/utils.py index c33eca9..bce6805 100644 --- a/src/gradient_aware_harmonisation/utils.py +++ b/src/gradient_aware_harmonisation/utils.py @@ -33,8 +33,8 @@ def __call__( self, harmonisation_time: Union[int, float], convergence_time: Union[int, float], - harmonised_spline_no_convergence: Spline, - convergence_spline: Spline, + diverge_from: Spline, + harmonisee: Spline, ) -> Spline: """ Generate the harmonised spline @@ -45,18 +45,18 @@ def __call__( Harmonisation time This is the time at and before which - the solution should be equal to `harmonised_spline_no_convergence`. + the solution should be equal to `diverge_from`. convergence_time Convergence time This is the time at and after which - the solution should be equal to `convergence_spline`. + the solution should be equal to `harmonisee`. - harmonised_spline_no_convergence + diverge_from Harmonised spline that does not consider convergence - convergence_spline + harmonisee The spline to which the result should converge Returns @@ -133,15 +133,15 @@ def harmonise_splines( # noqa: PLR0913 harmonised_spline_first_derivative_only(harmonisation_time), ) - harmonised_spline_no_convergence = add_constant_to_spline( + diverge_from = add_constant_to_spline( in_spline=harmonised_spline_first_derivative_only, constant=diff_spline ) harmonised_spline = get_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, - convergence_spline=converge_to, + diverge_from=diverge_from, + harmonisee=converge_to, ) return harmonised_spline diff --git a/tests/integration/test_convergence_integration.py b/tests/integration/test_convergence_integration.py index 9aabb09..a41024b 100644 --- a/tests/integration/test_convergence_integration.py +++ b/tests/integration/test_convergence_integration.py @@ -184,13 +184,13 @@ def test_get_cosine_decay_harmonised_spline(): x_up_to_harmonisation_time = np.linspace(x_min, harmonisation_time, 50) x_after_convergence_time = np.linspace(convergence_time, x_max, 50) - harmonised_spline_no_convergence = SplineScipy( + diverge_from = SplineScipy( scipy.interpolate.PPoly( x=[x_min, x_max], c=[[1], [0], [0], [0]], # y=x^3 ) ) - convergence_spline = SplineScipy( + harmonisee = SplineScipy( scipy.interpolate.PPoly( x=[x_min, x_max], c=[[-1], [1], [2]], # y=-x^2 + x + 2 @@ -200,17 +200,17 @@ def test_get_cosine_decay_harmonised_spline(): res = get_cosine_decay_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, - convergence_spline=convergence_spline, + diverge_from=diverge_from, + harmonisee=harmonisee, ) np.testing.assert_equal( - harmonised_spline_no_convergence(x_up_to_harmonisation_time), + diverge_from(x_up_to_harmonisation_time), res(x_up_to_harmonisation_time), ) np.testing.assert_equal( - convergence_spline(x_after_convergence_time), + harmonisee(x_after_convergence_time), res(x_after_convergence_time), ) @@ -218,18 +218,11 @@ def test_get_cosine_decay_harmonised_spline(): np.testing.assert_equal( np.array( [ - 0.5 - * (1.0 + np.cos(np.pi * 0.5 / 6.0)) - * harmonised_spline_no_convergence(3.0) - + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0))) - * convergence_spline(3.0), - 0.5 * harmonised_spline_no_convergence(5.5) - + 0.5 * convergence_spline(5.5), - 0.5 - * (1.0 + np.cos(np.pi * 3.5 / 6.0)) - * harmonised_spline_no_convergence(6.0) - + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0))) - * convergence_spline(6.0), + 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0)) * diverge_from(3.0) + + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0))) * harmonisee(3.0), + 0.5 * diverge_from(5.5) + 0.5 * harmonisee(5.5), + 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0)) * diverge_from(6.0) + + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0))) * harmonisee(6.0), ] ), res(np.array([3.0, 5.5, 6.0])), diff --git a/tests/integration/test_harmonise_splines_cosine_weight_decay.py b/tests/integration/test_harmonise_splines_cosine_weight_decay.py index e69de29..1e5760a 100644 --- a/tests/integration/test_harmonise_splines_cosine_weight_decay.py +++ b/tests/integration/test_harmonise_splines_cosine_weight_decay.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from gradient_aware_harmonisation.convergence import get_cosine_decay_harmonised_spline +from gradient_aware_harmonisation.spline import Spline, SplineScipy + + +def check_expected_continuity( + solution: Spline, + diverge_from: Spline, + harmonisee: Spline, + harmonisation_time: float, + convergence_time: float, +) -> None: + np.testing.assert_allclose( + solution(harmonisation_time), + diverge_from(harmonisation_time), + err_msg=( + "Difference in absolute value of solution and diverge_from " + "at harmonisation_time" + ), + ) + + +@pytest.mark.parametrize( + "harmonisation_time, convergence_time", + ( + pytest.param(0.0, 1.0), + pytest.param(0.0, 1.7), + pytest.param(3.0, 8.0), + pytest.param(-3.0, 0.0), + pytest.param(-3.0, 8.0), + pytest.param(-3.0, -1.0), + pytest.param(3.0, 1.0, id="backwards_harmonisation_positive_times"), + pytest.param( + 3.0, -1.0, id="backwards_harmonisation_positive_and_negative_time" + ), + pytest.param(-30.0, -10.0, id="backwards_harmonisation_negative_times"), + ), +) +def test_harmonisation_convergence_times(harmonisation_time, convergence_time): + """ + Test over a variety of harmonisation and convergence times + """ + scipy = pytest.importorskip("scipy") + + diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[[2.75], [1.2]], + x=[-100, 100], + ) + ) + + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[[2.3], [0.5]], + x=[-100, 100], + ) + ) + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + check_expected_continuity( + solution=res, + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + +def test_harmonisation_time_greater_than_convergence_time(): + scipy = pytest.importorskip("scipy") + + harmonisation_time = 1.0 + convergence_time = -1.0 + + # y = x + # TODO: from left-edge or something here + diverge_from = SplineScipy( + scipy.interpolate.PPoly( + # These are the constants you need given how PPoly is defined + # (it's basically y = f(x - x_le), + # where x_le is the left-edge of the boundary) + c=[[1.0], [-10.0]], + x=[-10.0, 10.0], + ) + ) + assert diverge_from(harmonisation_time) == 1.0 + + # y = 0.5x - 1 + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[[0.5], [-6.0]], + x=[-10.0, 10.0], + ) + ) + assert harmonisee(convergence_time) == -1.5 + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + check_expected_continuity( + solution=res, + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) diff --git a/tests/unit/test_get_cosine_decay_spline.py b/tests/unit/test_get_cosine_decay_spline.py index 3a70570..1cc3ca1 100644 --- a/tests/unit/test_get_cosine_decay_spline.py +++ b/tests/unit/test_get_cosine_decay_spline.py @@ -54,15 +54,15 @@ def test_get_cosine_decay(harmonisation_time, convergence_time): harmonised_spline_first_derivative_only(harmonisation_time), ) - harmonised_spline_no_convergence = add_constant_to_spline( + diverge_from = add_constant_to_spline( in_spline=harmonised_spline_first_derivative_only, constant=diff_spline ) harmonised_spline_convergence = get_cosine_decay_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, - convergence_spline=harmonisee_spline, + diverge_from=diverge_from, + harmonisee=harmonisee_spline, ) np.testing.assert_allclose(