1 //===----------------------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
12 #include <__algorithm/upper_bound.h>
14 #include <__random/uniform_real_distribution.h>
20 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
21 #pragma GCC system_header
25 #include <__undef_macros>
27 _LIBCPP_BEGIN_NAMESPACE_STD
29 template<class _IntType
= int>
30 class _LIBCPP_TEMPLATE_VIS discrete_distribution
34 typedef _IntType result_type
;
36 class _LIBCPP_TEMPLATE_VIS param_type
40 typedef discrete_distribution distribution_type
;
42 _LIBCPP_INLINE_VISIBILITY
44 template<class _InputIterator
>
45 _LIBCPP_INLINE_VISIBILITY
46 param_type(_InputIterator __f
, _InputIterator __l
)
47 : __p_(__f
, __l
) {__init();}
48 #ifndef _LIBCPP_CXX03_LANG
49 _LIBCPP_INLINE_VISIBILITY
50 param_type(initializer_list
<double> __wl
)
51 : __p_(__wl
.begin(), __wl
.end()) {__init();}
52 #endif // _LIBCPP_CXX03_LANG
53 template<class _UnaryOperation
>
54 param_type(size_t __nw
, double __xmin
, double __xmax
,
55 _UnaryOperation __fw
);
57 vector
<double> probabilities() const;
59 friend _LIBCPP_INLINE_VISIBILITY
60 bool operator==(const param_type
& __x
, const param_type
& __y
)
61 {return __x
.__p_
== __y
.__p_
;}
62 friend _LIBCPP_INLINE_VISIBILITY
63 bool operator!=(const param_type
& __x
, const param_type
& __y
)
64 {return !(__x
== __y
);}
69 friend class discrete_distribution
;
71 template <class _CharT
, class _Traits
, class _IT
>
73 basic_ostream
<_CharT
, _Traits
>&
74 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
75 const discrete_distribution
<_IT
>& __x
);
77 template <class _CharT
, class _Traits
, class _IT
>
79 basic_istream
<_CharT
, _Traits
>&
80 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
81 discrete_distribution
<_IT
>& __x
);
88 // constructor and reset functions
89 _LIBCPP_INLINE_VISIBILITY
90 discrete_distribution() {}
91 template<class _InputIterator
>
92 _LIBCPP_INLINE_VISIBILITY
93 discrete_distribution(_InputIterator __f
, _InputIterator __l
)
95 #ifndef _LIBCPP_CXX03_LANG
96 _LIBCPP_INLINE_VISIBILITY
97 discrete_distribution(initializer_list
<double> __wl
)
99 #endif // _LIBCPP_CXX03_LANG
100 template<class _UnaryOperation
>
101 _LIBCPP_INLINE_VISIBILITY
102 discrete_distribution(size_t __nw
, double __xmin
, double __xmax
,
103 _UnaryOperation __fw
)
104 : __p_(__nw
, __xmin
, __xmax
, __fw
) {}
105 _LIBCPP_INLINE_VISIBILITY
106 explicit discrete_distribution(const param_type
& __p
)
108 _LIBCPP_INLINE_VISIBILITY
111 // generating functions
112 template<class _URNG
>
113 _LIBCPP_INLINE_VISIBILITY
114 result_type
operator()(_URNG
& __g
)
115 {return (*this)(__g
, __p_
);}
116 template<class _URNG
> result_type
operator()(_URNG
& __g
, const param_type
& __p
);
118 // property functions
119 _LIBCPP_INLINE_VISIBILITY
120 vector
<double> probabilities() const {return __p_
.probabilities();}
122 _LIBCPP_INLINE_VISIBILITY
123 param_type
param() const {return __p_
;}
124 _LIBCPP_INLINE_VISIBILITY
125 void param(const param_type
& __p
) {__p_
= __p
;}
127 _LIBCPP_INLINE_VISIBILITY
128 result_type
min() const {return 0;}
129 _LIBCPP_INLINE_VISIBILITY
130 result_type
max() const {return __p_
.__p_
.size();}
132 friend _LIBCPP_INLINE_VISIBILITY
133 bool operator==(const discrete_distribution
& __x
,
134 const discrete_distribution
& __y
)
135 {return __x
.__p_
== __y
.__p_
;}
136 friend _LIBCPP_INLINE_VISIBILITY
137 bool operator!=(const discrete_distribution
& __x
,
138 const discrete_distribution
& __y
)
139 {return !(__x
== __y
);}
141 template <class _CharT
, class _Traits
, class _IT
>
143 basic_ostream
<_CharT
, _Traits
>&
144 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
145 const discrete_distribution
<_IT
>& __x
);
147 template <class _CharT
, class _Traits
, class _IT
>
149 basic_istream
<_CharT
, _Traits
>&
150 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
151 discrete_distribution
<_IT
>& __x
);
154 template<class _IntType
>
155 template<class _UnaryOperation
>
156 discrete_distribution
<_IntType
>::param_type::param_type(size_t __nw
,
159 _UnaryOperation __fw
)
163 __p_
.reserve(__nw
- 1);
164 double __d
= (__xmax
- __xmin
) / __nw
;
165 double __d2
= __d
/ 2;
166 for (size_t __k
= 0; __k
< __nw
; ++__k
)
167 __p_
.push_back(__fw(__xmin
+ __k
* __d
+ __d2
));
172 template<class _IntType
>
174 discrete_distribution
<_IntType
>::param_type::__init()
180 double __s
= _VSTD::accumulate(__p_
.begin(), __p_
.end(), 0.0);
181 for (vector
<double>::iterator __i
= __p_
.begin(), __e
= __p_
.end(); __i
< __e
; ++__i
)
183 vector
<double> __t(__p_
.size() - 1);
184 _VSTD::partial_sum(__p_
.begin(), __p_
.end() - 1, __t
.begin());
190 __p_
.shrink_to_fit();
195 template<class _IntType
>
197 discrete_distribution
<_IntType
>::param_type::probabilities() const
199 size_t __n
= __p_
.size();
200 vector
<double> __p(__n
+1);
201 _VSTD::adjacent_difference(__p_
.begin(), __p_
.end(), __p
.begin());
203 __p
[__n
] = 1 - __p_
[__n
-1];
209 template<class _IntType
>
210 template<class _URNG
>
212 discrete_distribution
<_IntType
>::operator()(_URNG
& __g
, const param_type
& __p
)
214 uniform_real_distribution
<double> __gen
;
215 return static_cast<_IntType
>(
216 _VSTD::upper_bound(__p
.__p_
.begin(), __p
.__p_
.end(), __gen(__g
)) -
220 template <class _CharT
, class _Traits
, class _IT
>
221 basic_ostream
<_CharT
, _Traits
>&
222 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
,
223 const discrete_distribution
<_IT
>& __x
)
225 __save_flags
<_CharT
, _Traits
> __lx(__os
);
226 typedef basic_ostream
<_CharT
, _Traits
> _OStream
;
227 __os
.flags(_OStream::dec
| _OStream::left
| _OStream::fixed
|
228 _OStream::scientific
);
229 _CharT __sp
= __os
.widen(' ');
231 size_t __n
= __x
.__p_
.__p_
.size();
233 for (size_t __i
= 0; __i
< __n
; ++__i
)
234 __os
<< __sp
<< __x
.__p_
.__p_
[__i
];
238 template <class _CharT
, class _Traits
, class _IT
>
239 basic_istream
<_CharT
, _Traits
>&
240 operator>>(basic_istream
<_CharT
, _Traits
>& __is
,
241 discrete_distribution
<_IT
>& __x
)
243 __save_flags
<_CharT
, _Traits
> __lx(__is
);
244 typedef basic_istream
<_CharT
, _Traits
> _Istream
;
245 __is
.flags(_Istream::dec
| _Istream::skipws
);
248 vector
<double> __p(__n
);
249 for (size_t __i
= 0; __i
< __n
; ++__i
)
252 swap(__x
.__p_
.__p_
, __p
);
256 _LIBCPP_END_NAMESPACE_STD
260 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H