3 enum { Aligned, RowMajor };
4 enum { ReadOnlyAccessors };
5 template <typename> struct K {
8 template <typename> struct traits;
9 template <typename T> struct traits<const T> : traits<T> {};
11 enum { has_write_access, value };
13 template <typename, int n> class array {
15 int operator[](unsigned long p1) { return values[p1]; }
18 template <typename> struct I;
19 template <typename, int, template <class> class = I> class M;
20 template <typename, int, int, typename> class J;
21 template <typename, int> class N;
22 template <typename, typename> class D;
23 template <typename, typename, typename, typename> class TensorContractionOp;
24 template <long, typename> class TensorChippingOp;
26 template <typename DenseIndex, int NumDims>
27 struct K<array<DenseIndex, NumDims>> {
28 static const long value = NumDims;
30 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType_>
31 struct traits<J<Scalar_, NumIndices_, Options_, IndexType_>> {
32 typedef IndexType_ Index;
34 template <typename PlainObjectType, int Options_,
35 template <class> class MakePointer_>
36 struct traits<M<PlainObjectType, Options_, MakePointer_>>
37 : traits<PlainObjectType> {};
38 template <typename T> struct B { typedef T type; };
39 template <typename Derived> class N<Derived, ReadOnlyAccessors> {
41 typedef typename traits<Derived>::Index Index;
42 D<int, Derived> m_fn1();
43 template <typename OtherDerived, typename Dimensions>
44 TensorContractionOp<Dimensions, Derived, const OtherDerived, int>
45 m_fn2(OtherDerived, Dimensions);
46 template <Index> TensorChippingOp<1, Derived> m_fn3(Index);
48 template <typename Derived, int = A::value>
49 class N : public N<Derived, ReadOnlyAccessors> {
51 template <typename DeviceType> C m_fn4(DeviceType);
53 template <typename, typename> struct TensorEvaluator;
54 template <typename UnaryOp, typename ArgType, typename Device>
55 struct TensorEvaluator<const D<UnaryOp, ArgType>, Device> {
56 TensorEvaluator(D<UnaryOp, ArgType>, Device);
58 template <typename, typename> class D {
60 typedef typename B<D>::type Nested;
62 template <typename Indices_, typename LeftArgType_, typename RightArgType_,
63 typename OutputKernelType_, typename Device_>
65 TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
66 RightArgType_, OutputKernelType_>,
68 typedef Indices_ Indices;
69 typedef LeftArgType_ LeftArgType;
70 typedef RightArgType_ RightArgType;
71 typedef OutputKernelType_ OutputKernelType;
72 typedef Device_ Device;
74 template <typename, typename LhsXprType, typename RhsXprType, typename>
75 class TensorContractionOp {
77 typedef typename B<TensorContractionOp>::type Nested;
78 typename LhsXprType::Nested m_fn5();
79 typename RhsXprType::Nested m_fn6();
81 template <typename Derived> struct TensorContractionEvaluatorBase {
82 typedef typename traits<Derived>::LeftArgType LeftArgType;
83 typedef typename traits<Derived>::RightArgType RightArgType;
84 typedef typename traits<Derived>::Device Device;
85 TensorContractionEvaluatorBase(
86 TensorContractionOp<typename traits<Derived>::Indices, LeftArgType,
88 typename traits<Derived>::OutputKernelType>
91 : m_leftImpl(p1.m_fn6(), p2), m_rightImpl(p1.m_fn5(), p2) {
96 if (nocontract_idx < K<int>::value)
97 m_j_size = m_j_strides[nocontract_idx];
102 array<long, 1> m_j_strides;
104 TensorEvaluator<RightArgType, Device> m_leftImpl;
105 TensorEvaluator<LeftArgType, Device> m_rightImpl;
107 template <typename Indices, typename LeftArgType, typename RightArgType,
108 typename OutputKernelType, typename Device>
109 struct TensorEvaluator<
110 const TensorContractionOp<Indices, LeftArgType, RightArgType,
113 : TensorContractionEvaluatorBase<TensorEvaluator<
114 const TensorContractionOp<Indices, LeftArgType, RightArgType,
117 typedef TensorEvaluator Self;
118 typedef TensorContractionEvaluatorBase<Self> Base;
120 TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
125 template <long DimId, typename XprType>
126 struct traits<TensorChippingOp<DimId, XprType>> : traits<XprType> {};
127 template <long, typename XprType>
128 class TensorChippingOp : public N<TensorChippingOp<1, XprType>> {
130 typedef typename B<TensorChippingOp>::type Nested;
132 template <long DimId, typename ArgType, typename Device>
133 struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
134 static const int NumInputDims = K<typename ArgType::Dimensions>::value;
135 array<long, NumInputDims> m_dimensions;
137 template <long DimId, typename ArgType, typename Device>
138 struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device>
139 : TensorEvaluator<const TensorChippingOp<1, ArgType>, Device> {
140 TensorEvaluator(TensorChippingOp<DimId, ArgType>, Device);
142 template <typename, typename RhsXprType> class TensorAssignOp {
144 TensorAssignOp(TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
146 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_fn7();
147 typename RhsXprType::Nested m_fn8();
149 template <typename LeftArgType, typename RightArgType, typename Device>
150 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>,
152 TensorEvaluator(TensorAssignOp<LeftArgType, RightArgType> p1, Device p2)
153 : m_leftImpl(p1.m_fn7(), p2), m_rightImpl(p1.m_fn8(), p2) {}
154 TensorEvaluator<LeftArgType, Device> m_leftImpl;
155 TensorEvaluator<RightArgType, Device> m_rightImpl;
157 template <typename Expression> class F {
159 static void m_fn9(Expression p1) {
161 TensorEvaluator<Expression, int>(p1, device);
167 operator=(TensorContractionOp<array<int, 1>,
168 TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
169 const D<int, M<J<float, 3, 1, int>, 0>>, int>
172 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
173 const TensorContractionOp<
174 array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
175 const D<int, M<J<float, 3, 1, int>, 0>>, int>>
176 assign(m_expression, p1);
177 F<const TensorAssignOp<
178 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
179 const TensorContractionOp<
180 array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
181 const D<int, M<J<float, 3, 1, int>, 0>>, int>>>::m_fn9(assign);
183 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_expression;
185 template <typename, int NumIndices_, int, typename> class J {
187 typedef array<long, NumIndices_> Dimensions;
189 template <typename PlainObjectType, int Options_, template <class> class>
190 class M : public N<M<PlainObjectType, Options_>> {
192 typedef typename PlainObjectType::Dimensions Dimensions;
194 template <int NDIMS> struct TTypes {
195 typedef M<J<float, NDIMS, RowMajor, int>, Aligned> ConstTensor;
199 template <typename, long NDIMS> typename TTypes<NDIMS>::ConstTensor m_fn10();
212 int BatchMatMul_context;
213 O() : H(&BatchMatMul_context) {
215 auto Tx = in_x.m_fn10<float, 3>(), Ty = in_y.m_fn10<float, 3>(),
216 Tz = out.m_fn10<float, 3>(), z = Tz;
217 array<int, 1> contract_pairs;
218 auto x = Tx.m_fn3<0>(0);
220 z.m_fn4(Run_d) = x.m_fn2(y, contract_pairs);
223 G registrar__body__0__object([](int *) -> H * { O(); return 0; });