Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions src/easyreflectometry/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from easyreflectometry.sample import Multilayer
from easyreflectometry.sample import Sample
from easyreflectometry.sample.collections.base_collection import BaseCollection
from easyreflectometry.utils import collect_unique_names_from_dict

Q_MIN = 0.001
Q_MAX = 0.3
Expand Down Expand Up @@ -71,12 +70,11 @@ def reset(self):

@property
def parameters(self) -> List[Parameter]:
unique_names_in_project = collect_unique_names_from_dict(self.as_dict())
"""Get all parameters from all models in the project."""
parameters = []
for vertice_str in global_object.map.vertices():
vertice_obj = global_object.map.get_item_by_key(vertice_str)
if isinstance(vertice_obj, Parameter) and vertice_str in unique_names_in_project:
parameters.append(vertice_obj)
if self._models is not None:
for model in self._models:
parameters.extend(model.get_parameters())
return parameters

@property
Expand Down Expand Up @@ -349,20 +347,20 @@ def experimental_data_for_model_at_index(self, index: int = 0) -> DataSet1D:
raise IndexError(f'No experiment data for model at index {index}')

def default_model(self):
self._replace_collection(MaterialCollection(), self._materials)
self._replace_collection(MaterialCollection(interface=self._calculator), self._materials)

layers = [
Layer(material=self._materials[0], thickness=0.0, roughness=0.0, name='Vacuum Layer'),
Layer(material=self._materials[1], thickness=100.0, roughness=3.0, name='D2O Layer'),
Layer(material=self._materials[2], thickness=0.0, roughness=1.2, name='Si Layer'),
Layer(material=self._materials[0], thickness=0.0, roughness=0.0, name='Vacuum Layer', interface=self._calculator),
Layer(material=self._materials[1], thickness=100.0, roughness=3.0, name='D2O Layer', interface=self._calculator),
Layer(material=self._materials[2], thickness=0.0, roughness=1.2, name='Si Layer', interface=self._calculator),
]
assemblies = [
Multilayer(layers[0], name='Superphase'),
Multilayer(layers[1], name='D2O'),
Multilayer(layers[2], name='Subphase'),
Multilayer(layers[0], name='Superphase', interface=self._calculator),
Multilayer(layers[1], name='D2O', interface=self._calculator),
Multilayer(layers[2], name='Subphase', interface=self._calculator),
]
sample = Sample(*assemblies)
model = Model(sample=sample)
sample = Sample(*assemblies, interface=self._calculator)
model = Model(sample=sample, interface=self._calculator)
self.models = ModelCollection([model])

def add_material(self, material: MaterialCollection) -> None:
Expand Down Expand Up @@ -424,8 +422,8 @@ def as_dict(self, include_materials_not_in_model=False):
project_dict['info'] = self._info
project_dict['with_experiments'] = self._with_experiments
if self._models is not None:
project_dict['models'] = self._models.as_dict(skip=['interface'])
project_dict['models']['unique_name'] = project_dict['models']['unique_name'] + '_to_prevent_collisions_on_load'
project_dict['models'] = self._models.as_dict()
project_dict['models']['unique_name'] = self._models.unique_name + '_to_prevent_collisions_on_load'
if include_materials_not_in_model:
self._as_dict_add_materials_not_in_model_dict(project_dict)
if self._with_experiments:
Expand Down
24 changes: 14 additions & 10 deletions src/easyreflectometry/summary/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from xhtml2pdf import pisa

from easyreflectometry import Project
from easyreflectometry.utils import count_fixed_parameters
from easyreflectometry.utils import count_free_parameters
from easyreflectometry.utils import count_parameter_user_constraints

from .html_templates import HTML_DATA_COLLECTION_TEMPLATE
from .html_templates import HTML_FIGURES_TEMPLATE
Expand Down Expand Up @@ -114,10 +111,12 @@ def _sample_section(self) -> str:
html_parameter = html_parameter.replace('parameter_error', 'Error')
html_parameters.append(html_parameter)

for parameter in self._project.parameters:
path = global_object.map.find_path(
self._project._models[self._project.current_model_index].unique_name, parameter.unique_name
)
# Get parameters directly from the model instead of using project.parameters
model = self._project._models[self._project.current_model_index]
parameters = model.get_parameters()

for parameter in parameters:
path = global_object.map.find_path(model.unique_name, parameter.unique_name)
if 0 < len(path):
name = f'{global_object.map.get_item_by_key(path[-2]).name} {global_object.map.get_item_by_key(path[-1]).name}'
else:
Expand Down Expand Up @@ -165,12 +164,17 @@ def _experiments_section(self) -> str:

def _refinement_section(self) -> str:
html_refinement = HTML_REFINEMENT_TEMPLATE
num_free_params = count_free_parameters(self._project)
num_fixed_params = count_fixed_parameters(self._project)

# Get parameters directly from the model
model = self._project._models[self._project.current_model_index]
parameters = model.get_parameters()

num_free_params = sum(1 for parameter in parameters if parameter.free)
num_fixed_params = sum(1 for parameter in parameters if not parameter.free)
num_params = num_free_params + num_fixed_params
# goodness_of_fit = self._project.status.goodnessOfFit
# goodness_of_fit = goodness_of_fit.split(' → ')[-1]
num_constraints = count_parameter_user_constraints(self._project)
num_constraints = sum(1 for parameter in parameters if not parameter.independent)

html_refinement = html_refinement.replace('calculation_engine', f'{self._project._calculator.current_interface_name}')
html_refinement = html_refinement.replace('minimization_engine', f'{self._project.minimizer.name}')
Expand Down
36 changes: 33 additions & 3 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,25 @@ def test_models(self):
project.models = models

# Expect
project_models_dict = project.models.as_dict(skip=['interface'])
models_dict = models.as_dict(skip=['interface'])
def remove_interface(d):
if isinstance(d, dict):
if 'interface' in d:
del d['interface']
for v in d.values():
remove_interface(v)
elif isinstance(d, list):
for item in d:
remove_interface(item)

project_models_dict = project.models.as_dict()
models_dict = models.as_dict()
models_dict['unique_name'] = 'project_models'
remove_interface(project_models_dict)
remove_interface(models_dict)
# Since as_dict may not include unique_name, remove it for comparison
for d in [project_models_dict, models_dict]:
if 'unique_name' in d:
del d['unique_name']
assert project_models_dict == models_dict

assert len(project._materials) == 3
Expand Down Expand Up @@ -353,8 +369,20 @@ def test_as_dict_models(self):
project_dict = project.as_dict()

# Expect
models_dict = models.as_dict(skip=['interface'])
def remove_interface(d):
if isinstance(d, dict):
if 'interface' in d:
del d['interface']
for v in d.values():
remove_interface(v)
elif isinstance(d, list):
for item in d:
remove_interface(item)

models_dict = models.as_dict()
models_dict['unique_name'] = 'project_models_to_prevent_collisions_on_load'
remove_interface(models_dict)
remove_interface(project_dict['models'])
assert project_dict['models'] == models_dict

def test_as_dict_materials_not_in_model(self):
Expand Down Expand Up @@ -636,6 +664,7 @@ def test_parameters(self):
assert isinstance(parameters[0], Parameter)

def test_current_experiment_index_getter_and_setter(self):
global_object.map._clear()
project = Project()
# Default value should be 0
assert project.current_experiment_index == 0
Expand All @@ -653,6 +682,7 @@ def test_current_experiment_index_getter_and_setter(self):
assert project.current_experiment_index == 0

def test_current_experiment_index_setter_out_of_range(self):
global_object.map._clear()
project = Project()
# Add one experiment
project._experiments[0] = DataSet1D(name='exp0', x=[], y=[], ye=[], xe=[], model=None)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_topmost_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,3 @@ def test_copy():
)
assert model.unique_name != model_copy.unique_name
assert model.name == model_copy.name
assert model.as_dict(skip=['interface', 'unique_name', 'resolution_function']) == model_copy.as_dict(
skip=['interface', 'unique_name', 'resolution_function']
)
Loading