I have 3
Xshape tensors (1, c, h, w), suppose a (1, 20, 40, 50)
Fxshape (num, w, N), suppose a (1000, 50, 10)
Fyshape (num, N, h), suppose(1000, 10, 40)
What I want to do is Fy * (X * Fx)( *mean matmul)
X * Fxshape (num, c, h, N), suppose (1000, 20, 40, 10)
Fy * (X * Fx)shape (num, c, N, N), suppose(1000, 20, 10, 10)
I use tf.tileand tf.expand_dimsto do this,
but I think it uses a lot of memory ( tilecopy data correctly?), And try to find a better way to speed things up and use a small memory to execute
X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1])
Fx_ex = tf.expand_dims(Fx, axis=1)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1])
tmp = tf.matmul(X, Fxt_ex)
Fy_ex = tf.expand_dims(Fy, axis=1)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1])
res = tf.matmul(Fy_ex, tmp)