1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
18 #include <condition_variable>
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
28 using namespace mlir::runtime
;
30 //===----------------------------------------------------------------------===//
32 //===----------------------------------------------------------------------===//
38 // Forward declare class defined below.
41 // -------------------------------------------------------------------------- //
42 // AsyncRuntime orchestrates all async operations and Async runtime API is built
43 // on top of the default runtime instance.
44 // -------------------------------------------------------------------------- //
48 AsyncRuntime() : numRefCountedObjects(0) {}
51 threadPool
.wait(); // wait for the completion of all async tasks
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects
.load(std::memory_order_relaxed
);
60 llvm::ThreadPoolInterface
&getThreadPool() { return threadPool
; }
63 friend class RefCounted
;
65 // Count the total number of reference counted objects in this instance
66 // of an AsyncRuntime. For debugging purposes only.
67 void addNumRefCountedObjects() {
68 numRefCountedObjects
.fetch_add(1, std::memory_order_relaxed
);
70 void dropNumRefCountedObjects() {
71 numRefCountedObjects
.fetch_sub(1, std::memory_order_relaxed
);
74 std::atomic
<int64_t> numRefCountedObjects
;
75 llvm::DefaultThreadPool threadPool
;
78 // -------------------------------------------------------------------------- //
79 // A state of the async runtime value (token, value or group).
80 // -------------------------------------------------------------------------- //
84 enum StateEnum
: int8_t {
85 // The underlying value is not yet available for consumption.
87 // The underlying value is available for consumption. This state can not
88 // transition to any other state.
90 // This underlying value is available and contains an error. This state can
91 // not transition to any other state.
95 /* implicit */ State(StateEnum s
) : state(s
) {}
96 /* implicit */ operator StateEnum() { return state
; }
98 bool isUnavailable() const { return state
== kUnavailable
; }
99 bool isAvailable() const { return state
== kAvailable
; }
100 bool isError() const { return state
== kError
; }
101 bool isAvailableOrError() const { return isAvailable() || isError(); }
103 const char *debug() const {
106 return "unavailable";
118 // -------------------------------------------------------------------------- //
119 // A base class for all reference counted objects created by the async runtime.
120 // -------------------------------------------------------------------------- //
124 RefCounted(AsyncRuntime
*runtime
, int64_t refCount
= 1)
125 : runtime(runtime
), refCount(refCount
) {
126 runtime
->addNumRefCountedObjects();
129 virtual ~RefCounted() {
130 assert(refCount
.load() == 0 && "reference count must be zero");
131 runtime
->dropNumRefCountedObjects();
134 RefCounted(const RefCounted
&) = delete;
135 RefCounted
&operator=(const RefCounted
&) = delete;
137 void addRef(int64_t count
= 1) { refCount
.fetch_add(count
); }
139 void dropRef(int64_t count
= 1) {
140 int64_t previous
= refCount
.fetch_sub(count
);
141 assert(previous
>= count
&& "reference count should not go below zero");
142 if (previous
== count
)
147 virtual void destroy() { delete this; }
150 AsyncRuntime
*runtime
;
151 std::atomic
<int64_t> refCount
;
156 // Returns the default per-process instance of an async runtime.
157 static std::unique_ptr
<AsyncRuntime
> &getDefaultAsyncRuntimeInstance() {
158 static auto runtime
= std::make_unique
<AsyncRuntime
>();
162 static void resetDefaultAsyncRuntime() {
163 return getDefaultAsyncRuntimeInstance().reset();
166 static AsyncRuntime
*getDefaultAsyncRuntime() {
167 return getDefaultAsyncRuntimeInstance().get();
170 // Async token provides a mechanism to signal asynchronous operation completion.
171 struct AsyncToken
: public RefCounted
{
172 // AsyncToken created with a reference count of 2 because it will be returned
173 // to the `async.execute` caller and also will be later on emplaced by the
174 // asynchronously executed task. If the caller immediately will drop its
175 // reference we must ensure that the token will be alive until the
176 // asynchronous operation is completed.
177 AsyncToken(AsyncRuntime
*runtime
)
178 : RefCounted(runtime
, /*refCount=*/2), state(State::kUnavailable
) {}
180 std::atomic
<State::StateEnum
> state
;
182 // Pending awaiters are guarded by a mutex.
184 std::condition_variable cv
;
185 std::vector
<std::function
<void()>> awaiters
;
188 // Async value provides a mechanism to access the result of asynchronous
189 // operations. It owns the storage that is used to store/load the value of the
190 // underlying type, and a flag to signal if the value is ready or not.
191 struct AsyncValue
: public RefCounted
{
192 // AsyncValue similar to an AsyncToken created with a reference count of 2.
193 AsyncValue(AsyncRuntime
*runtime
, int64_t size
)
194 : RefCounted(runtime
, /*refCount=*/2), state(State::kUnavailable
),
197 std::atomic
<State::StateEnum
> state
;
199 // Use vector of bytes to store async value payload.
200 std::vector
<std::byte
> storage
;
202 // Pending awaiters are guarded by a mutex.
204 std::condition_variable cv
;
205 std::vector
<std::function
<void()>> awaiters
;
208 // Async group provides a mechanism to group together multiple async tokens or
209 // values to await on all of them together (wait for the completion of all
210 // tokens or values added to the group).
211 struct AsyncGroup
: public RefCounted
{
212 AsyncGroup(AsyncRuntime
*runtime
, int64_t size
)
213 : RefCounted(runtime
), pendingTokens(size
), numErrors(0), rank(0) {}
215 std::atomic
<int> pendingTokens
;
216 std::atomic
<int> numErrors
;
217 std::atomic
<int> rank
;
219 // Pending awaiters are guarded by a mutex.
221 std::condition_variable cv
;
222 std::vector
<std::function
<void()>> awaiters
;
225 // Adds references to reference counted runtime object.
226 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr
, int64_t count
) {
227 RefCounted
*refCounted
= static_cast<RefCounted
*>(ptr
);
228 refCounted
->addRef(count
);
231 // Drops references from reference counted runtime object.
232 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr
, int64_t count
) {
233 RefCounted
*refCounted
= static_cast<RefCounted
*>(ptr
);
234 refCounted
->dropRef(count
);
237 // Creates a new `async.token` in not-ready state.
238 extern "C" AsyncToken
*mlirAsyncRuntimeCreateToken() {
239 AsyncToken
*token
= new AsyncToken(getDefaultAsyncRuntime());
243 // Creates a new `async.value` in not-ready state.
244 extern "C" AsyncValue
*mlirAsyncRuntimeCreateValue(int64_t size
) {
245 AsyncValue
*value
= new AsyncValue(getDefaultAsyncRuntime(), size
);
249 // Create a new `async.group` in empty state.
250 extern "C" AsyncGroup
*mlirAsyncRuntimeCreateGroup(int64_t size
) {
251 AsyncGroup
*group
= new AsyncGroup(getDefaultAsyncRuntime(), size
);
255 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken
*token
,
257 std::unique_lock
<std::mutex
> lockToken(token
->mu
);
258 std::unique_lock
<std::mutex
> lockGroup(group
->mu
);
260 // Get the rank of the token inside the group before we drop the reference.
261 int rank
= group
->rank
.fetch_add(1);
263 auto onTokenReady
= [group
, token
]() {
264 // Increment the number of errors in the group.
265 if (State(token
->state
).isError())
266 group
->numErrors
.fetch_add(1);
268 // If pending tokens go below zero it means that more tokens than the group
269 // size were added to this group.
270 assert(group
->pendingTokens
> 0 && "wrong group size");
272 // Run all group awaiters if it was the last token in the group.
273 if (group
->pendingTokens
.fetch_sub(1) == 1) {
274 group
->cv
.notify_all();
275 for (auto &awaiter
: group
->awaiters
)
280 if (State(token
->state
).isAvailableOrError()) {
281 // Update group pending tokens immediately and maybe run awaiters.
285 // Update group pending tokens when token will become ready. Because this
286 // will happen asynchronously we must ensure that `group` is alive until
287 // then, and re-ackquire the lock.
290 token
->awaiters
.emplace_back([group
, onTokenReady
]() {
291 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
293 std::unique_lock
<std::mutex
> lockGroup(group
->mu
);
303 // Switches `async.token` to available or error state (terminatl state) and runs
305 static void setTokenState(AsyncToken
*token
, State state
) {
306 assert(state
.isAvailableOrError() && "must be terminal state");
307 assert(State(token
->state
).isUnavailable() && "token must be unavailable");
309 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
311 std::unique_lock
<std::mutex
> lock(token
->mu
);
312 token
->state
= state
;
313 token
->cv
.notify_all();
314 for (auto &awaiter
: token
->awaiters
)
318 // Async tokens created with a ref count `2` to keep token alive until the
319 // async task completes. Drop this reference explicitly when token emplaced.
323 static void setValueState(AsyncValue
*value
, State state
) {
324 assert(state
.isAvailableOrError() && "must be terminal state");
325 assert(State(value
->state
).isUnavailable() && "value must be unavailable");
327 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
329 std::unique_lock
<std::mutex
> lock(value
->mu
);
330 value
->state
= state
;
331 value
->cv
.notify_all();
332 for (auto &awaiter
: value
->awaiters
)
336 // Async values created with a ref count `2` to keep value alive until the
337 // async task completes. Drop this reference explicitly when value emplaced.
341 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken
*token
) {
342 setTokenState(token
, State::kAvailable
);
345 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue
*value
) {
346 setValueState(value
, State::kAvailable
);
349 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken
*token
) {
350 setTokenState(token
, State::kError
);
353 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue
*value
) {
354 setValueState(value
, State::kError
);
357 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken
*token
) {
358 return State(token
->state
).isError();
361 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue
*value
) {
362 return State(value
->state
).isError();
365 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup
*group
) {
366 return group
->numErrors
.load() > 0;
369 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken
*token
) {
370 std::unique_lock
<std::mutex
> lock(token
->mu
);
371 if (!State(token
->state
).isAvailableOrError())
373 lock
, [token
] { return State(token
->state
).isAvailableOrError(); });
376 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue
*value
) {
377 std::unique_lock
<std::mutex
> lock(value
->mu
);
378 if (!State(value
->state
).isAvailableOrError())
380 lock
, [value
] { return State(value
->state
).isAvailableOrError(); });
383 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup
*group
) {
384 std::unique_lock
<std::mutex
> lock(group
->mu
);
385 if (group
->pendingTokens
!= 0)
386 group
->cv
.wait(lock
, [group
] { return group
->pendingTokens
== 0; });
389 // Returns a pointer to the storage owned by the async value.
390 extern "C" ValueStorage
mlirAsyncRuntimeGetValueStorage(AsyncValue
*value
) {
391 assert(!State(value
->state
).isError() && "unexpected error state");
392 return value
->storage
.data();
395 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle
, CoroResume resume
) {
396 auto *runtime
= getDefaultAsyncRuntime();
397 runtime
->getThreadPool().async([handle
, resume
]() { (*resume
)(handle
); });
400 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken
*token
,
403 auto execute
= [handle
, resume
]() { (*resume
)(handle
); };
404 std::unique_lock
<std::mutex
> lock(token
->mu
);
405 if (State(token
->state
).isAvailableOrError()) {
409 token
->awaiters
.emplace_back([execute
]() { execute(); });
413 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue
*value
,
416 auto execute
= [handle
, resume
]() { (*resume
)(handle
); };
417 std::unique_lock
<std::mutex
> lock(value
->mu
);
418 if (State(value
->state
).isAvailableOrError()) {
422 value
->awaiters
.emplace_back([execute
]() { execute(); });
426 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup
*group
,
429 auto execute
= [handle
, resume
]() { (*resume
)(handle
); };
430 std::unique_lock
<std::mutex
> lock(group
->mu
);
431 if (group
->pendingTokens
== 0) {
435 group
->awaiters
.emplace_back([execute
]() { execute(); });
439 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440 return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
443 //===----------------------------------------------------------------------===//
444 // Small async runtime support library for testing.
445 //===----------------------------------------------------------------------===//
447 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
448 static thread_local
std::thread::id thisId
= std::this_thread::get_id();
449 std::cout
<< "Current thread id: " << thisId
<< '\n';
452 //===----------------------------------------------------------------------===//
453 // MLIR ExecutionEngine dynamic library integration.
454 //===----------------------------------------------------------------------===//
456 // Visual Studio had a bug that fails to compile nested generic lambdas
457 // inside an `extern "C"` function.
458 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460 // a work around for older versions of Visual Studio.
461 // NOLINTNEXTLINE(*-identifier-naming): externally called.
462 extern "C" MLIR_ASYNC_RUNTIME_EXPORT
void
463 __mlir_execution_engine_init(llvm::StringMap
<void *> &exportSymbols
);
465 // NOLINTNEXTLINE(*-identifier-naming): externally called.
466 void __mlir_execution_engine_init(llvm::StringMap
<void *> &exportSymbols
) {
467 auto exportSymbol
= [&](llvm::StringRef name
, auto ptr
) {
468 assert(exportSymbols
.count(name
) == 0 && "symbol already exists");
469 exportSymbols
[name
] = reinterpret_cast<void *>(ptr
);
472 exportSymbol("mlirAsyncRuntimeAddRef",
473 &mlir::runtime::mlirAsyncRuntimeAddRef
);
474 exportSymbol("mlirAsyncRuntimeDropRef",
475 &mlir::runtime::mlirAsyncRuntimeDropRef
);
476 exportSymbol("mlirAsyncRuntimeExecute",
477 &mlir::runtime::mlirAsyncRuntimeExecute
);
478 exportSymbol("mlirAsyncRuntimeGetValueStorage",
479 &mlir::runtime::mlirAsyncRuntimeGetValueStorage
);
480 exportSymbol("mlirAsyncRuntimeCreateToken",
481 &mlir::runtime::mlirAsyncRuntimeCreateToken
);
482 exportSymbol("mlirAsyncRuntimeCreateValue",
483 &mlir::runtime::mlirAsyncRuntimeCreateValue
);
484 exportSymbol("mlirAsyncRuntimeEmplaceToken",
485 &mlir::runtime::mlirAsyncRuntimeEmplaceToken
);
486 exportSymbol("mlirAsyncRuntimeEmplaceValue",
487 &mlir::runtime::mlirAsyncRuntimeEmplaceValue
);
488 exportSymbol("mlirAsyncRuntimeSetTokenError",
489 &mlir::runtime::mlirAsyncRuntimeSetTokenError
);
490 exportSymbol("mlirAsyncRuntimeSetValueError",
491 &mlir::runtime::mlirAsyncRuntimeSetValueError
);
492 exportSymbol("mlirAsyncRuntimeIsTokenError",
493 &mlir::runtime::mlirAsyncRuntimeIsTokenError
);
494 exportSymbol("mlirAsyncRuntimeIsValueError",
495 &mlir::runtime::mlirAsyncRuntimeIsValueError
);
496 exportSymbol("mlirAsyncRuntimeIsGroupError",
497 &mlir::runtime::mlirAsyncRuntimeIsGroupError
);
498 exportSymbol("mlirAsyncRuntimeAwaitToken",
499 &mlir::runtime::mlirAsyncRuntimeAwaitToken
);
500 exportSymbol("mlirAsyncRuntimeAwaitValue",
501 &mlir::runtime::mlirAsyncRuntimeAwaitValue
);
502 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
503 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute
);
504 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
505 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute
);
506 exportSymbol("mlirAsyncRuntimeCreateGroup",
507 &mlir::runtime::mlirAsyncRuntimeCreateGroup
);
508 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
509 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup
);
510 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
511 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup
);
512 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
513 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute
);
514 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads
);
516 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
517 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId
);
520 // NOLINTNEXTLINE(*-identifier-naming): externally called.
521 extern "C" MLIR_ASYNC_RUNTIME_EXPORT
void __mlir_execution_engine_destroy() {
522 resetDefaultAsyncRuntime();
525 } // namespace runtime