@@ -971,6 +971,49 @@ public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose",
971971 } ) ;
972972 }
973973
974+ /// <summary>
975+ /// Transposes last two dimensions of tensor `a`.
976+ /// For example:
977+ /// <code> python
978+ /// x = tf.constant([[1, 2, 3], [4, 5, 6]])
979+ /// tf.matrix_transpose(x) # [[1, 4],
980+ /// # [2, 5],
981+ /// # [3, 6]]
982+ /// </code>
983+ /// Matrix with two batch dimensions.
984+ /// x.shape is [1, 2, 3, 4]
985+ /// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
986+ /// </summary>
987+ /// <param name="a"></param>
988+ /// <param name="name"></param>
989+ /// <param name="conjugate"></param>
990+ /// <returns></returns>
991+ /// <exception cref="ValueError"></exception>
992+ public static Tensor matrix_transpose ( Tensor a , string name = "matrix_transpose" , bool conjugate = false )
993+ {
994+ return tf_with ( ops . name_scope ( name , "transpose" , new { a } ) , scope =>
995+ {
996+ var a_shape = a . shape ;
997+ var ndims = a . shape . ndim ;
998+ Axis perm ;
999+ if ( ndims != 0 )
1000+ {
1001+ if ( ndims < 2 )
1002+ {
1003+ throw new ValueError ( "Argument `a` should be a (batch) matrix with rank " +
1004+ $ ">= 2. Received `a` = { a } with shape: { a_shape } ") ;
1005+ }
1006+ perm = new Axis ( Enumerable . Range ( 0 , ndims - 2 ) . Concat ( new int [ ] { ndims - 1 , ndims - 2 } ) . ToArray ( ) ) ;
1007+ }
1008+ else
1009+ {
1010+ var a_rank = a . rank ;
1011+ perm = new Axis ( Enumerable . Range ( 0 , a_rank - 2 ) . Concat ( new int [ ] { a_rank - 1 , a_rank - 2 } ) . ToArray ( ) ) ;
1012+ }
1013+ return transpose ( a , perm : perm , conjugate : conjugate ) ;
1014+ } ) ;
1015+ }
1016+
9741017 public static Tensor [ ] split ( Tensor value , Tensor size_splits , int axis , int num = - 1 ,
9751018 string name = "split" )
9761019 {
0 commit comments