[X86] Pre-commit test for D157513
[llvm-project.git] / libcxx / include / __random / normal_distribution.h
blobe2bf041b71fe2cdb597998ddc1479d7ca0210447
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef _LIBCPP___RANDOM_NORMAL_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_NORMAL_DISTRIBUTION_H
12 #include <__config>
13 #include <__random/is_valid.h>
14 #include <__random/uniform_real_distribution.h>
15 #include <cmath>
16 #include <iosfwd>
17 #include <limits>
19 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
20 # pragma GCC system_header
21 #endif
23 _LIBCPP_PUSH_MACROS
24 #include <__undef_macros>
26 _LIBCPP_BEGIN_NAMESPACE_STD
28 template<class _RealType = double>
29 class _LIBCPP_TEMPLATE_VIS normal_distribution
31 public:
32 // types
33 typedef _RealType result_type;
35 class _LIBCPP_TEMPLATE_VIS param_type
37 result_type __mean_;
38 result_type __stddev_;
39 public:
40 typedef normal_distribution distribution_type;
42 _LIBCPP_INLINE_VISIBILITY
43 explicit param_type(result_type __mean = 0, result_type __stddev = 1)
44 : __mean_(__mean), __stddev_(__stddev) {}
46 _LIBCPP_INLINE_VISIBILITY
47 result_type mean() const {return __mean_;}
48 _LIBCPP_INLINE_VISIBILITY
49 result_type stddev() const {return __stddev_;}
51 friend _LIBCPP_INLINE_VISIBILITY
52 bool operator==(const param_type& __x, const param_type& __y)
53 {return __x.__mean_ == __y.__mean_ && __x.__stddev_ == __y.__stddev_;}
54 friend _LIBCPP_INLINE_VISIBILITY
55 bool operator!=(const param_type& __x, const param_type& __y)
56 {return !(__x == __y);}
59 private:
60 param_type __p_;
61 result_type __v_;
62 bool __v_hot_;
64 public:
65 // constructors and reset functions
66 #ifndef _LIBCPP_CXX03_LANG
67 _LIBCPP_INLINE_VISIBILITY
68 normal_distribution() : normal_distribution(0) {}
69 _LIBCPP_INLINE_VISIBILITY
70 explicit normal_distribution(result_type __mean, result_type __stddev = 1)
71 : __p_(param_type(__mean, __stddev)), __v_hot_(false) {}
72 #else
73 _LIBCPP_INLINE_VISIBILITY
74 explicit normal_distribution(result_type __mean = 0,
75 result_type __stddev = 1)
76 : __p_(param_type(__mean, __stddev)), __v_hot_(false) {}
77 #endif
78 _LIBCPP_INLINE_VISIBILITY
79 explicit normal_distribution(const param_type& __p)
80 : __p_(__p), __v_hot_(false) {}
81 _LIBCPP_INLINE_VISIBILITY
82 void reset() {__v_hot_ = false;}
84 // generating functions
85 template<class _URNG>
86 _LIBCPP_INLINE_VISIBILITY
87 result_type operator()(_URNG& __g)
88 {return (*this)(__g, __p_);}
89 template<class _URNG>
90 _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p);
92 // property functions
93 _LIBCPP_INLINE_VISIBILITY
94 result_type mean() const {return __p_.mean();}
95 _LIBCPP_INLINE_VISIBILITY
96 result_type stddev() const {return __p_.stddev();}
98 _LIBCPP_INLINE_VISIBILITY
99 param_type param() const {return __p_;}
100 _LIBCPP_INLINE_VISIBILITY
101 void param(const param_type& __p) {__p_ = __p;}
103 _LIBCPP_INLINE_VISIBILITY
104 result_type min() const {return -numeric_limits<result_type>::infinity();}
105 _LIBCPP_INLINE_VISIBILITY
106 result_type max() const {return numeric_limits<result_type>::infinity();}
108 friend _LIBCPP_INLINE_VISIBILITY
109 bool operator==(const normal_distribution& __x,
110 const normal_distribution& __y)
111 {return __x.__p_ == __y.__p_ && __x.__v_hot_ == __y.__v_hot_ &&
112 (!__x.__v_hot_ || __x.__v_ == __y.__v_);}
113 friend _LIBCPP_INLINE_VISIBILITY
114 bool operator!=(const normal_distribution& __x,
115 const normal_distribution& __y)
116 {return !(__x == __y);}
118 template <class _CharT, class _Traits, class _RT>
119 friend
120 basic_ostream<_CharT, _Traits>&
121 operator<<(basic_ostream<_CharT, _Traits>& __os,
122 const normal_distribution<_RT>& __x);
124 template <class _CharT, class _Traits, class _RT>
125 friend
126 basic_istream<_CharT, _Traits>&
127 operator>>(basic_istream<_CharT, _Traits>& __is,
128 normal_distribution<_RT>& __x);
131 template <class _RealType>
132 template<class _URNG>
133 _RealType
134 normal_distribution<_RealType>::operator()(_URNG& __g, const param_type& __p)
136 static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");
137 result_type __up;
138 if (__v_hot_)
140 __v_hot_ = false;
141 __up = __v_;
143 else
145 uniform_real_distribution<result_type> __uni(-1, 1);
146 result_type __u;
147 result_type __v;
148 result_type __s;
151 __u = __uni(__g);
152 __v = __uni(__g);
153 __s = __u * __u + __v * __v;
154 } while (__s > 1 || __s == 0);
155 result_type __fp = _VSTD::sqrt(-2 * _VSTD::log(__s) / __s);
156 __v_ = __v * __fp;
157 __v_hot_ = true;
158 __up = __u * __fp;
160 return __up * __p.stddev() + __p.mean();
163 template <class _CharT, class _Traits, class _RT>
164 _LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>&
165 operator<<(basic_ostream<_CharT, _Traits>& __os,
166 const normal_distribution<_RT>& __x)
168 __save_flags<_CharT, _Traits> __lx(__os);
169 typedef basic_ostream<_CharT, _Traits> _OStream;
170 __os.flags(_OStream::dec | _OStream::left | _OStream::fixed |
171 _OStream::scientific);
172 _CharT __sp = __os.widen(' ');
173 __os.fill(__sp);
174 __os << __x.mean() << __sp << __x.stddev() << __sp << __x.__v_hot_;
175 if (__x.__v_hot_)
176 __os << __sp << __x.__v_;
177 return __os;
180 template <class _CharT, class _Traits, class _RT>
181 _LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>&
182 operator>>(basic_istream<_CharT, _Traits>& __is,
183 normal_distribution<_RT>& __x)
185 typedef normal_distribution<_RT> _Eng;
186 typedef typename _Eng::result_type result_type;
187 typedef typename _Eng::param_type param_type;
188 __save_flags<_CharT, _Traits> __lx(__is);
189 typedef basic_istream<_CharT, _Traits> _Istream;
190 __is.flags(_Istream::dec | _Istream::skipws);
191 result_type __mean;
192 result_type __stddev;
193 result_type __vp = 0;
194 bool __v_hot = false;
195 __is >> __mean >> __stddev >> __v_hot;
196 if (__v_hot)
197 __is >> __vp;
198 if (!__is.fail())
200 __x.param(param_type(__mean, __stddev));
201 __x.__v_hot_ = __v_hot;
202 __x.__v_ = __vp;
204 return __is;
207 _LIBCPP_END_NAMESPACE_STD
209 _LIBCPP_POP_MACROS
211 #endif // _LIBCPP___RANDOM_NORMAL_DISTRIBUTION_H