@@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
260260
261261
262262def sliced_wasserstein_sphere (X_s , X_t , a = None , b = None , n_projections = 50 ,
263- p = 2 , seed = None , log = False ):
263+ p = 2 , projections = None , seed = None , log = False ):
264264 r"""
265265 Compute the spherical sliced-Wasserstein discrepancy.
266266
@@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
287287 Number of projections used for the Monte-Carlo approximation
288288 p: float, optional (default=2)
289289 Power p used for computing the spherical sliced Wasserstein
290+ projections: shape (n_projections, dim, 2), optional
291+ Projection matrix (n_projections and seed are not used in this case)
290292 seed: int or RandomState or None, optional
291293 Seed used for random number generator
292294 log: bool, optional
@@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
326328 if nx .any (nx .abs (nx .sum (X_s ** 2 , axis = - 1 ) - 1 ) > 10 ** (- 4 )):
327329 raise ValueError ("X_s is not on the sphere." )
328330 if nx .any (nx .abs (nx .sum (X_t ** 2 , axis = - 1 ) - 1 ) > 10 ** (- 4 )):
329- raise ValueError ("Xt is not on the sphere." )
331+ raise ValueError ("X_t is not on the sphere." )
330332
331- # Uniforms and independent samples on the Stiefel manifold V_{d,2}
332- if isinstance (seed , np .random .RandomState ) and str (nx ) == 'numpy' :
333- Z = seed .randn (n_projections , d , 2 )
333+ if projections is None :
334+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
335+ if isinstance (seed , np .random .RandomState ) and str (nx ) == 'numpy' :
336+ Z = seed .randn (n_projections , d , 2 )
337+ else :
338+ if seed is not None :
339+ nx .seed (seed )
340+ Z = nx .randn (n_projections , d , 2 , type_as = X_s )
341+
342+ projections , _ = nx .qr (Z )
334343 else :
335- if seed is not None :
336- nx .seed (seed )
337- Z = nx .randn (n_projections , d , 2 , type_as = X_s )
338-
339- projections , _ = nx .qr (Z )
344+ n_projections = projections .shape [0 ]
340345
341346 # Projection on S^1
342347 # Projection on plane
343- Xps = nx .transpose ( nx . reshape ( nx . dot ( nx . transpose ( projections , ( 0 , 2 , 1 ))[:, None ] , X_s [:, :, None ]), ( n_projections , 2 , n )), ( 0 , 2 , 1 ) )
344- Xpt = nx .transpose ( nx . reshape ( nx . dot ( nx . transpose ( projections , ( 0 , 2 , 1 ))[:, None ] , X_t [:, :, None ]), ( n_projections , 2 , m )), ( 0 , 2 , 1 ) )
348+ Xps = nx .einsum ( "ikj, lk -> ilj" , projections , X_s )
349+ Xpt = nx .einsum ( "ikj, lk -> ilj" , projections , X_t )
345350
346351 # Projection on sphere
347352 Xps = Xps / nx .sqrt (nx .sum (Xps ** 2 , - 1 , keepdims = True ))
@@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
425430
426431 # Projection on S^1
427432 # Projection on plane
428- Xps = nx .transpose (nx .reshape (nx .dot (nx .transpose (projections , (0 , 2 , 1 ))[:, None ], X_s [:, :, None ]), (n_projections , 2 , n )), (0 , 2 , 1 ))
433+ Xps = nx .einsum ("ikj, lk -> ilj" , projections , X_s )
434+
429435 # Projection on sphere
430436 Xps = Xps / nx .sqrt (nx .sum (Xps ** 2 , - 1 , keepdims = True ))
437+
431438 # Get coordinates on [0,1[
432439 Xps_coords = nx .reshape (get_coordinate_circle (nx .reshape (Xps , (- 1 , 2 ))), (n_projections , n ))
433440
0 commit comments