Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/config/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Inference(Mode):
start_index: int = 0
summary_iteration: int = 1
logging_iteration: int = 1
torch_jit: bool = False

@dataclass
class IOTest(Mode):
Expand Down
115 changes: 80 additions & 35 deletions src/networks/torch/uresnet2D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Dict, Tuple, Optional, Union
import torch
import torch.nn as nn

Expand Down Expand Up @@ -61,30 +62,40 @@ def __init__(self, *,

if params.normalization == Norm.batch:
self._do_normalization = True
self.norm_type = "batch"
self.norm = nn.BatchNorm2d(outplanes)
elif params.normalization == Norm.group:
self._do_normalization = True
self.norm_type = "group"
self.norm = nn.GroupNorm(num_groups=4, num_channels=outplanes)
elif params.normalization == Norm.layer:
self._do_normalization = True
self.norm = "layer"
self.norm_type = "layer"
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
# h_out = (h_in+2*padding[0]-dilation[0]*(kernel[0]-1)-1)/stride[0]+1
# w_out = (w_in+2*padding[1]-dilation[1]*(kernel[1]-1)-1)/stride[1]+1
# h_out = (h_in+2*padding[0]-1*(kernel[0]-1)-1)/stride[0]+1
# w_out = (w_in+2*padding[1]-1*(kernel[1]-1)-1)/stride[1]+1
self.norm = nn.LayerNorm(normalized_shape=[outplanes,44,64])
elif params.normalization == Norm.instance:
self._do_normalization = True
self.norm_type = "instance"
self.norm = nn.InstanceNorm2d(outplanes)
else:
self._do_normalization = False

self.norm_type = "none"
self.norm = nn.Module()

self.activation = activation

def forward(self, x):
out = self.conv(x)
if self._do_normalization:
if self.norm == "layer":
norm_shape = out.shape[1:]
if self.norm_type == "layer":
#norm_shape = out.shape[1:]
# norm_shape = torch.tensor([8,] + list(norm_shape)).to(x.device)
self.norm = torch.nn.LayerNorm(normalized_shape=norm_shape)
self.norm.to(out.device)
#self.norm = torch.nn.LayerNorm(normalized_shape=norm_shape)
#self.norm.to(out.device)
# out = torch.nn.functional.layer_norm(out, norm_shape)
out = self.norm(out)
else:
Expand Down Expand Up @@ -273,8 +284,10 @@ def __init__(self, *, inplanes, n_blocks, params):


def forward(self, x):
for i in range(len(self.blocks)):
x = self.blocks[i](x)
#for i in range(len(self.blocks)):
# x = self.blocks[i](x)
for index, v in enumerate(self.blocks):
x = v(x)

return x

Expand Down Expand Up @@ -348,7 +361,7 @@ def __init__(self, *, inplanes, params):



def forward(self, x):
def forward(self, x: List[torch.Tensor]):

# THis isn't really a recommended setting to use, but we can control whether or not to connect here:
# if FLAGS.BLOCK_CONCAT:
Expand All @@ -361,15 +374,14 @@ def forward(self, x):
classification_head = torch.cat(x, dim=1)
else:

x = torch.cat(x, dim=1)
x = self.bottleneck(x)
x = self.blocks(x)
x = self.unbottleneck(x)
classification_head = x
x = torch.chunk(x, chunks=3, dim=1)

x1 = torch.cat(x, dim=1)
x1 = self.bottleneck(x1)
x1 = self.blocks(x1)
x1 = self.unbottleneck(x1)
classification_head = x1
x = torch.chunk(x1, chunks=3, dim=1)

return x, classification_head, None # The none is a placeholder for vertex ID YOLO
return x, classification_head, x # The none is a placeholder for vertex ID YOLO

class NoConnection(nn.Module):

Expand Down Expand Up @@ -454,6 +466,25 @@ def forward(self, x):
return self.bottleneck(x)


class DoNothing(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)

def forward(self, x):
#for i in range(len(self.blocks)):
# x = self.blocks[i](x)
return x

class DoNothing2(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)

def forward(self, x, y):
#for i in range(len(self.blocks)):
# x = self.blocks[i](x)
return x


class UNetCore(nn.Module):

def __init__(self, *, depth, inplanes, params):
Expand All @@ -469,6 +500,11 @@ def __init__(self, *, depth, inplanes, params):
if depth == 0:
self.main_module = DeepestBlock(inplanes = inplanes,
params = params)
self.down_blocks = DoNothing()
self.downsample = DoNothing()
self.upsample = DoNothing()
self.connection = DoNothing2()
self.up_blocks = DoNothing()
else:
# Residual or convolutional blocks, applied in series:
self.down_blocks = BlockSeries(inplanes = inplanes,
Expand Down Expand Up @@ -530,46 +566,50 @@ def __init__(self, *, depth, inplanes, params):
self.connection = NoConnection()


def forward(self, x):
def forward(self, x: List[torch.Tensor]):


# Take the input and apply the downward pass convolutions. Save the residual
# at the correct time.
residual = x
if self.depth != 0:

residual = x
#residual = x

x = tuple( self.down_blocks(_x) for _x in x )
#x = tuple( self.down_blocks(_x) for _x in x )
x = [self.down_blocks(_x) for _x in x ]

# perform the downsampling operation:
x = tuple( self.downsample(_x) for _x in x )
#x = tuple( self.downsample(_x) for _x in x )
x = [self.downsample(_x) for _x in x]
#
# if FLAGS.VERBOSITY >1:
# for p in range(len(x)):
# print("plane {} Depth {}, shape: ".format(p, self.depth), x[p].shape)


# Apply the main module:
x, classification_head, vertex_head = self.main_module(x)

# The vertex_head is None after the DEEPEST layer. But, if we're returning it, do it here:


if self.depth != 0:

# perform the upsampling step:
# perform the downsampling operation:
x = tuple( self.upsample(_x) for _x in x )
#x = tuple( self.upsample(_x) for _x in x )
x = [self.upsample(_x) for _x in x]

# Connect with the residual if necessary:
# for i in range(len(x)):
# x[i] = self.connection(x[i], residual=residual[i])

x = tuple( self.connection(_x, _r) for _x, _r in zip(x, residual))
#x = tuple( self.connection(_x, _r) for _x, _r in zip(x, residual))
x = [self.connection(_x, _r) for _x, _r in zip(x, residual)]


# Apply the convolutional steps:
x = tuple( self.up_blocks(_x) for _x in x )
#x = tuple( self.up_blocks(_x) for _x in x )
x = [self.up_blocks(_x) for _x in x]

if self.depth == self.vertex_depth: vertex_head = x

Expand Down Expand Up @@ -619,7 +659,7 @@ def __init__(self, params, spatial_size):

# The image size here is going to be the orignal / 2**depth
# We need to know it for the pooling layer
self.pool_size = [d // 2**params.depth for d in spatial_size]
self.pool_size = [int(d // 2**params.depth) for d in spatial_size]

n_filters = params.n_initial_filters
for i in range(params.depth):
Expand Down Expand Up @@ -652,7 +692,7 @@ def __init__(self, params, spatial_size):
stride = 1,
padding = 0,
bias = params.bias)

self.pool = torch.nn.AvgPool2d(self.pool_size)

if params.vertex.active:
Expand Down Expand Up @@ -709,10 +749,12 @@ def forward(self, input_tensor):

batch_size = input_tensor.shape[0]

return_dict = {
"event_label" : None,
"vertex" : None,
}
return_dict: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] = {}
#return_dict = {
# "event_label" : None,
# "vertex" : None,
# "segmentation" : None,
#}


# Reshape this tensor into the right shape to apply this multiplane network.
Expand All @@ -722,16 +764,19 @@ def forward(self, input_tensor):
x = torch.chunk(x, chunks=3, dim=1)

# Apply the initial convolutions:
x = tuple( self.initial_convolution(_x) for _x in x )
#x = tuple( self.initial_convolution(_x) for _x in x )
x = [self.initial_convolution(_x) for _x in x]



# Apply the main unet architecture:
seg_labels, classification_head, vertex_head = self.net_core(x)

# Apply the final residual block to each plane:
seg_labels = tuple( self.final_layer(_x) for _x in seg_labels )
seg_labels = tuple( self.bottleneck(_x) for _x in seg_labels )
#seg_labels = tuple( self.final_layer(_x) for _x in seg_labels )
seg_labels = [self.final_layer(_x) for _x in seg_labels]
#seg_labels = tuple( self.bottleneck(_x) for _x in seg_labels )
seg_labels = [self.bottleneck(_x) for _x in seg_labels]

# Always return the segmentation
return_dict["segmentation"] = seg_labels
Expand Down
25 changes: 19 additions & 6 deletions src/utils/torch/distributed_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,11 @@ def restore_model(self):
else:
state = None

# Restore the weights on rank 0:
if state is not None and self.rank == 0:
self.restore_state(state)


# Broadcast from rank 0 to sync weights
if self.args.framework.distributed_mode == DistributedMode.horovod:
# Restore the weights on rank 0:
if state is not None and self.rank == 0:
self.restore_state(state)

# Broadcast the global step:
self._global_step = hvd.broadcast_object(self._global_step, root_rank = 0)
Expand All @@ -209,6 +207,18 @@ def restore_model(self):
state_dict = hvd.broadcast_object(self.lr_scheduler.state_dict(), root_rank = 0)

elif self.args.framework.distributed_mode == DistributedMode.DDP:
# Broadcast and restore model state
state = MPI.COMM_WORLD.bcast(state, root=0)
self.restore_state(state)

# Compare this rank state to the one on rank 0 to make sure
# broadcast and restore were successful
#state_b = str(self._net.cpu().state_dict())
#state_0 = MPI.COMM_WORLD.bcast(state_b, root=0)
#if state_b != state_0:
# print(f"rank {self.rank} state differs from rank 0", flush=True)
#else:
# print(f"rank {self.rank} state is same as state from rank 0", flush=True)

devices = None
if self.args.run.compute_mode == ComputeMode.XPU:
Expand All @@ -219,7 +229,10 @@ def restore_model(self):

# print(self._net.parameters)

self._net = torch.nn.parallel.DistributedDataParallel(self._net, device_ids=devices, broadcast_buffers=self.args.run.broadcast_buffers, find_unused_parameters=False)
if self.is_training():
self._net = torch.nn.parallel.DistributedDataParallel(self._net, device_ids=devices,
broadcast_buffers=self.args.run.broadcast_buffers,
find_unused_parameters=False)

# print(self._net.parameters)

Expand Down
27 changes: 24 additions & 3 deletions src/utils/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def initialize(self, datasets):
if self.is_training() and self.args.mode.optimizer.gradient_accumulation > 1:
raise Exception("Can not accumulate gradients in half precision.")

# example_batch = next(iter(example_ds))
example_batch = next(iter(example_ds))

# self._net = torch.compile(self._net)

Expand All @@ -158,6 +158,11 @@ def initialize(self, datasets):
if self.args.mode.name == ModeKind.inference:
self.inference_metrics = {}
self.inference_metrics['n'] = 0

# JIT trace the model for inference
if self.args.mode.torch_jit:
self.jit_trace_model(example_batch["image"].shape)


# Turn the datasets into torch dataloaders
for key in datasets.keys():
Expand Down Expand Up @@ -330,6 +335,7 @@ def save_model(self):

torch.save(state_dict, current_file_path)


# Parse the checkpoint file to see what the last checkpoints were:

# Keep only the last 5 checkpoints
Expand Down Expand Up @@ -364,6 +370,21 @@ def save_model(self):
for key in past_checkpoint_files:
_chkpt.write('{}: {}\n'.format(key, past_checkpoint_files[key]))

def jit_trace_model(self, input_shape):
'''JIT trace the model

'''
# Save jit-traced version of the model
torch.jit.set_fusion_strategy([("STATIC",2),("DYNAMIC",2)])
with torch.no_grad():
if self.args.run.compute_mode == ComputeMode.XPU:
self._net = ipex.optimize(self._net)
torch.jit.enable_onednn_fusion(True)
self._net = torch.jit.script(self._net.eval(),example_inputs=[input_shape])
self._net = torch.jit.freeze(self._net)
self._net = torch.jit.optimize_for_inference(self._net)



def get_model_filepath(self):
'''Helper function to build the filepath of a model for saving and restoring:
Expand Down Expand Up @@ -808,7 +829,8 @@ def ana_step(self, batch):
# perform a validation step

# Set network to eval mode
self._net.eval()
if self.args.mode.name == ModeKind.inference and not self.args.mode.torch_jit:
self._net.eval()
# self._net.train()


Expand All @@ -821,7 +843,6 @@ def ana_step(self, batch):
logits_dict, labels_dict = self.forward_pass(batch)



# If the input data has labels available, compute the metrics:
if 'label' in batch:
# Compute the loss
Expand Down