1 //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_CODEGEN_PBQP_MATH_H
10 #define LLVM_CODEGEN_PBQP_MATH_H
12 #include "llvm/ADT/Hashing.h"
13 #include "llvm/ADT/STLExtras.h"
22 using PBQPNum
= float;
24 /// PBQP Vector class.
26 friend hash_code
hash_value(const Vector
&);
29 /// Construct a PBQP vector of the given size.
30 explicit Vector(unsigned Length
)
31 : Length(Length
), Data(std::make_unique
<PBQPNum
[]>(Length
)) {}
33 /// Construct a PBQP vector with initializer.
34 Vector(unsigned Length
, PBQPNum InitVal
)
35 : Length(Length
), Data(std::make_unique
<PBQPNum
[]>(Length
)) {
36 std::fill(Data
.get(), Data
.get() + Length
, InitVal
);
39 /// Copy construct a PBQP vector.
40 Vector(const Vector
&V
)
41 : Length(V
.Length
), Data(std::make_unique
<PBQPNum
[]>(Length
)) {
42 std::copy(V
.Data
.get(), V
.Data
.get() + Length
, Data
.get());
45 /// Move construct a PBQP vector.
47 : Length(V
.Length
), Data(std::move(V
.Data
)) {
51 /// Comparison operator.
52 bool operator==(const Vector
&V
) const {
53 assert(Length
!= 0 && Data
&& "Invalid vector");
54 if (Length
!= V
.Length
)
56 return std::equal(Data
.get(), Data
.get() + Length
, V
.Data
.get());
59 /// Return the length of the vector
60 unsigned getLength() const {
61 assert(Length
!= 0 && Data
&& "Invalid vector");
66 PBQPNum
& operator[](unsigned Index
) {
67 assert(Length
!= 0 && Data
&& "Invalid vector");
68 assert(Index
< Length
&& "Vector element access out of bounds.");
72 /// Const element access.
73 const PBQPNum
& operator[](unsigned Index
) const {
74 assert(Length
!= 0 && Data
&& "Invalid vector");
75 assert(Index
< Length
&& "Vector element access out of bounds.");
79 /// Add another vector to this one.
80 Vector
& operator+=(const Vector
&V
) {
81 assert(Length
!= 0 && Data
&& "Invalid vector");
82 assert(Length
== V
.Length
&& "Vector length mismatch.");
83 std::transform(Data
.get(), Data
.get() + Length
, V
.Data
.get(), Data
.get(),
84 std::plus
<PBQPNum
>());
88 /// Returns the index of the minimum value in this vector
89 unsigned minIndex() const {
90 assert(Length
!= 0 && Data
&& "Invalid vector");
91 return std::min_element(Data
.get(), Data
.get() + Length
) - Data
.get();
96 std::unique_ptr
<PBQPNum
[]> Data
;
99 /// Return a hash_value for the given vector.
100 inline hash_code
hash_value(const Vector
&V
) {
101 unsigned *VBegin
= reinterpret_cast<unsigned*>(V
.Data
.get());
102 unsigned *VEnd
= reinterpret_cast<unsigned*>(V
.Data
.get() + V
.Length
);
103 return hash_combine(V
.Length
, hash_combine_range(VBegin
, VEnd
));
106 /// Output a textual representation of the given vector on the given
108 template <typename OStream
>
109 OStream
& operator<<(OStream
&OS
, const Vector
&V
) {
110 assert((V
.getLength() != 0) && "Zero-length vector badness.");
113 for (unsigned i
= 1; i
< V
.getLength(); ++i
)
120 /// PBQP Matrix class
123 friend hash_code
hash_value(const Matrix
&);
126 /// Construct a PBQP Matrix with the given dimensions.
127 Matrix(unsigned Rows
, unsigned Cols
) :
128 Rows(Rows
), Cols(Cols
), Data(std::make_unique
<PBQPNum
[]>(Rows
* Cols
)) {
131 /// Construct a PBQP Matrix with the given dimensions and initial
133 Matrix(unsigned Rows
, unsigned Cols
, PBQPNum InitVal
)
134 : Rows(Rows
), Cols(Cols
),
135 Data(std::make_unique
<PBQPNum
[]>(Rows
* Cols
)) {
136 std::fill(Data
.get(), Data
.get() + (Rows
* Cols
), InitVal
);
139 /// Copy construct a PBQP matrix.
140 Matrix(const Matrix
&M
)
141 : Rows(M
.Rows
), Cols(M
.Cols
),
142 Data(std::make_unique
<PBQPNum
[]>(Rows
* Cols
)) {
143 std::copy(M
.Data
.get(), M
.Data
.get() + (Rows
* Cols
), Data
.get());
146 /// Move construct a PBQP matrix.
148 : Rows(M
.Rows
), Cols(M
.Cols
), Data(std::move(M
.Data
)) {
152 /// Comparison operator.
153 bool operator==(const Matrix
&M
) const {
154 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
155 if (Rows
!= M
.Rows
|| Cols
!= M
.Cols
)
157 return std::equal(Data
.get(), Data
.get() + (Rows
* Cols
), M
.Data
.get());
160 /// Return the number of rows in this matrix.
161 unsigned getRows() const {
162 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
166 /// Return the number of cols in this matrix.
167 unsigned getCols() const {
168 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
172 /// Matrix element access.
173 PBQPNum
* operator[](unsigned R
) {
174 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
175 assert(R
< Rows
&& "Row out of bounds.");
176 return Data
.get() + (R
* Cols
);
179 /// Matrix element access.
180 const PBQPNum
* operator[](unsigned R
) const {
181 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
182 assert(R
< Rows
&& "Row out of bounds.");
183 return Data
.get() + (R
* Cols
);
186 /// Returns the given row as a vector.
187 Vector
getRowAsVector(unsigned R
) const {
188 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
190 for (unsigned C
= 0; C
< Cols
; ++C
)
191 V
[C
] = (*this)[R
][C
];
195 /// Returns the given column as a vector.
196 Vector
getColAsVector(unsigned C
) const {
197 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
199 for (unsigned R
= 0; R
< Rows
; ++R
)
200 V
[R
] = (*this)[R
][C
];
204 /// Matrix transpose.
205 Matrix
transpose() const {
206 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
207 Matrix
M(Cols
, Rows
);
208 for (unsigned r
= 0; r
< Rows
; ++r
)
209 for (unsigned c
= 0; c
< Cols
; ++c
)
210 M
[c
][r
] = (*this)[r
][c
];
214 /// Add the given matrix to this one.
215 Matrix
& operator+=(const Matrix
&M
) {
216 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
217 assert(Rows
== M
.Rows
&& Cols
== M
.Cols
&&
218 "Matrix dimensions mismatch.");
219 std::transform(Data
.get(), Data
.get() + (Rows
* Cols
), M
.Data
.get(),
220 Data
.get(), std::plus
<PBQPNum
>());
224 Matrix
operator+(const Matrix
&M
) {
225 assert(Rows
!= 0 && Cols
!= 0 && Data
&& "Invalid matrix");
233 std::unique_ptr
<PBQPNum
[]> Data
;
236 /// Return a hash_code for the given matrix.
237 inline hash_code
hash_value(const Matrix
&M
) {
238 unsigned *MBegin
= reinterpret_cast<unsigned*>(M
.Data
.get());
240 reinterpret_cast<unsigned*>(M
.Data
.get() + (M
.Rows
* M
.Cols
));
241 return hash_combine(M
.Rows
, M
.Cols
, hash_combine_range(MBegin
, MEnd
));
244 /// Output a textual representation of the given matrix on the given
246 template <typename OStream
>
247 OStream
& operator<<(OStream
&OS
, const Matrix
&M
) {
248 assert((M
.getRows() != 0) && "Zero-row matrix badness.");
249 for (unsigned i
= 0; i
< M
.getRows(); ++i
)
250 OS
<< M
.getRowAsVector(i
) << "\n";
254 template <typename Metadata
>
255 class MDVector
: public Vector
{
257 MDVector(const Vector
&v
) : Vector(v
), md(*this) {}
258 MDVector(Vector
&&v
) : Vector(std::move(v
)), md(*this) { }
260 const Metadata
& getMetadata() const { return md
; }
266 template <typename Metadata
>
267 inline hash_code
hash_value(const MDVector
<Metadata
> &V
) {
268 return hash_value(static_cast<const Vector
&>(V
));
271 template <typename Metadata
>
272 class MDMatrix
: public Matrix
{
274 MDMatrix(const Matrix
&m
) : Matrix(m
), md(*this) {}
275 MDMatrix(Matrix
&&m
) : Matrix(std::move(m
)), md(*this) { }
277 const Metadata
& getMetadata() const { return md
; }
283 template <typename Metadata
>
284 inline hash_code
hash_value(const MDMatrix
<Metadata
> &M
) {
285 return hash_value(static_cast<const Matrix
&>(M
));
288 } // end namespace PBQP
289 } // end namespace llvm
291 #endif // LLVM_CODEGEN_PBQP_MATH_H