1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines a SMT generic Solver API, which will be the base class
10 // for every SMT solver specific class.
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_SUPPORT_SMTAPI_H
15 #define LLVM_SUPPORT_SMTAPI_H
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/APSInt.h"
19 #include "llvm/ADT/FoldingSet.h"
20 #include "llvm/Support/raw_ostream.h"
25 /// Generic base class for SMT sorts
29 virtual ~SMTSort() = default;
31 /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
32 virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
34 /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
35 virtual bool isFloatSort() const { return isFloatSortImpl(); }
37 /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
38 virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
40 /// Returns the bitvector size, fails if the sort is not a bitvector
41 /// Calls getBitvectorSortSizeImpl().
42 virtual unsigned getBitvectorSortSize() const {
43 assert(isBitvectorSort() && "Not a bitvector sort!");
44 unsigned Size
= getBitvectorSortSizeImpl();
45 assert(Size
&& "Size is zero!");
49 /// Returns the floating-point size, fails if the sort is not a floating-point
50 /// Calls getFloatSortSizeImpl().
51 virtual unsigned getFloatSortSize() const {
52 assert(isFloatSort() && "Not a floating-point sort!");
53 unsigned Size
= getFloatSortSizeImpl();
54 assert(Size
&& "Size is zero!");
58 virtual void Profile(llvm::FoldingSetNodeID
&ID
) const = 0;
60 bool operator<(const SMTSort
&Other
) const {
61 llvm::FoldingSetNodeID ID1
, ID2
;
67 friend bool operator==(SMTSort
const &LHS
, SMTSort
const &RHS
) {
68 return LHS
.equal_to(RHS
);
71 virtual void print(raw_ostream
&OS
) const = 0;
73 LLVM_DUMP_METHOD
void dump() const;
76 /// Query the SMT solver and returns true if two sorts are equal (same kind
77 /// and bit width). This does not check if the two sorts are the same objects.
78 virtual bool equal_to(SMTSort
const &other
) const = 0;
80 /// Query the SMT solver and checks if a sort is bitvector.
81 virtual bool isBitvectorSortImpl() const = 0;
83 /// Query the SMT solver and checks if a sort is floating-point.
84 virtual bool isFloatSortImpl() const = 0;
86 /// Query the SMT solver and checks if a sort is boolean.
87 virtual bool isBooleanSortImpl() const = 0;
89 /// Query the SMT solver and returns the sort bit width.
90 virtual unsigned getBitvectorSortSizeImpl() const = 0;
92 /// Query the SMT solver and returns the sort bit width.
93 virtual unsigned getFloatSortSizeImpl() const = 0;
96 /// Shared pointer for SMTSorts, used by SMTSolver API.
97 using SMTSortRef
= const SMTSort
*;
99 /// Generic base class for SMT exprs
103 virtual ~SMTExpr() = default;
105 bool operator<(const SMTExpr
&Other
) const {
106 llvm::FoldingSetNodeID ID1
, ID2
;
112 virtual void Profile(llvm::FoldingSetNodeID
&ID
) const = 0;
114 friend bool operator==(SMTExpr
const &LHS
, SMTExpr
const &RHS
) {
115 return LHS
.equal_to(RHS
);
118 virtual void print(raw_ostream
&OS
) const = 0;
120 LLVM_DUMP_METHOD
void dump() const;
123 /// Query the SMT solver and returns true if two sorts are equal (same kind
124 /// and bit width). This does not check if the two sorts are the same objects.
125 virtual bool equal_to(SMTExpr
const &other
) const = 0;
128 /// Shared pointer for SMTExprs, used by SMTSolver API.
129 using SMTExprRef
= const SMTExpr
*;
131 /// Generic base class for SMT Solvers
133 /// This class is responsible for wrapping all sorts and expression generation,
134 /// through the mk* methods. It also provides methods to create SMT expressions
135 /// straight from clang's AST, through the from* methods.
138 SMTSolver() = default;
139 virtual ~SMTSolver() = default;
141 LLVM_DUMP_METHOD
void dump() const;
143 // Returns an appropriate floating-point sort for the given bitwidth.
144 SMTSortRef
getFloatSort(unsigned BitWidth
) {
147 return getFloat16Sort();
149 return getFloat32Sort();
151 return getFloat64Sort();
153 return getFloat128Sort();
156 llvm_unreachable("Unsupported floating-point bitwidth!");
159 // Returns a boolean sort.
160 virtual SMTSortRef
getBoolSort() = 0;
162 // Returns an appropriate bitvector sort for the given bitwidth.
163 virtual SMTSortRef
getBitvectorSort(const unsigned BitWidth
) = 0;
165 // Returns a floating-point sort of width 16
166 virtual SMTSortRef
getFloat16Sort() = 0;
168 // Returns a floating-point sort of width 32
169 virtual SMTSortRef
getFloat32Sort() = 0;
171 // Returns a floating-point sort of width 64
172 virtual SMTSortRef
getFloat64Sort() = 0;
174 // Returns a floating-point sort of width 128
175 virtual SMTSortRef
getFloat128Sort() = 0;
177 // Returns an appropriate sort for the given AST.
178 virtual SMTSortRef
getSort(const SMTExprRef
&AST
) = 0;
180 /// Given a constraint, adds it to the solver
181 virtual void addConstraint(const SMTExprRef
&Exp
) const = 0;
183 /// Creates a bitvector addition operation
184 virtual SMTExprRef
mkBVAdd(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
186 /// Creates a bitvector subtraction operation
187 virtual SMTExprRef
mkBVSub(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
189 /// Creates a bitvector multiplication operation
190 virtual SMTExprRef
mkBVMul(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
192 /// Creates a bitvector signed modulus operation
193 virtual SMTExprRef
mkBVSRem(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
195 /// Creates a bitvector unsigned modulus operation
196 virtual SMTExprRef
mkBVURem(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
198 /// Creates a bitvector signed division operation
199 virtual SMTExprRef
mkBVSDiv(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
201 /// Creates a bitvector unsigned division operation
202 virtual SMTExprRef
mkBVUDiv(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
204 /// Creates a bitvector logical shift left operation
205 virtual SMTExprRef
mkBVShl(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
207 /// Creates a bitvector arithmetic shift right operation
208 virtual SMTExprRef
mkBVAshr(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
210 /// Creates a bitvector logical shift right operation
211 virtual SMTExprRef
mkBVLshr(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
213 /// Creates a bitvector negation operation
214 virtual SMTExprRef
mkBVNeg(const SMTExprRef
&Exp
) = 0;
216 /// Creates a bitvector not operation
217 virtual SMTExprRef
mkBVNot(const SMTExprRef
&Exp
) = 0;
219 /// Creates a bitvector xor operation
220 virtual SMTExprRef
mkBVXor(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
222 /// Creates a bitvector or operation
223 virtual SMTExprRef
mkBVOr(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
225 /// Creates a bitvector and operation
226 virtual SMTExprRef
mkBVAnd(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
228 /// Creates a bitvector unsigned less-than operation
229 virtual SMTExprRef
mkBVUlt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
231 /// Creates a bitvector signed less-than operation
232 virtual SMTExprRef
mkBVSlt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
234 /// Creates a bitvector unsigned greater-than operation
235 virtual SMTExprRef
mkBVUgt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
237 /// Creates a bitvector signed greater-than operation
238 virtual SMTExprRef
mkBVSgt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
240 /// Creates a bitvector unsigned less-equal-than operation
241 virtual SMTExprRef
mkBVUle(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
243 /// Creates a bitvector signed less-equal-than operation
244 virtual SMTExprRef
mkBVSle(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
246 /// Creates a bitvector unsigned greater-equal-than operation
247 virtual SMTExprRef
mkBVUge(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
249 /// Creates a bitvector signed greater-equal-than operation
250 virtual SMTExprRef
mkBVSge(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
252 /// Creates a boolean not operation
253 virtual SMTExprRef
mkNot(const SMTExprRef
&Exp
) = 0;
255 /// Creates a boolean equality operation
256 virtual SMTExprRef
mkEqual(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
258 /// Creates a boolean and operation
259 virtual SMTExprRef
mkAnd(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
261 /// Creates a boolean or operation
262 virtual SMTExprRef
mkOr(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
264 /// Creates a boolean ite operation
265 virtual SMTExprRef
mkIte(const SMTExprRef
&Cond
, const SMTExprRef
&T
,
266 const SMTExprRef
&F
) = 0;
268 /// Creates a bitvector sign extension operation
269 virtual SMTExprRef
mkBVSignExt(unsigned i
, const SMTExprRef
&Exp
) = 0;
271 /// Creates a bitvector zero extension operation
272 virtual SMTExprRef
mkBVZeroExt(unsigned i
, const SMTExprRef
&Exp
) = 0;
274 /// Creates a bitvector extract operation
275 virtual SMTExprRef
mkBVExtract(unsigned High
, unsigned Low
,
276 const SMTExprRef
&Exp
) = 0;
278 /// Creates a bitvector concat operation
279 virtual SMTExprRef
mkBVConcat(const SMTExprRef
&LHS
,
280 const SMTExprRef
&RHS
) = 0;
282 /// Creates a predicate that checks for overflow in a bitvector addition
284 virtual SMTExprRef
mkBVAddNoOverflow(const SMTExprRef
&LHS
,
285 const SMTExprRef
&RHS
,
288 /// Creates a predicate that checks for underflow in a signed bitvector
289 /// addition operation
290 virtual SMTExprRef
mkBVAddNoUnderflow(const SMTExprRef
&LHS
,
291 const SMTExprRef
&RHS
) = 0;
293 /// Creates a predicate that checks for overflow in a signed bitvector
294 /// subtraction operation
295 virtual SMTExprRef
mkBVSubNoOverflow(const SMTExprRef
&LHS
,
296 const SMTExprRef
&RHS
) = 0;
298 /// Creates a predicate that checks for underflow in a bitvector subtraction
300 virtual SMTExprRef
mkBVSubNoUnderflow(const SMTExprRef
&LHS
,
301 const SMTExprRef
&RHS
,
304 /// Creates a predicate that checks for overflow in a signed bitvector
305 /// division/modulus operation
306 virtual SMTExprRef
mkBVSDivNoOverflow(const SMTExprRef
&LHS
,
307 const SMTExprRef
&RHS
) = 0;
309 /// Creates a predicate that checks for overflow in a bitvector negation
311 virtual SMTExprRef
mkBVNegNoOverflow(const SMTExprRef
&Exp
) = 0;
313 /// Creates a predicate that checks for overflow in a bitvector multiplication
315 virtual SMTExprRef
mkBVMulNoOverflow(const SMTExprRef
&LHS
,
316 const SMTExprRef
&RHS
,
319 /// Creates a predicate that checks for underflow in a signed bitvector
320 /// multiplication operation
321 virtual SMTExprRef
mkBVMulNoUnderflow(const SMTExprRef
&LHS
,
322 const SMTExprRef
&RHS
) = 0;
324 /// Creates a floating-point negation operation
325 virtual SMTExprRef
mkFPNeg(const SMTExprRef
&Exp
) = 0;
327 /// Creates a floating-point isInfinite operation
328 virtual SMTExprRef
mkFPIsInfinite(const SMTExprRef
&Exp
) = 0;
330 /// Creates a floating-point isNaN operation
331 virtual SMTExprRef
mkFPIsNaN(const SMTExprRef
&Exp
) = 0;
333 /// Creates a floating-point isNormal operation
334 virtual SMTExprRef
mkFPIsNormal(const SMTExprRef
&Exp
) = 0;
336 /// Creates a floating-point isZero operation
337 virtual SMTExprRef
mkFPIsZero(const SMTExprRef
&Exp
) = 0;
339 /// Creates a floating-point multiplication operation
340 virtual SMTExprRef
mkFPMul(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
342 /// Creates a floating-point division operation
343 virtual SMTExprRef
mkFPDiv(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
345 /// Creates a floating-point remainder operation
346 virtual SMTExprRef
mkFPRem(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
348 /// Creates a floating-point addition operation
349 virtual SMTExprRef
mkFPAdd(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
351 /// Creates a floating-point subtraction operation
352 virtual SMTExprRef
mkFPSub(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
354 /// Creates a floating-point less-than operation
355 virtual SMTExprRef
mkFPLt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
357 /// Creates a floating-point greater-than operation
358 virtual SMTExprRef
mkFPGt(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
360 /// Creates a floating-point less-than-or-equal operation
361 virtual SMTExprRef
mkFPLe(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
363 /// Creates a floating-point greater-than-or-equal operation
364 virtual SMTExprRef
mkFPGe(const SMTExprRef
&LHS
, const SMTExprRef
&RHS
) = 0;
366 /// Creates a floating-point equality operation
367 virtual SMTExprRef
mkFPEqual(const SMTExprRef
&LHS
,
368 const SMTExprRef
&RHS
) = 0;
370 /// Creates a floating-point conversion from floatint-point to floating-point
372 virtual SMTExprRef
mkFPtoFP(const SMTExprRef
&From
, const SMTSortRef
&To
) = 0;
374 /// Creates a floating-point conversion from signed bitvector to
375 /// floatint-point operation
376 virtual SMTExprRef
mkSBVtoFP(const SMTExprRef
&From
,
377 const SMTSortRef
&To
) = 0;
379 /// Creates a floating-point conversion from unsigned bitvector to
380 /// floatint-point operation
381 virtual SMTExprRef
mkUBVtoFP(const SMTExprRef
&From
,
382 const SMTSortRef
&To
) = 0;
384 /// Creates a floating-point conversion from floatint-point to signed
385 /// bitvector operation
386 virtual SMTExprRef
mkFPtoSBV(const SMTExprRef
&From
, unsigned ToWidth
) = 0;
388 /// Creates a floating-point conversion from floatint-point to unsigned
389 /// bitvector operation
390 virtual SMTExprRef
mkFPtoUBV(const SMTExprRef
&From
, unsigned ToWidth
) = 0;
392 /// Creates a new symbol, given a name and a sort
393 virtual SMTExprRef
mkSymbol(const char *Name
, SMTSortRef Sort
) = 0;
395 // Returns an appropriate floating-point rounding mode.
396 virtual SMTExprRef
getFloatRoundingMode() = 0;
398 // If the a model is available, returns the value of a given bitvector symbol
399 virtual llvm::APSInt
getBitvector(const SMTExprRef
&Exp
, unsigned BitWidth
,
400 bool isUnsigned
) = 0;
402 // If the a model is available, returns the value of a given boolean symbol
403 virtual bool getBoolean(const SMTExprRef
&Exp
) = 0;
405 /// Constructs an SMTExprRef from a boolean.
406 virtual SMTExprRef
mkBoolean(const bool b
) = 0;
408 /// Constructs an SMTExprRef from a finite APFloat.
409 virtual SMTExprRef
mkFloat(const llvm::APFloat Float
) = 0;
411 /// Constructs an SMTExprRef from an APSInt and its bit width
412 virtual SMTExprRef
mkBitvector(const llvm::APSInt Int
, unsigned BitWidth
) = 0;
414 /// Given an expression, extract the value of this operand in the model.
415 virtual bool getInterpretation(const SMTExprRef
&Exp
, llvm::APSInt
&Int
) = 0;
417 /// Given an expression extract the value of this operand in the model.
418 virtual bool getInterpretation(const SMTExprRef
&Exp
,
419 llvm::APFloat
&Float
) = 0;
421 /// Check if the constraints are satisfiable
422 virtual Optional
<bool> check() const = 0;
424 /// Push the current solver state
425 virtual void push() = 0;
427 /// Pop the previous solver state
428 virtual void pop(unsigned NumStates
= 1) = 0;
430 /// Reset the solver and remove all constraints.
431 virtual void reset() = 0;
433 /// Checks if the solver supports floating-points.
434 virtual bool isFPSupported() = 0;
436 virtual void print(raw_ostream
&OS
) const = 0;
439 /// Shared pointer for SMTSolvers.
440 using SMTSolverRef
= std::shared_ptr
<SMTSolver
>;
442 /// Convenience method to create and Z3Solver object
443 SMTSolverRef
CreateZ3Solver();