Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions python/MRzeroCore/phantom/voxel_grid_phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def dephasing_func(t, _): return identity(t)
recover_func=lambda data: recover(mask, data),
phantom_motion=self.phantom_motion,
voxel_motion=self.voxel_motion,
tissue_masks=self.tissue_masks
tissue_masks=mask
)

@classmethod
Expand Down Expand Up @@ -312,11 +312,6 @@ def select(tensor: torch.Tensor):
return tensor[..., slices].view(
*list(self.PD.shape[:2]), len(slices)
)
def select_multicoil(tensor: torch.Tensor):
coils = tensor.shape[0]
return tensor[..., slices].view(
coils, *list(self.PD.shape[:2]), len(slices)
)

return VoxelGridPhantom(
select(self.PD),
Expand All @@ -325,8 +320,8 @@ def select_multicoil(tensor: torch.Tensor):
select(self.T2dash),
select(self.D),
select(self.B0),
select_multicoil(self.B1),
select_multicoil(self.coil_sens),
select(self.B1).unsqueeze(0),
select(self.coil_sens).unsqueeze(0),
self.size.clone(),
tissue_masks={
key: mask[..., slices] for key, mask in self.tissue_masks.items()
Expand Down Expand Up @@ -438,7 +433,7 @@ def resample_masks(tensors: Dict) -> Optional[Dict]:
tissue_masks=resample_masks(self.tissue_masks)
)

def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
def plot(self, plot_masks=False, plot_slice="center") -> None:
"""
Print and plot all data stored in this phantom.

Expand All @@ -449,8 +444,6 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
slice : str | int
If int, the specified slice is plotted. "center" plots the center
slice and "all" plots all slices as a grid.
time_unit : str
Time unit to use for T1, T2, and T2' maps (default: 's'). Supported 's' and 'ms'.
"""
print("VoxelGridPhantom")
print(f"size = {self.size}")
Expand All @@ -459,7 +452,7 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
s = self.PD.shape[2] // 2
elif plot_slice == "all":
s = slice(None)
elif isinstance(plot_slice, int):
elif plot_slice is int:
s = plot_slice
else:
raise ValueError("expected plot_slice to be 'all', 'center' or an integer")
Expand All @@ -471,10 +464,6 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
if self.PD.shape[2] > 1:
print(f"Plotting slice {s} / {self.PD.shape[2]}")


# Get time unit scaling factor
time_factor = 1000 if time_unit == 'ms' else 1

# Determine the number of subplots needed
num_plots = 9 # Base number of plots without masks
if plot_masks:
Expand All @@ -494,18 +483,18 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None:
plt.colorbar()

plt.subplot(rows, cols, 2)
plt.title("T1 (%s)" % time_unit)
imshow(self.T1[:, :, s]*time_factor, vmin=0)
plt.title("T1")
imshow(self.T1[:, :, s], vmin=0)
plt.colorbar()

plt.subplot(rows, cols, 3)
plt.title("T2 (%s)" % time_unit)
imshow(self.T2[:, :, s]*time_factor, vmin=0)
plt.title("T2")
imshow(self.T2[:, :, s], vmin=0)
plt.colorbar()

plt.subplot(rows, cols, 4)
plt.title("T2' (%s)" % time_unit)
imshow(self.T2dash[:, :, s]*time_factor, vmin=0)
plt.title("T2'")
imshow(self.T2dash[:, :, s], vmin=0)
plt.colorbar()

plt.subplot(rows, cols, 5)
Expand Down Expand Up @@ -564,7 +553,7 @@ def plot3D(self, data2print: int = 0) -> None:

def recover(mask, sim_data: SimData) -> VoxelGridPhantom:
"""Provided to :class:`SimData` to reverse the ``build()``"""

mask = mask.to(sim_data.device)

def to_full(sparse):
Expand Down
Loading