Сортировка и слияние двух тензоров с тензорным потоком

Felipe Moser спросил: 31 июля 2018 в 09:46 в: python-3.x

У меня есть два тензора A и B, каждый из которых равен Nx3. У меня также есть булевский тензор C = Nx1. Я хочу использовать это логическое значение для объединения A и B, следующих за значения C. Например:

A = [[ a1, a2, a3],
     [ a4, a5, a6],
     [ a7, a8, a9]]B = [[ b1, b2, b3],
     [ b4, b5, b6],
     [ b7, b8, b9]]C = [True, True, False]

Я хочу получить что-то вроде этого:

D = [[[ a1, a2, a3],
      [ b1, b2, b3]],     [[ a4, a5, a6],
      [ b4, b5, b6]],     [[ b7, b8, b9],
      [ a7, a8, a9]]]

Как вы можете видеть, оба тензора были объединены и порядок первой строки определяется значениями C.

Два тензора Nx3 A и B объединяются в один тензор D = Nx2x3, объединяя строки исходных тензоров. Порядок, в котором они добавляются к тензору Nx2x3, зависит от булевого тензора C, т. Е.

, если C [i] = True, D [i, 0 ,: ] = A [i ,:] и D [i, 1,:] = B [i,:]. Если C [i] = False, то D [i, 0,:] = B [i,:] и D [i, 1,:] = A [i,:]

Я уверен, что есть простой подход к этому, но я не смог понять это.


2 ответа

Felipe Moser ответил: 31 июля 2018 в 11:45

Поэтому я нашел решение, хотя оно может быть неэффективным. В случае, если у кого-то еще есть такая же проблема, я заставил ее работать следующим образом (используя те же A, B, C и D, как указано выше):

row_num = tf.cast(tf.reshape(tf.range(A.shape[0]), [-1, 1]), tf.int64)
AB = tf.concat([tf.expand_dims(A, 1), tf.expand_dims(B, 1)], axis=1)
filt_top = tf.reshape(tf.cast(C, tf.int64), [-1, 1])
filt_bottom = tf.reshape(tf.cast(tf.logical_not(C), tf.int64), [-1, 1])
filt = tf.concat([row_num, filt_top, filt_bottom], axis=1)
D_top = tf.map_fn(lambda x: AB[x[0], x[1], :], filt, dtype=tf.float32)
D_bottom = tf.map_fn(lambda x: AB[x[0], x[2], :], filt, dtype=tf.float32)
D = tf.concat([D_top, D_bottom], axis=1)
yann ответил: 03 августа 2018 в 07:41

Если A, B и C являются массивом Numpy, вы можете объединить их таким образом, используя технологию индексирования маскирования Numpy:

D = np.zeros((N,2,3))
D[C,0,:] = A[C]
D[~C,0,:] = B[~C]
D[~C,1,:] = A[~C]
D[C,1,:] = B[C]   

На самом деле, C не обязательно должен быть массивом Numpy. Список хорошо подходит для индексации массива в Python.