fix: improve modern diag manager performance (#1634)
[FMS.git] / test_fms / diag_manager / test_reduction_methods.F90
blob7a0bb8efc6d02f1866d746c9d810d1564c1b38eb
1 !***********************************************************************
2 !*                   GNU Lesser General Public License
3 !*
4 !* This file is part of the GFDL Flexible Modeling System (FMS).
5 !*
6 !* FMS is free software: you can redistribute it and/or modify it under
7 !* the terms of the GNU Lesser General Public License as published by
8 !* the Free Software Foundation, either version 3 of the License, or (at
9 !* your option) any later version.
11 !* FMS is distributed in the hope that it will be useful, but WITHOUT
12 !* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 !* FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14 !* for more details.
16 !* You should have received a copy of the GNU Lesser General Public
17 !* License along with FMS.  If not, see <http://www.gnu.org/licenses/>.
18 !***********************************************************************
20 !> @brief  General program to test the different possible reduction methods
21 program test_reduction_methods
22   use fms_mod,           only: fms_init, fms_end
23   use testing_utils,     only: allocate_buffer, test_normal, test_openmp, test_halos, no_mask, logical_mask, real_mask
24   use platform_mod,      only: r8_kind
25   use block_control_mod, only: block_control_type, define_blocks
26   use mpp_mod,           only: mpp_sync, FATAL, mpp_error, mpp_npes, mpp_pe, mpp_root_pe, mpp_broadcast, input_nml_file
27   use time_manager_mod,  only: time_type, set_calendar_type, set_date, JULIAN, set_time, OPERATOR(+)
28   use diag_manager_mod,  only: diag_manager_init, diag_manager_end, diag_axis_init, register_diag_field, &
29                                diag_send_complete, diag_manager_set_time_end, send_data
30   use mpp_domains_mod,   only: domain2d, mpp_define_domains, mpp_define_io_domain, mpp_get_compute_domain, &
31                                mpp_get_data_domain, NORTH, EAST
33   implicit none
35   integer                            :: nx              !< Number of points in the x direction
36   integer                            :: ny              !< Number of points in the y direction
37   integer                            :: nz              !< Number of points in the z direction
38   integer                            :: nw              !< Number of points in the 4th dimension
39   integer                            :: layout(2)       !< Layout
40   integer                            :: io_layout(2)    !< Io layout
41   type(domain2d)                     :: Domain          !< 2D domain
42   integer                            :: isc, isd        !< Starting x compute, data domain index
43   integer                            :: iec, ied        !< Ending x compute, data domain index
44   integer                            :: jsc, jsd        !< Starting y compute, data domaine index
45   integer                            :: jec, jed        !< Ending y compute, data domain index
46   integer                            :: nhalox          !< Number of halos in x
47   integer                            :: nhaloy          !< Number of halos in y
48   real(kind=r8_kind), allocatable    :: cdata(:,:,:,:)  !< Data in the compute domain
49   real(kind=r8_kind), allocatable    :: cdata_corner(:,:,:,:)  !< Data in the compute domain
50   real(kind=r8_kind), allocatable    :: ddata(:,:,:,:)  !< Data in the data domain
51   real(kind=r8_kind), allocatable    :: crmask(:,:,:,:) !< Mask in the compute domain
52   real(kind=r8_kind), allocatable    :: drmask(:,:,:,:) !< Mask in the data domain
53   logical,            allocatable    :: clmask(:,:,:,:) !< Logical mask in the compute domain
54   logical,            allocatable    :: dlmask(:,:,:,:) !< Logical mask in the data domain
55   type(time_type)                    :: Time            !< Time of the simulation
56   type(time_type)                    :: Time_step       !< Time of the simulation
57   integer                            :: ntimes          !< Number of times
58   integer                            :: id_x            !< axis id for the x dimension
59   integer                            :: id_xc           !< axis id for the x dimension (corner)
60   integer                            :: id_y            !< axis id for the y dimension
61   integer                            :: id_yc           !< axis id for the y dimension (corner)
62   integer                            :: id_z            !< axis id for the z dimension
63   integer                            :: id_w            !< axis id for the w dimension
64   integer                            :: id_var0         !< diag_field id for 0d var
65   integer                            :: id_var1         !< diag_field id for 1d var
66   integer                            :: id_var2         !< diag_field id for 2d var
67   integer                            :: id_var2missing  !< diag_field id for a var that is not masked but has missing
68                                                         !! values passed into send_data
69   integer                            :: id_var2c        !< diag_field id for 2d var_corner
70   integer                            :: id_var3         !< diag_field id for 3d var
71   integer                            :: id_var4         !< diag_field id for 4d var
72   integer                            :: id_var999       !< diag_field id for a var that send_data is not called for
73   integer                            :: io_status       !< Status after reading the namelist
74   type(block_control_type)           :: my_block        !< Returns instantiated @ref block_control_type
75   logical                            :: message         !< Flag for outputting debug message
76   integer                            :: isd1            !< Starting x data domain index (1-based)
77   integer                            :: ied1            !< Ending x data domain index (1-based)
78   integer                            :: jsd1            !< Starting y data domain index (1-based)
79   integer                            :: jed1            !< Ending y data domain index (1-based)
80   integer                            :: isw             !< Starting index for each thread in the x direction
81   integer                            :: iew             !< Ending index for each thread in the x direction
82   integer                            :: jsw             !< Starting index for each thread in the y direction
83   integer                            :: jew             !< Ending index for each thread in the y direction
84   integer                            :: is1             !< Starting index for each thread in the x direction (1-based)
85   integer                            :: ie1             !< Ending index for each thread in the x direction (1-based)
86   integer                            :: js1             !< Starting index for each thread in the y direction (1-based)
87   integer                            :: je1             !< Ending index for each thread in the y direction (1-based)
88   integer                            :: iblock          !< For looping through the blocks
89   integer                            :: i               !< For do loops
90   logical                            :: used            !< Dummy argument to send_data
91   real(kind=r8_kind)                 :: missing_value   !< Missing value to use
93   !< Configuration parameters
94   integer :: test_case = test_normal !< Indicates which test case to run
95   integer :: mask_case = no_mask     !< Indicates which masking option to run
96   logical :: use_pow_data = .false.  !< uses simplified smaller dataset for the pow reduction to simplify checks
98   namelist / test_reduction_methods_nml / test_case, mask_case, use_pow_data
100   call fms_init
101   call set_calendar_type(JULIAN)
102   call diag_manager_init
104   read (input_nml_file, test_reduction_methods_nml, iostat=io_status)
105   if (io_status > 0) call mpp_error(FATAL,'=>test_modern_diag: Error reading input.nml')
107   Time = set_date(2,1,1,0,0,0)
108   Time_step = set_time (3600,0) !< 1 hour
109   nx = 96
110   ny = 96
111   nz = 5
112   nw = 2
113   layout = (/1, mpp_npes()/)
114   io_layout = (/1, 1/)
115   nhalox = 2
116   nhaloy = 2
117   ntimes = 48
119   !< Create a lat/lon domain
120   call mpp_define_domains( (/1,nx,1,ny/), layout, Domain, name='2D domain', symmetry=.true., &
121     xhalo=nhalox, yhalo=nhaloy)
122   call mpp_define_io_domain(Domain, io_layout)
123   call mpp_get_compute_domain(Domain, isc, iec, jsc, jec)
124   call mpp_get_data_domain(Domain, isd, ied, jsd, jed)
125   cdata = allocate_buffer(isc, iec, jsc, jec, nz, nw)
126   cdata_corner = allocate_buffer(isc, iec+1, jsc, jec+1, nz, nw)
127   call init_buffer(cdata, isc, iec, jsc, jec, 0)
128   call init_buffer(cdata_corner, isc, iec+1, jsc, jec+1, 0)
130   select case (test_case)
131   case (test_normal)
132     if (mpp_pe() .eq. mpp_root_pe()) print *, "Testing the normal send_data calls"
133   case (test_halos)
134     if (mpp_pe() .eq. mpp_root_pe()) print *, "Testing the send_data calls with halos"
135     ddata = allocate_buffer(isd, ied, jsd, jed, nz, nw)
136     call init_buffer(ddata, isc, iec, jsc, jec, 2) !< The halos never get set
137   case (test_openmp)
138     message = .true.
139     if (mpp_pe() .eq. mpp_root_pe()) print *, "Testing the send_data calls with openmp blocks"
140      call define_blocks ('testing_model', my_block, isc, iec, jsc, jec, kpts=0, &
141                          nx_block=1, ny_block=4, message=message)
142   end select
144   select case (mask_case)
145   case (logical_mask)
146     clmask = allocate_logical_mask(isc, iec, jsc, jec, nz, nw)
147     if (mpp_pe() .eq. 0) clmask(isc, jsc, 1, :) = .False.
149     if (test_case .eq. test_halos) then
150       dlmask = allocate_logical_mask(isd, ied, jsd, jed, nz, nw)
151       if (mpp_pe() .eq. 0) dlmask(1+nhalox, 1+nhaloy, 1, :) = .False.
152     endif
153   case (real_mask)
154     crmask = allocate_real_mask(isc, iec, jsc, jec, nz, nw)
155     if (mpp_pe() .eq. 0) crmask(isc, jsc, 1, :) = 0_r8_kind
157     if (test_case .eq. test_halos) then
158       drmask = allocate_real_mask(isd, ied, jsd, jed, nz, nw)
159       if (mpp_pe() .eq. 0) drmask(1+nhalox, 1+nhaloy, 1, :) = 0_r8_kind
160     endif
161   end select
163   !< Register the axis
164   id_x  = diag_axis_init('x',  real((/ (i, i = 1,nx) /), kind=r8_kind),  'point_E', 'x', long_name='point_E', &
165     Domain2=Domain)
166   id_xc  = diag_axis_init('xc',  real((/ (i, i = 1,nx+1) /), kind=r8_kind),  'point_E corner', 'x', &
167     long_name='point_E', Domain2=Domain, domain_position=EAST)
168   id_y  = diag_axis_init('y',  real((/ (i, i = 1,ny) /), kind=r8_kind),  'point_N', 'y', long_name='point_N', &
169     Domain2=Domain)
170   id_yc  = diag_axis_init('yc',  real((/ (i, i = 1,ny) /), kind=r8_kind),  'point_N corner', 'y', &
171     long_name='point_N', Domain2=Domain, domain_position=NORTH)
172   id_z  = diag_axis_init('z',  real((/ (i, i = 1,nz) /), kind=r8_kind),  'point_Z', 'z', long_name='point_Z')
173   id_w  = diag_axis_init('w',  real((/ (i, i = 1,nw) /), kind=r8_kind),  'point_W', 'n', long_name='point_W')
175   missing_value = -666._r8_kind
176   !< Register the fields
177   id_var0 = register_diag_field  ('ocn_mod', 'var0', Time, 'Var0d', &
178     'mullions', missing_value = missing_value)
179   id_var1 = register_diag_field  ('ocn_mod', 'var1', (/id_x/), Time, 'Var1d', &
180     'mullions', missing_value = missing_value)
181   id_var2 = register_diag_field  ('ocn_mod', 'var2', (/id_x, id_y/), Time, 'Var2d', &
182     'mullions', missing_value = missing_value)
183   id_var2missing = register_diag_field  ('ocn_mod', 'var2missing', (/id_x, id_y/), Time, 'Var2d', &
184     'mullions', missing_value = missing_value)
185   id_var2c = register_diag_field  ('ocn_mod', 'var2c', (/id_xc, id_yc/), Time, 'Var2d corner', &
186     'mullions', missing_value = missing_value)
187   id_var3 = register_diag_field  ('ocn_mod', 'var3', (/id_x, id_y, id_z/), Time, 'Var3d', &
188     'mullions', missing_value = missing_value)
189   id_var4 = register_diag_field  ('ocn_mod', 'var4', (/id_x, id_y, id_z, id_w/), Time, 'Var4d', &
190     'mullions', missing_value = missing_value)
191   id_var999 = register_diag_field  ('ocn_mod', 'IOnASphere', Time, missing_value=missing_value)
193   !< Get the data domain indices (1 based)
194   isd1 = isc-isd+1
195   jsd1 = jsc-jsd+1
196   ied1 = isd1 + iec-isc
197   jed1 = jsd1 + jec-jsc
199   call diag_manager_set_time_end(set_date(2,1,3,0,0,0))
200   do i = 1, ntimes
201     Time = Time + Time_step
203     call set_buffer(cdata, i)
204     call set_buffer(cdata_corner, i)
206     ! This is passing in the data with missing values, but the variable is not masked.
207     ! An error is expected in this case.
208     used = send_data(id_var2missing, cdata(:,:,1,1)*0_r8_kind + missing_value, Time)
210     used = send_data(id_var2c, cdata_corner(:,:,1,1), Time)
211     used = send_data(id_var0, cdata(1,1,1,1), Time)
213     select case(test_case)
214     case (test_normal)
215       select case (mask_case)
216       case (no_mask)
217         used = send_data(id_var1, cdata(:,1,1,1), Time)
218         used = send_data(id_var2, cdata(:,:,1,1), Time)
219         used = send_data(id_var3, cdata(:,:,:,1), Time)
220         used = send_data(id_var4, cdata(:,:,:,:), Time)
221       case (real_mask)
222         used = send_data(id_var1, cdata(:,1,1,1), Time, rmask=crmask(:,1,1,1))
223         used = send_data(id_var2, cdata(:,:,1,1), Time, rmask=crmask(:,:,1,1))
224         used = send_data(id_var3, cdata(:,:,:,1), Time, rmask=crmask(:,:,:,1))
225         used = send_data(id_var4, cdata(:,:,:,:), Time, rmask=crmask(:,:,:,:))
226       case (logical_mask)
227         used = send_data(id_var1, cdata(:,1,1,1), Time, mask=clmask(:,1,1,1))
228         used = send_data(id_var2, cdata(:,:,1,1), Time, mask=clmask(:,:,1,1))
229         used = send_data(id_var3, cdata(:,:,:,1), Time, mask=clmask(:,:,:,1))
230         used = send_data(id_var4, cdata(:,:,:,:), Time, mask=clmask(:,:,:,:))
231       end select
232     case (test_halos)
233       call set_buffer(ddata, i)
234       select case (mask_case)
235       case (no_mask)
236         used = send_data(id_var1, cdata(:,1,1,1), Time)
237         used = send_data(id_var2, ddata(:,:,1,1), Time, &
238           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1)
239         used = send_data(id_var3, ddata(:,:,:,1), Time, &
240           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1)
241         used = send_data(id_var4, ddata(:,:,:,:), Time, &
242           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1)
243       case (real_mask)
244         used = send_data(id_var1, cdata(:,1,1,1), Time, &
245           rmask=crmask(:,1,1,1))
246         used = send_data(id_var2, ddata(:,:,1,1), Time, &
247           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
248           rmask=drmask(:,:,1,1))
249         used = send_data(id_var3, ddata(:,:,:,1), Time, &
250           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
251           rmask=drmask(:,:,:,1))
252         used = send_data(id_var4, ddata(:,:,:,:), Time, &
253           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
254           rmask=drmask(:,:,:,:))
255       case (logical_mask)
256         used = send_data(id_var1, cdata(:,1,1,1), Time, &
257           mask=clmask(:,1,1,1))
258         used = send_data(id_var2, ddata(:,:,1,1), Time, &
259           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
260           mask=dlmask(:,:,1,1))
261         used = send_data(id_var3, ddata(:,:,:,1), Time, &
262           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
263           mask=dlmask(:,:,:,1))
264         used = send_data(id_var4, ddata(:,:,:,:), Time, &
265           is_in=isd1, ie_in=ied1, js_in=jsd1, je_in=jed1, &
266           mask=dlmask(:,:,:,:))
267       end select
268     case (test_openmp)
269       select case(mask_case)
270       case (no_mask)
271         used=send_data(id_var1, cdata(:, 1, 1, 1), time)
272       case (logical_mask)
273         used=send_data(id_var1, cdata(:, 1, 1, 1), time, &
274             mask=clmask(:, 1, 1, 1))
275       case (real_mask)
276         used=send_data(id_var1, cdata(:, 1, 1, 1), time, &
277             rmask=crmask(:, 1, 1, 1))
278       end select
279 !$OMP parallel do default(shared) private(iblock, isw, iew, jsw, jew, is1, ie1, js1, je1)
280       do iblock=1, 4
281         isw = my_block%ibs(iblock)
282         jsw = my_block%jbs(iblock)
283         iew = my_block%ibe(iblock)
284         jew = my_block%jbe(iblock)
286       !--- indices for 1-based arrays ---
287         is1 = isw-isc+1
288         ie1 = iew-isc+1
289         js1 = jsw-jsc+1
290         je1 = jew-jsc+1
292         select case (mask_case)
293         case (no_mask)
294           used=send_data(id_var2, cdata(is1:ie1, js1:je1, 1, 1), time, is_in=is1, js_in=js1)
295           used=send_data(id_var3, cdata(is1:ie1, js1:je1, :, 1), time, is_in=is1, js_in=js1)
296           used=send_data(id_var4, cdata(is1:ie1, js1:je1, :, :), time, is_in=is1, js_in=js1)
297         case (real_mask)
298           used=send_data(id_var2, cdata(is1:ie1, js1:je1, 1, 1), time, is_in=is1, js_in=js1, &
299             rmask=crmask(is1:ie1, js1:je1, 1, 1))
300           used=send_data(id_var3, cdata(is1:ie1, js1:je1, :, 1), time, is_in=is1, js_in=js1, &
301             rmask=crmask(is1:ie1, js1:je1, :, 1))
302           used=send_data(id_var4, cdata(is1:ie1, js1:je1, :, :), time, is_in=is1, js_in=js1, &
303             rmask=crmask(is1:ie1, js1:je1, :, :))
304         case (logical_mask)
305           used=send_data(id_var2, cdata(is1:ie1, js1:je1, 1, 1), time, is_in=is1, js_in=js1, &
306             mask=clmask(is1:ie1, js1:je1, 1, 1))
307           used=send_data(id_var3, cdata(is1:ie1, js1:je1, :, 1), time, is_in=is1, js_in=js1, &
308             mask=clmask(is1:ie1, js1:je1, :, 1))
309           used=send_data(id_var4, cdata(is1:ie1, js1:je1, :, :), time, is_in=is1, js_in=js1, &
310             mask=clmask(is1:ie1, js1:je1, :, :))
311         end select
312       enddo
313     end select
314     call diag_send_complete(Time_step)
315     call diag_send_complete(Time_step)
316   enddo
318   call diag_manager_end(Time)
320   call fms_end
322   contains
324   !> @brief Allocate the logical mask based on the starting/ending indices
325   !! @return logical mask initiliazed to .True.
326   function allocate_logical_mask(is, ie, js, je, k, l) &
327   result(buffer)
328     integer, intent(in) :: is !< Starting x index
329     integer, intent(in) :: ie !< Ending x index
330     integer, intent(in) :: js !< Starting y index
331     integer, intent(in) :: je !< Ending y index
332     integer, intent(in) :: k  !< Number of points in the 4th dimension
333     integer, intent(in) :: l  !< Number of points in the 5th dimension
335     logical, allocatable :: buffer(:,:,:,:)
337     allocate(buffer(is:ie, js:je, 1:k, 1:l))
338     buffer = .True.
339   end function allocate_logical_mask
341   !> @brief Allocate the real mask based on the starting/ending indices
342   !! @returnreal mask initiliazed to 1_r8_kind
343   function allocate_real_mask(is, ie, js, je, k, l) &
344   result(buffer)
345     integer, intent(in) :: is !< Starting x index
346     integer, intent(in) :: ie !< Ending x index
347     integer, intent(in) :: js !< Starting y index
348     integer, intent(in) :: je !< Ending y index
349     integer, intent(in) :: k  !< Number of points in the 4th dimension
350     integer, intent(in) :: l  !< Number of points in the 5th dimension
351     real(kind=r8_kind), allocatable :: buffer(:,:,:,:)
353     allocate(buffer(is:ie, js:je, 1:k, 1:l))
354     buffer = 1.0_r8_kind
355   end function allocate_real_mask
357   !> @brief initiliazed the buffer based on the starting/ending indices
358   subroutine init_buffer(buffer, is, ie, js, je, nhalo)
359     real(kind=r8_kind), intent(inout) :: buffer(:,:,:,:) !< output buffer
360     integer,            intent(in)    :: is              !< Starting x index
361     integer,            intent(in)    :: ie              !< Ending x index
362     integer,            intent(in)    :: js              !< Starting y index
363     integer,            intent(in)    :: je              !< Ending y index
364     integer,            intent(in)    :: nhalo           !< Number of halos
366     integer :: ii, j, k, l
368     do ii = is, ie
369       do j = js, je
370         do k = 1, size(buffer, 3)
371           do l = 1, size(buffer,4)
372             if(.not. use_pow_data) then
373               buffer(ii-is+1+nhalo, j-js+1+nhalo, k, l) = real(ii, kind=r8_kind)* 1000_r8_kind + &
374                 real(j, kind=r8_kind)* 10_r8_kind + &
375                 real(k, kind=r8_kind)
376             else
377               ! just sends the sum of indices for pow
378               buffer(ii-is+1+nhalo, j-js+1+nhalo, k, l) = ii + j + k + l
379             endif
380           enddo
381         enddo
382       enddo
383     enddo
385   end subroutine init_buffer
387   !> @brief Set the buffer based on the time_index
388   subroutine set_buffer(buffer, time_index)
389     real(kind=r8_kind), intent(inout) :: buffer(:,:,:,:) !< Output buffer
390     integer,            intent(in)    :: time_index      !< Time index
392     if(use_pow_data) return
393     buffer = nint(buffer) + real(time_index, kind=r8_kind)/100_r8_kind
395   end subroutine set_buffer
397 end program test_reduction_methods