Add PR check to suggest alternatives to using undef (#118506)
[llvm-project.git] / mlir / lib / ExecutionEngine / AsyncRuntime.cpp
blob9e6f8a7216995617c6ad6e91f809606ab0fffbc3
1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
16 #include <atomic>
17 #include <cassert>
18 #include <condition_variable>
19 #include <functional>
20 #include <iostream>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
28 using namespace mlir::runtime;
30 //===----------------------------------------------------------------------===//
31 // Async runtime API.
32 //===----------------------------------------------------------------------===//
34 namespace mlir {
35 namespace runtime {
36 namespace {
38 // Forward declare class defined below.
39 class RefCounted;
41 // -------------------------------------------------------------------------- //
42 // AsyncRuntime orchestrates all async operations and Async runtime API is built
43 // on top of the default runtime instance.
44 // -------------------------------------------------------------------------- //
46 class AsyncRuntime {
47 public:
48 AsyncRuntime() : numRefCountedObjects(0) {}
50 ~AsyncRuntime() {
51 threadPool.wait(); // wait for the completion of all async tasks
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects.load(std::memory_order_relaxed);
60 llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
62 private:
63 friend class RefCounted;
65 // Count the total number of reference counted objects in this instance
66 // of an AsyncRuntime. For debugging purposes only.
67 void addNumRefCountedObjects() {
68 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
70 void dropNumRefCountedObjects() {
71 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74 std::atomic<int64_t> numRefCountedObjects;
75 llvm::DefaultThreadPool threadPool;
78 // -------------------------------------------------------------------------- //
79 // A state of the async runtime value (token, value or group).
80 // -------------------------------------------------------------------------- //
82 class State {
83 public:
84 enum StateEnum : int8_t {
85 // The underlying value is not yet available for consumption.
86 kUnavailable = 0,
87 // The underlying value is available for consumption. This state can not
88 // transition to any other state.
89 kAvailable = 1,
90 // This underlying value is available and contains an error. This state can
91 // not transition to any other state.
92 kError = 2,
95 /* implicit */ State(StateEnum s) : state(s) {}
96 /* implicit */ operator StateEnum() { return state; }
98 bool isUnavailable() const { return state == kUnavailable; }
99 bool isAvailable() const { return state == kAvailable; }
100 bool isError() const { return state == kError; }
101 bool isAvailableOrError() const { return isAvailable() || isError(); }
103 const char *debug() const {
104 switch (state) {
105 case kUnavailable:
106 return "unavailable";
107 case kAvailable:
108 return "available";
109 case kError:
110 return "error";
114 private:
115 StateEnum state;
118 // -------------------------------------------------------------------------- //
119 // A base class for all reference counted objects created by the async runtime.
120 // -------------------------------------------------------------------------- //
122 class RefCounted {
123 public:
124 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125 : runtime(runtime), refCount(refCount) {
126 runtime->addNumRefCountedObjects();
129 virtual ~RefCounted() {
130 assert(refCount.load() == 0 && "reference count must be zero");
131 runtime->dropNumRefCountedObjects();
134 RefCounted(const RefCounted &) = delete;
135 RefCounted &operator=(const RefCounted &) = delete;
137 void addRef(int64_t count = 1) { refCount.fetch_add(count); }
139 void dropRef(int64_t count = 1) {
140 int64_t previous = refCount.fetch_sub(count);
141 assert(previous >= count && "reference count should not go below zero");
142 if (previous == count)
143 destroy();
146 protected:
147 virtual void destroy() { delete this; }
149 private:
150 AsyncRuntime *runtime;
151 std::atomic<int64_t> refCount;
154 } // namespace
156 // Returns the default per-process instance of an async runtime.
157 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158 static auto runtime = std::make_unique<AsyncRuntime>();
159 return runtime;
162 static void resetDefaultAsyncRuntime() {
163 return getDefaultAsyncRuntimeInstance().reset();
166 static AsyncRuntime *getDefaultAsyncRuntime() {
167 return getDefaultAsyncRuntimeInstance().get();
170 // Async token provides a mechanism to signal asynchronous operation completion.
171 struct AsyncToken : public RefCounted {
172 // AsyncToken created with a reference count of 2 because it will be returned
173 // to the `async.execute` caller and also will be later on emplaced by the
174 // asynchronously executed task. If the caller immediately will drop its
175 // reference we must ensure that the token will be alive until the
176 // asynchronous operation is completed.
177 AsyncToken(AsyncRuntime *runtime)
178 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
180 std::atomic<State::StateEnum> state;
182 // Pending awaiters are guarded by a mutex.
183 std::mutex mu;
184 std::condition_variable cv;
185 std::vector<std::function<void()>> awaiters;
188 // Async value provides a mechanism to access the result of asynchronous
189 // operations. It owns the storage that is used to store/load the value of the
190 // underlying type, and a flag to signal if the value is ready or not.
191 struct AsyncValue : public RefCounted {
192 // AsyncValue similar to an AsyncToken created with a reference count of 2.
193 AsyncValue(AsyncRuntime *runtime, int64_t size)
194 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195 storage(size) {}
197 std::atomic<State::StateEnum> state;
199 // Use vector of bytes to store async value payload.
200 std::vector<std::byte> storage;
202 // Pending awaiters are guarded by a mutex.
203 std::mutex mu;
204 std::condition_variable cv;
205 std::vector<std::function<void()>> awaiters;
208 // Async group provides a mechanism to group together multiple async tokens or
209 // values to await on all of them together (wait for the completion of all
210 // tokens or values added to the group).
211 struct AsyncGroup : public RefCounted {
212 AsyncGroup(AsyncRuntime *runtime, int64_t size)
213 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
215 std::atomic<int> pendingTokens;
216 std::atomic<int> numErrors;
217 std::atomic<int> rank;
219 // Pending awaiters are guarded by a mutex.
220 std::mutex mu;
221 std::condition_variable cv;
222 std::vector<std::function<void()>> awaiters;
225 // Adds references to reference counted runtime object.
226 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228 refCounted->addRef(count);
231 // Drops references from reference counted runtime object.
232 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234 refCounted->dropRef(count);
237 // Creates a new `async.token` in not-ready state.
238 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
239 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
240 return token;
243 // Creates a new `async.value` in not-ready state.
244 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
245 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246 return value;
249 // Create a new `async.group` in empty state.
250 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252 return group;
255 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
256 AsyncGroup *group) {
257 std::unique_lock<std::mutex> lockToken(token->mu);
258 std::unique_lock<std::mutex> lockGroup(group->mu);
260 // Get the rank of the token inside the group before we drop the reference.
261 int rank = group->rank.fetch_add(1);
263 auto onTokenReady = [group, token]() {
264 // Increment the number of errors in the group.
265 if (State(token->state).isError())
266 group->numErrors.fetch_add(1);
268 // If pending tokens go below zero it means that more tokens than the group
269 // size were added to this group.
270 assert(group->pendingTokens > 0 && "wrong group size");
272 // Run all group awaiters if it was the last token in the group.
273 if (group->pendingTokens.fetch_sub(1) == 1) {
274 group->cv.notify_all();
275 for (auto &awaiter : group->awaiters)
276 awaiter();
280 if (State(token->state).isAvailableOrError()) {
281 // Update group pending tokens immediately and maybe run awaiters.
282 onTokenReady();
284 } else {
285 // Update group pending tokens when token will become ready. Because this
286 // will happen asynchronously we must ensure that `group` is alive until
287 // then, and re-ackquire the lock.
288 group->addRef();
290 token->awaiters.emplace_back([group, onTokenReady]() {
291 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
293 std::unique_lock<std::mutex> lockGroup(group->mu);
294 onTokenReady();
296 group->dropRef();
300 return rank;
303 // Switches `async.token` to available or error state (terminatl state) and runs
304 // all awaiters.
305 static void setTokenState(AsyncToken *token, State state) {
306 assert(state.isAvailableOrError() && "must be terminal state");
307 assert(State(token->state).isUnavailable() && "token must be unavailable");
309 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
311 std::unique_lock<std::mutex> lock(token->mu);
312 token->state = state;
313 token->cv.notify_all();
314 for (auto &awaiter : token->awaiters)
315 awaiter();
318 // Async tokens created with a ref count `2` to keep token alive until the
319 // async task completes. Drop this reference explicitly when token emplaced.
320 token->dropRef();
323 static void setValueState(AsyncValue *value, State state) {
324 assert(state.isAvailableOrError() && "must be terminal state");
325 assert(State(value->state).isUnavailable() && "value must be unavailable");
327 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
329 std::unique_lock<std::mutex> lock(value->mu);
330 value->state = state;
331 value->cv.notify_all();
332 for (auto &awaiter : value->awaiters)
333 awaiter();
336 // Async values created with a ref count `2` to keep value alive until the
337 // async task completes. Drop this reference explicitly when value emplaced.
338 value->dropRef();
341 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
342 setTokenState(token, State::kAvailable);
345 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
346 setValueState(value, State::kAvailable);
349 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
350 setTokenState(token, State::kError);
353 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
354 setValueState(value, State::kError);
357 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
358 return State(token->state).isError();
361 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
362 return State(value->state).isError();
365 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366 return group->numErrors.load() > 0;
369 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370 std::unique_lock<std::mutex> lock(token->mu);
371 if (!State(token->state).isAvailableOrError())
372 token->cv.wait(
373 lock, [token] { return State(token->state).isAvailableOrError(); });
376 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377 std::unique_lock<std::mutex> lock(value->mu);
378 if (!State(value->state).isAvailableOrError())
379 value->cv.wait(
380 lock, [value] { return State(value->state).isAvailableOrError(); });
383 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384 std::unique_lock<std::mutex> lock(group->mu);
385 if (group->pendingTokens != 0)
386 group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
389 // Returns a pointer to the storage owned by the async value.
390 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
391 assert(!State(value->state).isError() && "unexpected error state");
392 return value->storage.data();
395 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396 auto *runtime = getDefaultAsyncRuntime();
397 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
400 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
401 CoroHandle handle,
402 CoroResume resume) {
403 auto execute = [handle, resume]() { (*resume)(handle); };
404 std::unique_lock<std::mutex> lock(token->mu);
405 if (State(token->state).isAvailableOrError()) {
406 lock.unlock();
407 execute();
408 } else {
409 token->awaiters.emplace_back([execute]() { execute(); });
413 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414 CoroHandle handle,
415 CoroResume resume) {
416 auto execute = [handle, resume]() { (*resume)(handle); };
417 std::unique_lock<std::mutex> lock(value->mu);
418 if (State(value->state).isAvailableOrError()) {
419 lock.unlock();
420 execute();
421 } else {
422 value->awaiters.emplace_back([execute]() { execute(); });
426 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
427 CoroHandle handle,
428 CoroResume resume) {
429 auto execute = [handle, resume]() { (*resume)(handle); };
430 std::unique_lock<std::mutex> lock(group->mu);
431 if (group->pendingTokens == 0) {
432 lock.unlock();
433 execute();
434 } else {
435 group->awaiters.emplace_back([execute]() { execute(); });
439 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440 return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
443 //===----------------------------------------------------------------------===//
444 // Small async runtime support library for testing.
445 //===----------------------------------------------------------------------===//
447 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
448 static thread_local std::thread::id thisId = std::this_thread::get_id();
449 std::cout << "Current thread id: " << thisId << '\n';
452 //===----------------------------------------------------------------------===//
453 // MLIR ExecutionEngine dynamic library integration.
454 //===----------------------------------------------------------------------===//
456 // Visual Studio had a bug that fails to compile nested generic lambdas
457 // inside an `extern "C"` function.
458 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460 // a work around for older versions of Visual Studio.
461 // NOLINTNEXTLINE(*-identifier-naming): externally called.
462 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463 __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
465 // NOLINTNEXTLINE(*-identifier-naming): externally called.
466 void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467 auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468 assert(exportSymbols.count(name) == 0 && "symbol already exists");
469 exportSymbols[name] = reinterpret_cast<void *>(ptr);
472 exportSymbol("mlirAsyncRuntimeAddRef",
473 &mlir::runtime::mlirAsyncRuntimeAddRef);
474 exportSymbol("mlirAsyncRuntimeDropRef",
475 &mlir::runtime::mlirAsyncRuntimeDropRef);
476 exportSymbol("mlirAsyncRuntimeExecute",
477 &mlir::runtime::mlirAsyncRuntimeExecute);
478 exportSymbol("mlirAsyncRuntimeGetValueStorage",
479 &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
480 exportSymbol("mlirAsyncRuntimeCreateToken",
481 &mlir::runtime::mlirAsyncRuntimeCreateToken);
482 exportSymbol("mlirAsyncRuntimeCreateValue",
483 &mlir::runtime::mlirAsyncRuntimeCreateValue);
484 exportSymbol("mlirAsyncRuntimeEmplaceToken",
485 &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
486 exportSymbol("mlirAsyncRuntimeEmplaceValue",
487 &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
488 exportSymbol("mlirAsyncRuntimeSetTokenError",
489 &mlir::runtime::mlirAsyncRuntimeSetTokenError);
490 exportSymbol("mlirAsyncRuntimeSetValueError",
491 &mlir::runtime::mlirAsyncRuntimeSetValueError);
492 exportSymbol("mlirAsyncRuntimeIsTokenError",
493 &mlir::runtime::mlirAsyncRuntimeIsTokenError);
494 exportSymbol("mlirAsyncRuntimeIsValueError",
495 &mlir::runtime::mlirAsyncRuntimeIsValueError);
496 exportSymbol("mlirAsyncRuntimeIsGroupError",
497 &mlir::runtime::mlirAsyncRuntimeIsGroupError);
498 exportSymbol("mlirAsyncRuntimeAwaitToken",
499 &mlir::runtime::mlirAsyncRuntimeAwaitToken);
500 exportSymbol("mlirAsyncRuntimeAwaitValue",
501 &mlir::runtime::mlirAsyncRuntimeAwaitValue);
502 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
503 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
504 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
505 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
506 exportSymbol("mlirAsyncRuntimeCreateGroup",
507 &mlir::runtime::mlirAsyncRuntimeCreateGroup);
508 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
509 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
510 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
511 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
512 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
513 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
516 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
517 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
520 // NOLINTNEXTLINE(*-identifier-naming): externally called.
521 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
522 resetDefaultAsyncRuntime();
525 } // namespace runtime
526 } // namespace mlir