1 //===----------------------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
12 #include <__algorithm/upper_bound.h>
14 #include <__random/is_valid.h>
15 #include <__random/uniform_real_distribution.h>
16 #include <__vector/vector.h>
17 #include <initializer_list>
21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
22 # pragma GCC system_header
26 #include <__undef_macros>
28 _LIBCPP_BEGIN_NAMESPACE_STD
30 template <class _IntType
= int>
31 class _LIBCPP_TEMPLATE_VIS discrete_distribution
{
32 static_assert(__libcpp_random_is_valid_inttype
<_IntType
>::value
, "IntType must be a supported integer type");
36 typedef _IntType result_type
;
38 class _LIBCPP_TEMPLATE_VIS param_type
{
42 typedef discrete_distribution distribution_type
;
44 _LIBCPP_HIDE_FROM_ABI
param_type() {}
45 template <class _InputIterator
>
46 _LIBCPP_HIDE_FROM_ABI
param_type(_InputIterator __f
, _InputIterator __l
) : __p_(__f
, __l
) {
49 #ifndef _LIBCPP_CXX03_LANG
50 _LIBCPP_HIDE_FROM_ABI
param_type(initializer_list
<double> __wl
) : __p_(__wl
.begin(), __wl
.end()) { __init(); }
51 #endif // _LIBCPP_CXX03_LANG
52 template <class _UnaryOperation
>
53 _LIBCPP_HIDE_FROM_ABI
param_type(size_t __nw
, double __xmin
, double __xmax
, _UnaryOperation __fw
);
55 _LIBCPP_HIDE_FROM_ABI vector
<double> probabilities() const;
57 friend _LIBCPP_HIDE_FROM_ABI
bool operator==(const param_type
& __x
, const param_type
& __y
) {
58 return __x
.__p_
== __y
.__p_
;
60 friend _LIBCPP_HIDE_FROM_ABI
bool operator!=(const param_type
& __x
, const param_type
& __y
) { return !(__x
== __y
); }
63 _LIBCPP_HIDE_FROM_ABI
void __init();
65 friend class discrete_distribution
;
67 template <class _CharT
, class _Traits
, class _IT
>
68 friend basic_ostream
<_CharT
, _Traits
>&
69 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
, const discrete_distribution
<_IT
>& __x
);
71 template <class _CharT
, class _Traits
, class _IT
>
72 friend basic_istream
<_CharT
, _Traits
>&
73 operator>>(basic_istream
<_CharT
, _Traits
>& __is
, discrete_distribution
<_IT
>& __x
);
80 // constructor and reset functions
81 _LIBCPP_HIDE_FROM_ABI
discrete_distribution() {}
82 template <class _InputIterator
>
83 _LIBCPP_HIDE_FROM_ABI
discrete_distribution(_InputIterator __f
, _InputIterator __l
) : __p_(__f
, __l
) {}
84 #ifndef _LIBCPP_CXX03_LANG
85 _LIBCPP_HIDE_FROM_ABI
discrete_distribution(initializer_list
<double> __wl
) : __p_(__wl
) {}
86 #endif // _LIBCPP_CXX03_LANG
87 template <class _UnaryOperation
>
88 _LIBCPP_HIDE_FROM_ABI
discrete_distribution(size_t __nw
, double __xmin
, double __xmax
, _UnaryOperation __fw
)
89 : __p_(__nw
, __xmin
, __xmax
, __fw
) {}
90 _LIBCPP_HIDE_FROM_ABI
explicit discrete_distribution(const param_type
& __p
) : __p_(__p
) {}
91 _LIBCPP_HIDE_FROM_ABI
void reset() {}
93 // generating functions
94 template <class _URNG
>
95 _LIBCPP_HIDE_FROM_ABI result_type
operator()(_URNG
& __g
) {
96 return (*this)(__g
, __p_
);
98 template <class _URNG
>
99 _LIBCPP_HIDE_FROM_ABI result_type
operator()(_URNG
& __g
, const param_type
& __p
);
101 // property functions
102 _LIBCPP_HIDE_FROM_ABI vector
<double> probabilities() const { return __p_
.probabilities(); }
104 _LIBCPP_HIDE_FROM_ABI param_type
param() const { return __p_
; }
105 _LIBCPP_HIDE_FROM_ABI
void param(const param_type
& __p
) { __p_
= __p
; }
107 _LIBCPP_HIDE_FROM_ABI result_type
min() const { return 0; }
108 _LIBCPP_HIDE_FROM_ABI result_type
max() const { return __p_
.__p_
.size(); }
110 friend _LIBCPP_HIDE_FROM_ABI
bool operator==(const discrete_distribution
& __x
, const discrete_distribution
& __y
) {
111 return __x
.__p_
== __y
.__p_
;
113 friend _LIBCPP_HIDE_FROM_ABI
bool operator!=(const discrete_distribution
& __x
, const discrete_distribution
& __y
) {
114 return !(__x
== __y
);
117 template <class _CharT
, class _Traits
, class _IT
>
118 friend basic_ostream
<_CharT
, _Traits
>&
119 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
, const discrete_distribution
<_IT
>& __x
);
121 template <class _CharT
, class _Traits
, class _IT
>
122 friend basic_istream
<_CharT
, _Traits
>&
123 operator>>(basic_istream
<_CharT
, _Traits
>& __is
, discrete_distribution
<_IT
>& __x
);
126 template <class _IntType
>
127 template <class _UnaryOperation
>
128 discrete_distribution
<_IntType
>::param_type::param_type(
129 size_t __nw
, double __xmin
, double __xmax
, _UnaryOperation __fw
) {
131 __p_
.reserve(__nw
- 1);
132 double __d
= (__xmax
- __xmin
) / __nw
;
133 double __d2
= __d
/ 2;
134 for (size_t __k
= 0; __k
< __nw
; ++__k
)
135 __p_
.push_back(__fw(__xmin
+ __k
* __d
+ __d2
));
140 template <class _IntType
>
141 void discrete_distribution
<_IntType
>::param_type::__init() {
143 if (__p_
.size() > 1) {
144 double __s
= std::accumulate(__p_
.begin(), __p_
.end(), 0.0);
145 for (vector
<double>::iterator __i
= __p_
.begin(), __e
= __p_
.end(); __i
< __e
; ++__i
)
147 vector
<double> __t(__p_
.size() - 1);
148 std::partial_sum(__p_
.begin(), __p_
.end() - 1, __t
.begin());
152 __p_
.shrink_to_fit();
157 template <class _IntType
>
158 vector
<double> discrete_distribution
<_IntType
>::param_type::probabilities() const {
159 size_t __n
= __p_
.size();
160 vector
<double> __p(__n
+ 1);
161 std::adjacent_difference(__p_
.begin(), __p_
.end(), __p
.begin());
163 __p
[__n
] = 1 - __p_
[__n
- 1];
169 template <class _IntType
>
170 template <class _URNG
>
171 _IntType discrete_distribution
<_IntType
>::operator()(_URNG
& __g
, const param_type
& __p
) {
172 static_assert(__libcpp_random_is_valid_urng
<_URNG
>::value
, "");
173 uniform_real_distribution
<double> __gen
;
174 return static_cast<_IntType
>(std::upper_bound(__p
.__p_
.begin(), __p
.__p_
.end(), __gen(__g
)) - __p
.__p_
.begin());
177 template <class _CharT
, class _Traits
, class _IT
>
178 _LIBCPP_HIDE_FROM_ABI basic_ostream
<_CharT
, _Traits
>&
179 operator<<(basic_ostream
<_CharT
, _Traits
>& __os
, const discrete_distribution
<_IT
>& __x
) {
180 __save_flags
<_CharT
, _Traits
> __lx(__os
);
181 typedef basic_ostream
<_CharT
, _Traits
> _OStream
;
182 __os
.flags(_OStream::dec
| _OStream::left
| _OStream::fixed
| _OStream::scientific
);
183 _CharT __sp
= __os
.widen(' ');
185 size_t __n
= __x
.__p_
.__p_
.size();
187 for (size_t __i
= 0; __i
< __n
; ++__i
)
188 __os
<< __sp
<< __x
.__p_
.__p_
[__i
];
192 template <class _CharT
, class _Traits
, class _IT
>
193 _LIBCPP_HIDE_FROM_ABI basic_istream
<_CharT
, _Traits
>&
194 operator>>(basic_istream
<_CharT
, _Traits
>& __is
, discrete_distribution
<_IT
>& __x
) {
195 __save_flags
<_CharT
, _Traits
> __lx(__is
);
196 typedef basic_istream
<_CharT
, _Traits
> _Istream
;
197 __is
.flags(_Istream::dec
| _Istream::skipws
);
200 vector
<double> __p(__n
);
201 for (size_t __i
= 0; __i
< __n
; ++__i
)
204 swap(__x
.__p_
.__p_
, __p
);
208 _LIBCPP_END_NAMESPACE_STD
212 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H