|
25 | 25 | def batch_invert_permutation(permutations): |
26 | 26 | """Returns batched `tf.invert_permutation` for every row in `permutations`.""" |
27 | 27 | with tf.name_scope('batch_invert_permutation', values=[permutations]): |
28 | | - unpacked = tf.unstack(permutations) |
29 | | - inverses = [tf.invert_permutation(permutation) for permutation in unpacked] |
30 | | - return tf.stack(inverses) |
| 28 | + perm = tf.cast(permutations, tf.float32) |
| 29 | + dim = int(perm.get_shape()[-1]) |
| 30 | + size = tf.cast(tf.shape(perm)[0], tf.float32) |
| 31 | + delta = tf.cast(tf.shape(perm)[-1], tf.float32) |
| 32 | + rg = tf.range(0, size * delta, delta, dtype=tf.float32) |
| 33 | + rg = tf.expand_dims(rg, 1) |
| 34 | + rg = tf.tile(rg, [1, dim]) |
| 35 | + perm = tf.add(perm, rg) |
| 36 | + flat = tf.reshape(perm, [-1]) |
| 37 | + perm = tf.invert_permutation(tf.cast(flat, tf.int32)) |
| 38 | + perm = tf.reshape(perm, [-1, dim]) |
| 39 | + return tf.subtract(perm, tf.cast(rg, tf.int32)) |
31 | 40 |
|
32 | 41 |
|
33 | 42 | def batch_gather(values, indices): |
34 | 43 | """Returns batched `tf.gather` for every row in the input.""" |
35 | 44 | with tf.name_scope('batch_gather', values=[values, indices]): |
36 | | - unpacked = zip(tf.unstack(values), tf.unstack(indices)) |
37 | | - result = [tf.gather(value, index) for value, index in unpacked] |
38 | | - return tf.stack(result) |
| 45 | + idx = tf.expand_dims(indices, -1) |
| 46 | + size = tf.shape(indices)[0] |
| 47 | + rg = tf.range(size, dtype=tf.int32) |
| 48 | + rg = tf.expand_dims(rg, -1) |
| 49 | + rg = tf.tile(rg, [1, int(indices.get_shape()[-1])]) |
| 50 | + rg = tf.expand_dims(rg, -1) |
| 51 | + gidx = tf.concat([rg, idx], -1) |
| 52 | + return tf.gather_nd(values, gidx) |
39 | 53 |
|
40 | 54 |
|
41 | 55 | def one_hot(length, index): |
42 | 56 | """Return an nd array of given `length` filled with 0s and a 1 at `index`.""" |
43 | 57 | result = np.zeros(length) |
44 | 58 | result[index] = 1 |
45 | 59 | return result |
| 60 | + |
| 61 | +def reduce_prod(x, axis, name=None): |
| 62 | + """Efficient reduce product over axis. |
| 63 | +
|
| 64 | + Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU. |
| 65 | + """ |
| 66 | + with tf.name_scope(name, 'util_reduce_prod', values=[x]): |
| 67 | + cp = tf.cumprod(x, axis, reverse=True) |
| 68 | + size = tf.shape(cp)[0] |
| 69 | + idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32) |
| 70 | + idx2 = tf.zeros([size], tf.float32) |
| 71 | + indices = tf.stack([idx1, idx2], 1) |
| 72 | + return tf.gather_nd(cp, tf.cast(indices, tf.int32)) |
0 commit comments