Skip to content
40 changes: 27 additions & 13 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,19 +2742,33 @@ def round(x, decimals=0):

def tile(x, repeats):
x = convert_to_tensor(x)
repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
repeats_size = tf.size(repeats)
repeats = tf.pad(
repeats,
[[tf.maximum(x.shape.rank - repeats_size, 0), 0]],
constant_values=1,
)
x_shape = tf.pad(
tf.shape(x),
[[tf.maximum(repeats_size - x.shape.rank, 0), 0]],
constant_values=1,
)
x = tf.reshape(x, x_shape)

# Convert repeats to a list (works for both sequences and 1D tensors)
if isinstance(repeats, int):
repeats = [repeats]
else:
repeats = [v for v in repeats]

# Process list elements: convert concrete scalar tensors to Python ints
processed_repeats = []
for r in repeats:
if hasattr(r, "numpy") and r.shape == ():
processed_repeats.append(int(r.numpy()))
else:
processed_repeats.append(r)
repeats = processed_repeats

# Get x rank
x_rank = x.shape.rank

# Pad repeats if needed
if len(repeats) < x_rank:
repeats = [1] * (x_rank - len(repeats)) + repeats

# Add dimensions to x if needed using tf.expand_dims
while len(repeats) > x.shape.rank:
x = tf.expand_dims(x, 0)

return tf.tile(x, repeats)


Expand Down
9 changes: 6 additions & 3 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6411,17 +6411,20 @@ def compute_output_spec(self, x):
repeats = self.repeats
if isinstance(repeats, int):
repeats = [repeats]
else:
repeats = list(repeats)

if len(x_shape) > len(repeats):
repeats = [1] * (len(x_shape) - len(repeats)) + repeats
else:
x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape

output_shape = []
for x_size, repeat in zip(x_shape, repeats):
if x_size is None:
output_shape.append(None)
else:
if isinstance(x_size, int):
output_shape.append(x_size * repeat)
else:
output_shape.append(None)
return KerasTensor(output_shape, dtype=x.dtype)


Expand Down
21 changes: 21 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,10 @@ def test_tile(self):
self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6))
self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6))

# Test with multi-dimensional input
x = KerasTensor((None, 3, 2, 2))
self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2))

def test_trace(self):
x = KerasTensor((None, 3, None, 5))
self.assertEqual(knp.trace(x).shape, (None, 5))
Expand Down Expand Up @@ -9507,3 +9511,20 @@ def call(self, x):
model.compile(jit_compile=jit_compile)

model.predict(np.random.randn(1, 8))


class TileTest(testing.TestCase):
@pytest.mark.skipif(
keras.config.backend() == "openvino",
reason="`tile` is not supported with openvino backend",
)
def test_tile_shape_inference_in_layer(self):
class TileLayer(keras.layers.Layer):
def call(self, x):
repeats = [1, 2, 1, 1]
return knp.tile(x, repeats)

inputs = keras.Input(shape=(3, 2, 2))
output = TileLayer()(inputs)

self.assertEqual(output.shape, (None, 6, 2, 2))