1 //===----------------------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
12 #include <__algorithm/upper_bound.h>
14 #include <__random/is_valid.h>
15 #include <__random/uniform_real_distribution.h>
21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
22 # pragma GCC system_header
26 #include <__undef_macros>
28 _LIBCPP_BEGIN_NAMESPACE_STD
30 template<class _IntType
= int>
31 class _LIBCPP_TEMPLATE_VIS discrete_distribution
33 static_assert(__libcpp_random_is_valid_inttype
<_IntType
>::value
, "IntType must be a supported integer type");
36 typedef _IntType result_type
;
38 class _LIBCPP_TEMPLATE_VIS param_type
42 typedef discrete_distribution distribution_type
;
44 _LIBCPP_INLINE_VISIBILITY
46 template<class _InputIterator
>
47 _LIBCPP_INLINE_VISIBILITY
48 param_type(_InputIterator __f
, _InputIterator __l
)
49 : __p_(__f
, __l
) {__init();}
50 #ifndef _LIBCPP_CXX03_LANG
51 _LIBCPP_INLINE_VISIBILITY
52 param_type(initializer_list
<double> __wl
)
53 : __p_(__wl
.begin(), __wl
.end()) {__init();}
54 #endif // _LIBCPP_CXX03_LANG
55 template<class _UnaryOperation
>
56 _LIBCPP_HIDE_FROM_ABI
param_type(size_t __nw
, double __xmin
, double __xmax
,
57 _UnaryOperation __fw
);
59 _LIBCPP_HIDE_FROM_ABI vector
<double> probabilities() const;
61 friend _LIBCPP_INLINE_VISIBILITY
62 bool operator==(const param_type
& __x
, const param_type
& __y
)
63 {return __x
.__p_
== __y
.__p_
;}
64 friend _LIBCPP_INLINE_VISIBILITY
65 bool operator!=(const param_type
& __x
, const param_type
& __y
)
66 {return !(__x
== __y
);}
69 _LIBCPP_HIDE_FROM_ABI
void __init();
71 friend class discrete_distribution
;
73 template <class _CharT
, class _Traits
, class _IT
>
75 basic_ostream
<_CharT
, _Traits
>&
76 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
77 const discrete_distribution
<_IT
>& __x
);
79 template <class _CharT
, class _Traits
, class _IT
>
81 basic_istream
<_CharT
, _Traits
>&
82 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
83 discrete_distribution
<_IT
>& __x
);
90 // constructor and reset functions
91 _LIBCPP_INLINE_VISIBILITY
92 discrete_distribution() {}
93 template<class _InputIterator
>
94 _LIBCPP_INLINE_VISIBILITY
95 discrete_distribution(_InputIterator __f
, _InputIterator __l
)
97 #ifndef _LIBCPP_CXX03_LANG
98 _LIBCPP_INLINE_VISIBILITY
99 discrete_distribution(initializer_list
<double> __wl
)
101 #endif // _LIBCPP_CXX03_LANG
102 template<class _UnaryOperation
>
103 _LIBCPP_INLINE_VISIBILITY
104 discrete_distribution(size_t __nw
, double __xmin
, double __xmax
,
105 _UnaryOperation __fw
)
106 : __p_(__nw
, __xmin
, __xmax
, __fw
) {}
107 _LIBCPP_INLINE_VISIBILITY
108 explicit discrete_distribution(const param_type
& __p
)
110 _LIBCPP_INLINE_VISIBILITY
113 // generating functions
114 template<class _URNG
>
115 _LIBCPP_INLINE_VISIBILITY
116 result_type
operator()(_URNG
& __g
)
117 {return (*this)(__g
, __p_
);}
118 template<class _URNG
>
119 _LIBCPP_HIDE_FROM_ABI result_type
operator()(_URNG
& __g
, const param_type
& __p
);
121 // property functions
122 _LIBCPP_INLINE_VISIBILITY
123 vector
<double> probabilities() const {return __p_
.probabilities();}
125 _LIBCPP_INLINE_VISIBILITY
126 param_type
param() const {return __p_
;}
127 _LIBCPP_INLINE_VISIBILITY
128 void param(const param_type
& __p
) {__p_
= __p
;}
130 _LIBCPP_INLINE_VISIBILITY
131 result_type
min() const {return 0;}
132 _LIBCPP_INLINE_VISIBILITY
133 result_type
max() const {return __p_
.__p_
.size();}
135 friend _LIBCPP_INLINE_VISIBILITY
136 bool operator==(const discrete_distribution
& __x
,
137 const discrete_distribution
& __y
)
138 {return __x
.__p_
== __y
.__p_
;}
139 friend _LIBCPP_INLINE_VISIBILITY
140 bool operator!=(const discrete_distribution
& __x
,
141 const discrete_distribution
& __y
)
142 {return !(__x
== __y
);}
144 template <class _CharT
, class _Traits
, class _IT
>
146 basic_ostream
<_CharT
, _Traits
>&
147 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
148 const discrete_distribution
<_IT
>& __x
);
150 template <class _CharT
, class _Traits
, class _IT
>
152 basic_istream
<_CharT
, _Traits
>&
153 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
154 discrete_distribution
<_IT
>& __x
);
157 template<class _IntType
>
158 template<class _UnaryOperation
>
159 discrete_distribution
<_IntType
>::param_type::param_type(size_t __nw
,
162 _UnaryOperation __fw
)
166 __p_
.reserve(__nw
- 1);
167 double __d
= (__xmax
- __xmin
) / __nw
;
168 double __d2
= __d
/ 2;
169 for (size_t __k
= 0; __k
< __nw
; ++__k
)
170 __p_
.push_back(__fw(__xmin
+ __k
* __d
+ __d2
));
175 template<class _IntType
>
177 discrete_distribution
<_IntType
>::param_type::__init()
183 double __s
= _VSTD::accumulate(__p_
.begin(), __p_
.end(), 0.0);
184 for (vector
<double>::iterator __i
= __p_
.begin(), __e
= __p_
.end(); __i
< __e
; ++__i
)
186 vector
<double> __t(__p_
.size() - 1);
187 _VSTD::partial_sum(__p_
.begin(), __p_
.end() - 1, __t
.begin());
193 __p_
.shrink_to_fit();
198 template<class _IntType
>
200 discrete_distribution
<_IntType
>::param_type::probabilities() const
202 size_t __n
= __p_
.size();
203 vector
<double> __p(__n
+1);
204 _VSTD::adjacent_difference(__p_
.begin(), __p_
.end(), __p
.begin());
206 __p
[__n
] = 1 - __p_
[__n
-1];
212 template<class _IntType
>
213 template<class _URNG
>
215 discrete_distribution
<_IntType
>::operator()(_URNG
& __g
, const param_type
& __p
)
217 static_assert(__libcpp_random_is_valid_urng
<_URNG
>::value
, "");
218 uniform_real_distribution
<double> __gen
;
219 return static_cast<_IntType
>(
220 _VSTD::upper_bound(__p
.__p_
.begin(), __p
.__p_
.end(), __gen(__g
)) -
224 template <class _CharT
, class _Traits
, class _IT
>
225 _LIBCPP_HIDE_FROM_ABI basic_ostream
<_CharT
, _Traits
>&
226 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
227 const discrete_distribution
<_IT
>& __x
)
229 __save_flags
<_CharT
, _Traits
> __lx(__os
);
230 typedef basic_ostream
<_CharT
, _Traits
> _OStream
;
231 __os
.flags(_OStream::dec
| _OStream::left
| _OStream::fixed
|
232 _OStream::scientific
);
233 _CharT __sp
= __os
.widen(' ');
235 size_t __n
= __x
.__p_
.__p_
.size();
237 for (size_t __i
= 0; __i
< __n
; ++__i
)
238 __os
<< __sp
<< __x
.__p_
.__p_
[__i
];
242 template <class _CharT
, class _Traits
, class _IT
>
243 _LIBCPP_HIDE_FROM_ABI basic_istream
<_CharT
, _Traits
>&
244 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
245 discrete_distribution
<_IT
>& __x
)
247 __save_flags
<_CharT
, _Traits
> __lx(__is
);
248 typedef basic_istream
<_CharT
, _Traits
> _Istream
;
249 __is
.flags(_Istream::dec
| _Istream::skipws
);
252 vector
<double> __p(__n
);
253 for (size_t __i
= 0; __i
< __n
; ++__i
)
256 swap(__x
.__p_
.__p_
, __p
);
260 _LIBCPP_END_NAMESPACE_STD
264 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H