1 // Do the inner product of tensors.
6 static void inner_f(void);
29 if (istensor(p1
) && istensor(p2
))
35 tensor_times_scalar();
36 else if (istensor(p2
))
37 scalar_times_tensor();
44 // inner product of tensors p1 and p2
49 int ak
, bk
, i
, j
, k
, n
, ndim
;
52 n
= p1
->u
.tensor
->dim
[p1
->u
.tensor
->ndim
- 1];
54 if (n
!= p2
->u
.tensor
->dim
[0])
55 stop("inner: tensor dimension check");
57 ndim
= p1
->u
.tensor
->ndim
+ p2
->u
.tensor
->ndim
- 2;
60 stop("inner: rank of result exceeds maximum");
62 a
= p1
->u
.tensor
->elem
;
63 b
= p2
->u
.tensor
->elem
;
65 //---------------------------------------------------------------------
67 // ak is the number of rows in tensor A
69 // bk is the number of columns in tensor B
73 // A[3][3][4] B[4][4][3]
77 // 4 3 bk = 4 * 3 = 12
79 //---------------------------------------------------------------------
82 for (i
= 0; i
< p1
->u
.tensor
->ndim
- 1; i
++)
83 ak
*= p1
->u
.tensor
->dim
[i
];
86 for (i
= 1; i
< p2
->u
.tensor
->ndim
; i
++)
87 bk
*= p2
->u
.tensor
->dim
[i
];
89 p3
= alloc_tensor(ak
* bk
);
91 c
= p3
->u
.tensor
->elem
;
93 // new method copied from ginac
95 for (i
= 0; i
< ak
; i
++) {
96 for (j
= 0; j
< n
; j
++) {
97 if (iszero(a
[i
* n
+ j
]))
99 for (k
= 0; k
< bk
; k
++) {
105 c
[i
* bk
+ k
] = pop();
110 for (i
= 0; i
< ak
; i
++) {
111 for (j
= 0; j
< bk
; j
++) {
113 for (k
= 0; k
< n
; k
++) {
119 c
[i
* bk
+ j
] = pop();
123 //---------------------------------------------------------------------
125 // Note on understanding "k * bk + j"
127 // k * bk because each element of a column is bk locations apart
129 // + j because the beginnings of all columns are in the first bk
132 // Example: n = 2, bk = 6
134 // b111 <- 1st element of 1st column
135 // b112 <- 1st element of 2nd column
136 // b113 <- 1st element of 3rd column
137 // b121 <- 1st element of 4th column
138 // b122 <- 1st element of 5th column
139 // b123 <- 1st element of 6th column
141 // b211 <- 2nd element of 1st column
142 // b212 <- 2nd element of 2nd column
143 // b213 <- 2nd element of 3rd column
144 // b221 <- 2nd element of 4th column
145 // b222 <- 2nd element of 5th column
146 // b223 <- 2nd element of 6th column
148 //---------------------------------------------------------------------
151 push(p3
->u
.tensor
->elem
[0]);
153 p3
->u
.tensor
->ndim
= ndim
;
154 for (i
= 0; i
< p1
->u
.tensor
->ndim
- 1; i
++)
155 p3
->u
.tensor
->dim
[i
] = p1
->u
.tensor
->dim
[i
];
157 for (i
= 0; i
< p2
->u
.tensor
->ndim
- 1; i
++)
158 p3
->u
.tensor
->dim
[j
+ i
] = p2
->u
.tensor
->dim
[i
+ 1];
176 "inner(((a11,a12),(a21,a22)),(x1,x2))",
177 "(a11*x1+a12*x2,a21*x1+a22*x2)",
179 "inner((1,2),(3,4))",
182 "inner(inner((1,2),((3,4),(5,6))),(7,8))",
185 "inner((1,2),inner(((3,4),(5,6)),(7,8)))",
188 "inner((1,2),((3,4),(5,6)),(7,8))",
195 test(__FILE__
, s
, sizeof s
/ sizeof (char *));