Tensors#

diag_part_axis(tensor, axis[, offset])

Extracts the batched diagonal part of a batched tensor over the specified axis.

enumerate_indices(bounds)

Enumerates all indices between 0 (included) and bounds (excluded) in lexicographic order.

expand_to_rank(tensor, target_rank[, axis])

Inserts as many axes to a tensor as needed to achieve a desired rank.

find_true_position(bool_tensor[, side, axis])

Finds the index of the first or last True value along the specified axis.

flatten_dims(tensor, num_dims, axis)

Flattens a specified set of dimensions of a tensor.

flatten_last_dims(tensor[, num_dims])

Flattens the last n dimensions of a tensor.

flatten_multi_index(indices, shape)

Converts a tensor of index arrays into a tensor of flat indices.

gather_from_batched_indices(params, indices)

Gathers the values of a tensor params according to batch-specific indices.

insert_dims(tensor, num_dims[, axis])

Adds multiple length-one dimensions to a tensor.

random_tensor_from_values(values, shape[, dtype])

Generates a tensor of the specified shape, with elements randomly sampled from the provided set of values.

split_dim(tensor, shape, axis)

Reshapes a dimension of a tensor into multiple dimensions.

tensor_values_are_in_set(tensor, admissible_set)

Checks if the input tensor values are contained in the specified admissible_set.