use insert function instead of for loop
[LibreOffice.git] / include / systools / win32 / comtools.hxx
blob9c2f49e8be0a16d5534c442d3672178f0e605c80
1 /* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
2 /*
3 * This file is part of the LibreOffice project.
5 * This Source Code Form is subject to the terms of the Mozilla Public
6 * License, v. 2.0. If a copy of the MPL was not distributed with this
7 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 * This file incorporates work covered by the following license notice:
11 * Licensed to the Apache Software Foundation (ASF) under one or more
12 * contributor license agreements. See the NOTICE file distributed
13 * with this work for additional information regarding copyright
14 * ownership. The ASF licenses this file to you under the Apache
15 * License, Version 2.0 (the "License"); you may not use this file
16 * except in compliance with the License. You may obtain a copy of
17 * the License at http://www.apache.org/licenses/LICENSE-2.0 .
20 #pragma once
22 #include <sal/config.h>
24 #include <source_location>
25 #include <string>
26 #include <string_view>
27 #include <stdexcept>
28 #include <type_traits>
29 #include <utility>
31 #include <prewin.h>
32 #include <objbase.h>
33 #include <postwin.h>
35 namespace sal::systools
37 /* Simple exception class for propagating COM errors */
38 class ComError : public std::runtime_error
40 public:
41 ComError(std::string_view message, HRESULT hr,
42 const std::source_location& loc = std::source_location::current())
43 : std::runtime_error(std::string(message))
44 , hr_(hr)
45 , loc_(loc)
48 HRESULT GetHresult() const { return hr_; }
49 const std::source_location& GetLocation() const { return loc_; }
51 private:
52 HRESULT hr_;
53 std::source_location loc_;
56 /* Convert failed HRESULT to thrown ComError */
57 inline void ThrowIfFailed(HRESULT hr, std::string_view msg,
58 std::source_location loc = std::source_location::current())
60 if (FAILED(hr))
61 throw ComError(msg, hr, loc);
64 /* A guard class to call CoInitializeEx/CoUninitialize in proper pairs
65 * See also: o3tl::safeCoInitializeEx doing dangerous re-initialization
67 class CoInitializeGuard
69 public:
70 enum class WhenFailed
72 NoThrow, // do not throw
73 Throw, // throw on failure
74 Abort, // std::abort on failure
76 explicit CoInitializeGuard(DWORD dwCoInit, bool failChangeMode = false,
77 WhenFailed whenFailed = WhenFailed::Throw)
79 HRESULT hr = ::CoInitializeEx(nullptr, dwCoInit);
80 if (whenFailed != WhenFailed::NoThrow && FAILED(hr)
81 && (failChangeMode || hr != RPC_E_CHANGED_MODE))
83 if (whenFailed == WhenFailed::Throw)
84 throw ComError("CoInitializeEx failed", hr);
85 else // if (whenFailed == Abort)
86 std::abort();
88 mbUninit = SUCCEEDED(hr);
90 CoInitializeGuard(const CoInitializeGuard&) = delete; // non-construction-copyable
91 void operator=(const CoInitializeGuard&) = delete; // non-copyable
92 ~CoInitializeGuard()
94 if (mbUninit)
95 ::CoUninitialize();
98 private:
99 bool mbUninit;
102 struct COM_QUERY_TAG {} constexpr COM_QUERY;
103 struct COM_QUERY_THROW_TAG : public COM_QUERY_TAG {} constexpr COM_QUERY_THROW;
105 /* A simple COM smart pointer template */
106 template <typename T>
107 class COMReference
109 public:
110 /* Explicitly controllable whether AddRef will be called or not */
111 COMReference(T* comptr = nullptr, bool bAddRef = true) :
112 com_ptr_(comptr)
114 if (bAddRef)
115 addRef(com_ptr_);
118 COMReference(const COMReference<T>& other) :
119 COMReference(other.com_ptr_)
123 COMReference(COMReference<T>&& other) :
124 COMReference(std::exchange(other.com_ptr_, nullptr), false)
128 // Query from IUnknown*, using COM_QUERY or COM_QUERY_THROW tags
129 template <typename T2, typename TAG>
130 COMReference(const COMReference<T2>& p, TAG t)
131 : COMReference(p.template QueryInterface<T>(t))
135 // Using CoCreateInstance
136 COMReference(REFCLSID clsid, IUnknown* pOuter = nullptr, DWORD nCtx = CLSCTX_ALL)
137 : com_ptr_(nullptr)
139 ThrowIfFailed(CoCreateInstance(clsid, pOuter, nCtx), "CoCreateInstance failed");
142 COMReference<T>& operator=(const COMReference<T>& other)
144 return operator=(other.com_ptr_);
147 COMReference<T>& operator=(COMReference<T>&& other)
149 if (com_ptr_ != other.com_ptr_)
151 clear();
152 std::swap(com_ptr_, other.com_ptr_);
154 return *this;
157 COMReference<T>& operator=(T* comptr)
159 assign(comptr);
160 return *this;
163 ~COMReference() { release(com_ptr_); }
165 template <typename T2, typename TAG>
166 requires std::is_base_of_v<COM_QUERY_TAG, TAG>
167 COMReference<T2> QueryInterface(TAG) const
169 T2* ip = nullptr;
170 HRESULT hr = E_POINTER;
171 if (com_ptr_)
172 hr = com_ptr_->QueryInterface(&ip);
174 if constexpr (std::is_same_v<TAG, COM_QUERY_THROW_TAG>)
175 ThrowIfFailed(hr, "QueryInterface failed");
177 return { ip, false };
180 template <typename T2, typename TAG>
181 COMReference<T>& set(const COMReference<T2>& p, TAG t)
183 return operator=(p.template QueryInterface<T>(t));
186 HRESULT CoCreateInstance(REFCLSID clsid, IUnknown* pOuter = nullptr,
187 DWORD nCtx = CLSCTX_ALL)
189 T* ip;
190 HRESULT hr = ::CoCreateInstance(clsid, pOuter, nCtx, IID_PPV_ARGS(&ip));
191 if (SUCCEEDED(hr))
192 release(std::exchange(com_ptr_, ip));
193 return hr;
196 HRESULT CoGetClassObject(REFCLSID clsid, DWORD nCtx = CLSCTX_ALL)
198 T* ip;
199 HRESULT hr = ::CoGetClassObject(clsid, nCtx, nullptr, IID_PPV_ARGS(&ip));
200 if (SUCCEEDED(hr))
201 release(std::exchange(com_ptr_, ip));
202 return hr;
205 T* operator->() const { return com_ptr_; }
207 T& operator*() const { return *com_ptr_; }
209 /* Necessary for assigning com_ptr_ from functions like
210 CoCreateInstance which require a 'void**' */
211 T** operator&()
213 clear();
214 return &com_ptr_;
217 T* get() const { return com_ptr_; }
218 operator T* () const { return get(); }
220 void clear() { assign(nullptr); }
222 bool is() const { return (com_ptr_ != nullptr); }
223 operator bool() const { return is(); }
225 private:
226 static void addRef(T* ptr)
228 if (ptr)
229 ptr->AddRef();
232 static void release(T* ptr)
234 if (ptr)
235 ptr->Release();
238 void assign(T* ptr)
240 if (com_ptr_ == ptr)
241 return;
242 addRef(ptr);
243 release(std::exchange(com_ptr_, ptr));
246 T* com_ptr_;
249 // A class to use with functions taking an out pointer argument,
250 // that needs to be freed with CoTaskMemFree - like SHGetKnownFolderPath
251 template <typename T> class CoTaskMemAllocated
253 public:
254 ~CoTaskMemAllocated() { CoTaskMemFree(m_pv); }
256 T** operator&()
258 CoTaskMemFree(std::exchange(m_pv, nullptr));
259 return &m_pv;
262 operator T*() { return m_pv; }
264 private:
265 T* m_pv = nullptr;
268 } // sal::systools
270 /* Typedefs for some popular COM interfaces */
271 typedef sal::systools::COMReference<IDataObject> IDataObjectPtr;
273 /* vim:set shiftwidth=4 softtabstop=4 expandtab: */