|
18 | 18 |
|
19 | 19 | import matplotlib |
20 | 20 | import matplotlib.pyplot as plt |
| 21 | +import numpy as np |
21 | 22 | import pytest |
22 | 23 | from numpy import array_equal, dot, linspace, pi, sin |
23 | 24 | from scipy.optimize import leastsq |
24 | 25 |
|
| 26 | +from diffpy.srfit.fitbase import FitResults |
25 | 27 | from diffpy.srfit.fitbase.fitcontribution import FitContribution |
26 | 28 | from diffpy.srfit.fitbase.fitrecipe import FitRecipe |
27 | 29 | from diffpy.srfit.fitbase.parameter import Parameter |
@@ -462,6 +464,76 @@ def optimize_recipe(recipe): |
462 | 464 | leastsq(residuals, values) |
463 | 465 |
|
464 | 466 |
|
| 467 | +def test_initialize_recipe_from_recipe(build_recipe_two_contributions): |
| 468 | + # Case: User initializes a FitRecipe from a previously optimized fit |
| 469 | + # expected: recipe is initialized with everything: |
| 470 | + # contributions, profiles (contained in contributions), |
| 471 | + # variables, restraints, and constraints |
| 472 | + recipe1 = build_recipe_two_contributions |
| 473 | + optimize_recipe(recipe1) |
| 474 | + expected_parameters_dict = recipe1._parameters |
| 475 | + expected_constraints_dict = recipe1._constraints |
| 476 | + expected_restraints_set = recipe1._restraints |
| 477 | + expected_contributions_dict = recipe1._contributions |
| 478 | + expected_profiles_list = [] |
| 479 | + for con_name, contribution in expected_contributions_dict.items(): |
| 480 | + expected_profile = contribution.profile |
| 481 | + expected_profiles_list.append(expected_profile) |
| 482 | + |
| 483 | + recipe2 = FitRecipe() |
| 484 | + recipe2.initialize_recipe_with_recipe(recipe1) |
| 485 | + actual_parameters_dict = recipe2._parameters |
| 486 | + actual_constraints_dict = recipe2._constraints |
| 487 | + actual_restraints_set = recipe2._restraints |
| 488 | + actual_contributions_dict = recipe2._contributions |
| 489 | + actual_profiles_list = [] |
| 490 | + for con_name, contribution in actual_contributions_dict.items(): |
| 491 | + actual_profile = contribution.profile |
| 492 | + actual_profiles_list.append(actual_profile) |
| 493 | + |
| 494 | + assert expected_parameters_dict == actual_parameters_dict |
| 495 | + assert expected_constraints_dict == actual_constraints_dict |
| 496 | + assert expected_restraints_set == actual_restraints_set |
| 497 | + assert expected_contributions_dict == actual_contributions_dict |
| 498 | + assert expected_profiles_list == actual_profiles_list |
| 499 | + |
| 500 | + # Check to see if the refined values and variable names are |
| 501 | + # the same in the results objects for each recipe |
| 502 | + results1 = FitResults(recipe1) |
| 503 | + # round to account for small numerical differences |
| 504 | + expected_values = np.round(results1.varvals, 7) |
| 505 | + expected_names = results1.varnames |
| 506 | + |
| 507 | + optimize_recipe(recipe2) |
| 508 | + results2 = FitResults(recipe2) |
| 509 | + # round to account for small numerical differences |
| 510 | + actual_values = np.round(results2.varvals, 7) |
| 511 | + actual_names = results2.varnames |
| 512 | + |
| 513 | + assert sorted(expected_names) == sorted(actual_names) |
| 514 | + assert sorted(list(expected_values)) == sorted(list(actual_values)) |
| 515 | + |
| 516 | + |
| 517 | +def test_initialize_recipe_from_recipe_bad(build_recipe_two_contributions): |
| 518 | + # Case: User tries to initialize a FitRecipe from a non recipe object |
| 519 | + # expected: raised ValueError with message |
| 520 | + recipe_bad = 12345 # not a FitRecipe object |
| 521 | + recipe2 = FitRecipe() |
| 522 | + msg = ( |
| 523 | + "The input recipe_object must be a FitRecipe, " |
| 524 | + "but got <class 'int'>." |
| 525 | + ) |
| 526 | + with pytest.raises(ValueError, match=msg): |
| 527 | + recipe2.initialize_recipe_with_recipe(recipe_bad) |
| 528 | + |
| 529 | + |
| 530 | +# def test_initialize_recipe_from_results(build_recipe_one_contribution): |
| 531 | +# # Case: User initializes a FitRecipe from a FitResults object or |
| 532 | +# # results file |
| 533 | +# # expected: recipe is initialized with variables from previous fit |
| 534 | +# assert False |
| 535 | + |
| 536 | + |
465 | 537 | def get_labels_and_linecount(ax): |
466 | 538 | """Helper to get line labels and count from a matplotlib Axes.""" |
467 | 539 | labels = [ |
|
0 commit comments