diff --git a/AUTHORS b/AUTHORS index 26a9fbb..a921893 100644 --- a/AUTHORS +++ b/AUTHORS @@ -27,3 +27,4 @@ Contributors: * Zhiyi Wu * Olivier Languin-Cattoën * Andrés Montoya (logo) +* Pradyumn Prasad diff --git a/CHANGELOG b/CHANGELOG index 7741a9e..88775c5 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -13,7 +13,8 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) ------------------------------------------------------------------------------ -??/??/???? IAlibay, ollyfutur, conradolandia, orbeckst + +??/??/???? IAlibay, ollyfutur, conradolandia, orbeckst, Pradyumn-cloud * 1.1.0 Changes @@ -28,6 +29,7 @@ The rules for this file: * `Grid` now accepts binary operations with any operand that can be broadcasted to the grid's shape according to `numpy` broadcasting rules (PR #142) + * Added MRC file writing support (Issue #108) Fixes diff --git a/gridData/core.py b/gridData/core.py index 3395d62..c8f0d97 100644 --- a/gridData/core.py +++ b/gridData/core.py @@ -203,6 +203,7 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None, 'PKL': self._export_python, 'PICKLE': self._export_python, # compatibility 'PYTHON': self._export_python, # compatibility + 'MRC': self._export_mrc, } self._loaders = { 'CCP4': self._load_mrc, @@ -590,6 +591,8 @@ def export(self, filename, file_format=None, type=None, typequote='"'): dx :mod:`OpenDX` + mrc + :mod:`mrc` MRC/CCP4 format pickle pickle (use :meth:`Grid.load` to restore); :meth:`Grid.save` is simpler than ``export(format='python')``. @@ -599,7 +602,7 @@ def export(self, filename, file_format=None, type=None, typequote='"'): filename : str name of the output file - file_format : {'dx', 'pickle', None} (optional) + file_format : {'dx', 'pickle', 'mrc', None} (optional) output file format, the default is "dx" type : str (optional) @@ -676,6 +679,41 @@ def _export_dx(self, filename, type=None, typequote='"', **kwargs): if ext == '.gz': filename = root + ext dx.write(filename) + + def _export_mrc(self, filename, **kwargs): + """Export the density grid to an MRC/CCP4 file. + + The MRC2014 file format is used via the mrcfile library. + + Parameters + ---------- + filename : str + Output filename + **kwargs + Additional keyword arguments (currently ignored) + + Notes + ----- + * Only orthorhombic unit cells are supported + * If the Grid was loaded from an MRC file, the original header + information (including axis ordering) is preserved + * For new grids, standard ordering (mapc=1, mapr=2, maps=3) is used + + .. versionadded:: 1.1.0 + """ + # Create MRC object and populate with Grid data + mrc_file = mrc.MRC() + mrc_file.array = self.grid + mrc_file.delta = numpy.diag(self.delta) + mrc_file.origin = self.origin + mrc_file.rank = 3 + + # Transfer header if it exists (preserves axis ordering and other metadata) + if hasattr(self, '_mrc_header'): + mrc_file.header = self._mrc_header + + # Write to file + mrc_file.write(filename) def save(self, filename): """Save a grid object to `filename` and add ".pickle" extension. diff --git a/gridData/mrc.py b/gridData/mrc.py index 329f50a..b3c9bf9 100644 --- a/gridData/mrc.py +++ b/gridData/mrc.py @@ -79,7 +79,7 @@ class MRC(object): ----- * Only volumetric (3D) densities are read. * Only orthorhombic unitcells supported (other raise :exc:`ValueError`) - * Only reading is currently supported. + * Reading and writing are supported. .. versionadded:: 0.7.0 @@ -148,5 +148,106 @@ def histogramdd(self): """Return array data as (edges,grid), i.e. a numpy nD histogram.""" return (self.array, self.edges) - - + def write(self, filename): + """Write grid data to MRC/CCP4 file format. + + Parameters + ---------- + filename : str + Output filename for the MRC file + + Notes + ----- + The data array should be in xyz order (axis 0=X, axis 1=Y, axis 2=Z). + + If the MRC object was created by reading an existing file, the original + header information (including mapc, mapr, maps ordering) is preserved. + Otherwise, standard ordering (mapc=1, mapr=2, maps=3) is used. + + .. versionadded:: 1.1.0 + """ + if filename is not None: + self.filename = filename + + # Preserve header if it exists, otherwise use defaults + if hasattr(self, 'header'): + # File was read - preserve original ordering + h = self.header + axes_order = np.hstack([h.mapc, h.mapr, h.maps]) + mapc, mapr, maps = int(h.mapc), int(h.mapr), int(h.maps) + else: + # New file - use standard ordering + axes_order = np.array([1, 2, 3]) + mapc, mapr, maps = 1, 2, 3 + h = None + + # Reverse the transformation done in read() + transpose_order = np.argsort(axes_order[::-1]) + inverse_transpose_order = np.argsort(transpose_order) + + # Transform our xyz array back to the file's native ordering + data_for_file = np.transpose(self.array, axes=inverse_transpose_order) + + # Ensure proper data type (float32 is standard for mode 2) + data_for_file = data_for_file.astype(np.float32) + + # Create new MRC file + with mrcfile.new(filename, overwrite=True) as mrc: + mrc.set_data(data_for_file) + + # Set voxel size from delta (diagonal elements) + voxel_size = np.diag(self.delta).astype(np.float32) + mrc.voxel_size = tuple(voxel_size) + + # Set map ordering + mrc.header.mapc = mapc + mrc.header.mapr = mapr + mrc.header.maps = maps + + # Handle nstart and origin + if h is not None: + # Preserve original header values + nxstart = int(h.nxstart) + nystart = int(h.nystart) + nzstart = int(h.nzstart) + header_origin_xyz = np.array([h.origin.x, h.origin.y, h.origin.z], dtype=np.float32) + + mrc.header.mx = int(h.mx) + mrc.header.my = int(h.my) + mrc.header.mz = int(h.mz) + + # Preserve cell dimensions + if hasattr(h, 'cella'): + mrc.header.cella.x = float(h.cella.x) + mrc.header.cella.y = float(h.cella.y) + mrc.header.cella.z = float(h.cella.z) + if hasattr(h, 'cellb'): + mrc.header.cellb.alpha = float(h.cellb.alpha) + mrc.header.cellb.beta = float(h.cellb.beta) + mrc.header.cellb.gamma = float(h.cellb.gamma) + # Copy space group if available + if hasattr(h, 'ispg'): + mrc.header.ispg = int(h.ispg) + else: + # For new files, calculate nstart from origin + if np.any(voxel_size <= 0): + raise ValueError(f"Voxel size must be positive, got {voxel_size}") + + # Set header.origin = 0 and encode everything in nstart + header_origin_xyz = np.zeros(3, dtype=np.float32) + nxstart = int(np.round(self.origin[0] / voxel_size[0])) + nystart = int(np.round(self.origin[1] / voxel_size[1])) + nzstart = int(np.round(self.origin[2] / voxel_size[2])) + + # Set the start positions + mrc.header.nxstart = nxstart + mrc.header.nystart = nystart + mrc.header.nzstart = nzstart + + # Set explicit origin + mrc.header.origin.x = float(header_origin_xyz[0]) + mrc.header.origin.y = float(header_origin_xyz[1]) + mrc.header.origin.z = float(header_origin_xyz[2]) + + # Update statistics only + mrc.update_header_stats() diff --git a/gridData/tests/test_mrc.py b/gridData/tests/test_mrc.py index 93caf21..8111986 100644 --- a/gridData/tests/test_mrc.py +++ b/gridData/tests/test_mrc.py @@ -133,3 +133,170 @@ def test_origin(self, grid, ccp4data): def test_data(self, grid, ccp4data): assert_allclose(grid.grid, ccp4data.array) + +class TestMRCWrite: + """Tests for MRC write functionality""" + + def test_mrc_write_roundtrip(self, ccp4data, tmpdir): + """Test writing and reading back preserves data""" + outfile = str(tmpdir / "roundtrip.mrc") + + # Write the file + ccp4data.write(outfile) + + # Read it back + mrc_read = mrc.MRC(outfile) + + # Check data matches + assert_allclose(mrc_read.array, ccp4data.array) + assert_allclose(mrc_read.origin, ccp4data.origin, rtol=1e-5, atol=1e-3) + assert_allclose(mrc_read.delta, ccp4data.delta, rtol=1e-5) + + def test_mrc_write_header_preserved(self, ccp4data, tmpdir): + """Test that header fields are preserved""" + outfile = str(tmpdir / "header.mrc") + + ccp4data.write(outfile) + mrc_read = mrc.MRC(outfile) + + # Check axis ordering preserved + assert mrc_read.header.mapc == ccp4data.header.mapc + assert mrc_read.header.mapr == ccp4data.header.mapr + assert mrc_read.header.maps == ccp4data.header.maps + + # Check offsets preserved + assert mrc_read.header.nxstart == ccp4data.header.nxstart + assert mrc_read.header.nystart == ccp4data.header.nystart + assert mrc_read.header.nzstart == ccp4data.header.nzstart + + def test_mrc_write_new_file(self, tmpdir): + """Test creating new MRC file from scratch""" + outfile = str(tmpdir / "new.mrc") + + # Create new MRC object + mrc_new = mrc.MRC() + mrc_new.array = np.arange(24).reshape(2, 3, 4).astype(np.float32) + mrc_new.delta = np.diag([1.0, 2.0, 3.0]) + mrc_new.origin = np.array([5.0, 10.0, 15.0]) + mrc_new.rank = 3 + + # Write and read back + mrc_new.write(outfile) + mrc_read = mrc.MRC(outfile) + + # Verify + assert_allclose(mrc_read.array, mrc_new.array, rtol=1e-5) + assert_allclose(mrc_read.origin, mrc_new.origin, rtol=1e-4) + assert_allclose(np.diag(mrc_read.delta), np.diag(mrc_new.delta), rtol=1e-5) + + def test_mrc_write_zero_voxel_raises(self, tmpdir): + """Test that zero voxel size raises ValueError""" + outfile = str(tmpdir / "invalid.mrc") + + mrc_obj = mrc.MRC() + mrc_obj.array = np.ones((2, 2, 2), dtype=np.float32) + mrc_obj.delta = np.diag([0.0, 1.0, 1.0]) + mrc_obj.origin = np.array([0.0, 0.0, 0.0]) + mrc_obj.rank = 3 + + with pytest.raises(ValueError, match="Voxel size must be positive"): + mrc_obj.write(outfile) + + +class TestGridMRCWrite: + """Tests for Grid.export() with MRC format""" + + def test_grid_export_mrc(self, tmpdir): + """Test Grid.export() with file_format='mrc'""" + outfile = str(tmpdir / "grid.mrc") + + # Create simple grid + data = np.arange(60).reshape(3, 4, 5).astype(np.float32) + g = Grid(data, origin=[0, 0, 0], delta=[1.0, 1.0, 1.0]) + + # Export and read back + g.export(outfile, file_format='mrc') + g_read = Grid(outfile) + + # Verify + assert_allclose(g_read.grid, g.grid, rtol=1e-5) + assert_allclose(g_read.origin, g.origin, rtol=1e-4) + assert_allclose(g_read.delta, g.delta, rtol=1e-5) + + def test_grid_export_mrc_roundtrip(self, tmpdir): + """Test MRC → Grid → export → Grid preserves data""" + outfile = str(tmpdir / "roundtrip_grid.mrc") + + # Load original + g_orig = Grid(datafiles.CCP4_1JZV) + + # Export and reload + g_orig.export(outfile, file_format='mrc') + g_read = Grid(outfile) + + # Verify + assert_allclose(g_read.grid, g_orig.grid, rtol=1e-5) + assert_allclose(g_read.origin, g_orig.origin, rtol=1e-4) + assert_allclose(g_read.delta, g_orig.delta, rtol=1e-5) + assert_equal(g_read.grid.shape, g_orig.grid.shape) + + def test_grid_export_mrc_preserves_header(self, tmpdir): + """Test that Grid preserves MRC header through export""" + outfile = str(tmpdir / "header_grid.mrc") + + g_orig = Grid(datafiles.CCP4_1JZV) + orig_mapc = g_orig._mrc_header.mapc + orig_mapr = g_orig._mrc_header.mapr + orig_maps = g_orig._mrc_header.maps + + # Export and check + g_orig.export(outfile, file_format='mrc') + g_read = Grid(outfile) + + assert g_read._mrc_header.mapc == orig_mapc + assert g_read._mrc_header.mapr == orig_mapr + assert g_read._mrc_header.maps == orig_maps + + def test_mrc_write_4x4x4_with_header(self, tmpdir): + """Test writing 4x4x4 MRC file with custom header values.""" + + # Create 4x4x4 data + data = np.arange(64, dtype=np.float32).reshape((4, 4, 4)) + outfile = str(tmpdir / "test_with_header.mrc") + + # Create and write MRC + m = mrc.MRC() + m.array = data + m.delta = np.diag([1.5, 2.0, 2.5]) + m.origin = np.array([10.0, 20.0, 30.0]) + m.rank = 3 + m.write(outfile) + + # Read back and verify + m_read = mrc.MRC(outfile) + assert_allclose(m_read.array, data) + assert_allclose(np.diag(m_read.delta), [1.5, 2.0, 2.5]) + assert_allclose(m_read.origin, [10.0, 20.0, 30.0], rtol=1e-4, atol=1.0) + + + def test_mrc_write_4x4x4_without_header(self, tmpdir): + """Test writing 4x4x4 MRC file with default header.""" + + # Create 4x4x4 random data + np.random.seed(42) + data = np.random.rand(4, 4, 4).astype(np.float32) + outfile = str(tmpdir / "test_without_header.mrc") + + # Create and write MRC + m = mrc.MRC() + m.array = data + m.delta = np.diag([1.0, 1.0, 1.0]) + m.origin = np.array([0.0, 0.0, 0.0]) + m.rank = 3 + m.write(outfile) + + # Read back and verify + m_read = mrc.MRC(outfile) + assert_allclose(m_read.array, data) + assert_allclose(np.diag(m_read.delta), [1.0, 1.0, 1.0]) + assert_allclose(m_read.origin, [0.0, 0.0, 0.0], rtol=1e-4, atol=1.0)