Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions zonopy/contset/zonotope/batch_zono.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,51 +231,65 @@ def __mul__(self,other):
def __len__(self):
return self.Z.shape[0]

def slice(self,slice_dim,slice_pt):
def slice(self, slice_dim, slice_pt, return_grads = False):
'''
slice zonotope on specified point in a certain dimension
self: <zonotope>
slice_dim: <torch.Tensor> or <list> or <int>
, shape []
slice_pt: <torch.Tensor> or <list> or <float> or <int>
, shape []
return <zonotope>
Slice zonotope on specified point in a given dimension and return gradient.
Returns:
newzono: batchZonotope after slicing
dNewCenter_dSlicePt: gradient of the new center with respect to slice_pt
'''
if isinstance(slice_dim, list):
slice_dim = torch.tensor(slice_dim,dtype=torch.long,device=self.device)
elif isinstance(slice_dim, int) or (isinstance(slice_dim, torch.Tensor) and len(slice_dim.shape)==0):
slice_dim = torch.tensor([slice_dim],dtype=torch.long,device=self.device)
slice_dim = torch.tensor(slice_dim, dtype=torch.long, device=self.device)
elif isinstance(slice_dim, int) or (isinstance(slice_dim, torch.Tensor) and len(slice_dim.shape) == 0):
slice_dim = torch.tensor([slice_dim], dtype=torch.long, device=self.device)

if isinstance(slice_pt, list):
slice_pt = torch.tensor(slice_pt,dtype=self.dtype,device=self.device)
elif isinstance(slice_pt, int) or isinstance(slice_pt, float) or (isinstance(slice_pt, torch.Tensor) and len(slice_pt.shape)==0):
slice_pt = torch.tensor([slice_pt],dtype=self.dtype,device=self.device)
slice_pt = torch.tensor(slice_pt, dtype=self.dtype, device=self.device)
elif isinstance(slice_pt, int) or isinstance(slice_pt, float) or (isinstance(slice_pt, torch.Tensor) and len(slice_pt.shape) == 0):
slice_pt = torch.tensor([slice_pt], dtype=self.dtype, device=self.device)

assert isinstance(slice_dim, torch.Tensor) and isinstance(slice_pt, torch.Tensor), 'Invalid type of input'
assert len(slice_dim.shape) ==1, 'slicing dimension should be 1-dim component.'
#assert slice_pt.shape[:-1] ==self.batch_shape, 'slicing point should be (batch+1)-dim component.'
assert len(slice_dim) == slice_pt.shape[-1], f'The number of slicing dimension ({len(slice_dim)}) and the number of slicing point ({slice_pt.shape[-1]}) should be the same.'
assert len(slice_dim.shape) == 1, 'slicing dimension should be 1-dim component.'
assert len(slice_dim) == slice_pt.shape[-1], (
f'The number of slicing dimensions ({len(slice_dim)}) and the number of slicing points '
f'({slice_pt.shape[-1]}) should be the same.'
)

N = len(slice_dim)
slice_dim, ind = torch.sort(slice_dim)
slice_pt = slice_pt[(slice(None),)*(len(slice_pt.shape)-1)+(ind,)]
slice_pt = slice_pt[(slice(None),) * (len(slice_pt.shape) - 1) + (ind,)]

c = self.center
G = self.generators
G_dim = G[self.batch_idx_all+(slice(None),slice_dim)]
G_dim = G[self.batch_idx_all + (slice(None), slice_dim)]
non_zero_idx = G_dim != 0
assert torch.all(torch.sum(non_zero_idx,-2)==1), 'There should be one generator for each slice index.'
slice_idx = non_zero_idx.transpose(-2,-1).nonzero()

assert torch.all(torch.sum(non_zero_idx, -2) == 1), 'There should be one generator for each slice index.'
slice_idx = non_zero_idx.transpose(-2, -1).nonzero()

slice_c = c[self.batch_idx_all + (slice_dim,)]
ind = tuple(slice_idx[:, :-2].T)
slice_g = G_dim[ind + (slice_idx[:, -1], slice_idx[:, -2])].reshape(self.batch_shape + (N,))
slice_lambda = (slice_pt - slice_c) / slice_g
assert not (abs(slice_lambda) > 1).any(), 'Slice point is outside bounds of reach set'

# Compute new center and its gradient
G_slice = G[ind + (slice_idx[:, -1],)].reshape(self.batch_shape + (N, self.dimension))
newc = c.unsqueeze(-2) + slice_lambda.unsqueeze(-2) @ G_slice
newc = newc.squeeze(-2)

# Form new zonotope by removing the sliced generators
remaining = ~non_zero_idx.any(-1)
newG = G[remaining].reshape(self.batch_shape + (-1, self.dimension))
Z = torch.cat((newc.unsqueeze(-2), newG), -2)
output_zono = batchZonotope(Z)

if return_grads:
dNewCenter_dSlicePt = G_slice.transpose(-2, -1) / slice_g.unsqueeze(-2)
dNewGenerators_dSlicePt = torch.zeros(newG.shape + (slice_pt.shape[-1],), dtype=self.dtype, device=self.device)
return output_zono, dNewCenter_dSlicePt, dNewGenerators_dSlicePt

return output_zono

#slice_idx = torch.any(non_zero_idx,-1)
slice_c = c[self.batch_idx_all+(slice_dim,)]
ind = tuple(slice_idx[:,:-2].T)
slice_g = G_dim[ind+(slice_idx[:,-1],slice_idx[:,-2])].reshape(self.batch_shape+(N,))
slice_lambda = (slice_pt-slice_c)/slice_g
assert not (abs(slice_lambda)>1).any(), 'slice point is ouside bounds of reach set, and therefore is not verified'
Z = torch.cat((c.unsqueeze(-2) + slice_lambda.unsqueeze(-2)@G[ind+(slice_idx[:,-1],)].reshape(self.batch_shape+(N,self.dimension)),G[~non_zero_idx.any(-1)].reshape(self.batch_shape+(-1,self.dimension))),-2)
return batchZonotope(Z)
def project(self,dim=[0,1]):
'''
The projection of a batch zonotope onto the specified dimensions
Expand Down
2 changes: 1 addition & 1 deletion zonopy/util/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def remove_dependence_and_compress(
ful_slc_idx = torch.logical_and(has_val, dn_has_val)

if isinstance(Set,(zp.polyZonotope,zp.batchPolyZonotope)):
if zpi.__debug_extra__: assert torch.count_nonzero(ful_slc_idx) <= np.count_nonzero(id_idx)
# if zpi.__debug_extra__: assert torch.count_nonzero(ful_slc_idx) <= np.count_nonzero(id_idx)
c = Set.c
G = Set.G[...,ful_slc_idx,:]
ExpMat = Set.expMat[ful_slc_idx][:,id_idx]
Expand Down