YAC 3.13.0
Yet Another Coupler
Loading...
Searching...
No Matches
test_group_comm.c
Go to the documentation of this file.
1// Copyright (c) 2024 The YAC Authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#include <stdlib.h>
6#include <mpi.h>
7#include <math.h>
8
9#include "tests.h"
10#include "yac_mpi_internal.h"
11
17int main (void) {
18
19 MPI_Init(NULL, NULL);
20
21 int comm_rank, comm_size;
22 MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank);
23 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
24
25 if (comm_size != 9) {
26
27 PUT_ERR("ERROR: wrong number of processes\n");
28 return TEST_EXIT_CODE;
29 }
30
31 struct yac_group_comm world_group_comm = yac_group_comm_new(MPI_COMM_WORLD);
32
33 if (yac_group_comm_get_global_rank(world_group_comm) != comm_rank)
34 PUT_ERR("ERROR in yac_group_comm_get_global_rank");
35 if (yac_group_comm_get_global_size(world_group_comm) != comm_size)
36 PUT_ERR("ERROR in yac_group_comm_get_global_size");
37 if (yac_group_comm_get_rank(world_group_comm) != comm_rank)
38 PUT_ERR("ERROR in yac_group_comm_get_rank");
39 if (yac_group_comm_get_size(world_group_comm) != comm_size)
40 PUT_ERR("ERROR in yac_group_comm_get_size");
41
42 struct yac_group_comm local_group_comm, remote_group_comm;
43 int split_rank = 3;
45 world_group_comm, split_rank, &local_group_comm, &remote_group_comm);
46
47 int group_idx = comm_rank >= split_rank;
48
49 int ref_group_size[2] = {3, 6};
50 int group_rank_offset[2] = {0, 3};
51
52 if (yac_group_comm_get_global_rank(local_group_comm) != comm_rank)
53 PUT_ERR("ERROR in yac_group_comm_get_global_rank");
54 if (yac_group_comm_get_global_size(local_group_comm) != comm_size)
55 PUT_ERR("ERROR in yac_group_comm_get_global_size");
56 if (yac_group_comm_get_global_rank(remote_group_comm) != comm_rank)
57 PUT_ERR("ERROR in yac_group_comm_get_global_rank");
58 if (yac_group_comm_get_global_size(remote_group_comm) != comm_size)
59 PUT_ERR("ERROR in yac_group_comm_get_global_size");
60
61 if (yac_group_comm_get_rank(local_group_comm) !=
62 (comm_rank - group_rank_offset[group_idx]))
63 PUT_ERR("ERROR in yac_group_comm_get_rank");
64 if (yac_group_comm_get_size(local_group_comm) != ref_group_size[group_idx])
65 PUT_ERR("ERROR in yac_group_comm_get_size");
66 if (yac_group_comm_get_rank(remote_group_comm) !=
67 (comm_rank - group_rank_offset[group_idx^1]))
68 PUT_ERR("ERROR in yac_group_comm_get_rank");
69 if (yac_group_comm_get_size(remote_group_comm) != ref_group_size[group_idx^1])
70 PUT_ERR("ERROR in yac_group_comm_get_size");
71
72 double dble_proc_data[9][3] = {{1,2,3},
73 {4,5,6},
74 {7,8,9},
75 {10,11,12},
76 {13,14,15},
77 {16,17,18},
78 {19,20,21},
79 {22,23,24},
80 {25,26,27}};
81 double ref_dble_sum[2][3] = {{1+4+7,2+5+8,3+6+9},
82 {10+13+16+19+22+25,
83 11+14+17+20+23+26,
84 12+15+18+21+24+27}};
85
86 double local_dble_proc_data[3];
87 for (int i = 0; i < 3; ++i)
88 local_dble_proc_data[i] = dble_proc_data[comm_rank][i];
89 yac_allreduce_sum_dble(local_dble_proc_data, 3, local_group_comm);
90 for (int i = 0; i < 3; ++i)
91 if (fabs(local_dble_proc_data[i] - ref_dble_sum[group_idx][i]) > 1e-9)
92 PUT_ERR("ERROR in yac_allreduce_sum_dble");
93
94 size_t size_t_proc_data[9][2] = {{1,2},
95 {3,4},
96 {5,6},
97 {7,8},
98 {9,10},
99 {11,12},
100 {13,14},
101 {15,16},
102 {17,18}};
103 size_t size_t_recv_buffer[12];
104 size_t ref_size_t_allgather[2][12] = {{1,2,3,4,5,6},
105 {7,8,9,10,11,12,13,14,15,16,17,18}};
106 size_t ref_size_t_allgather_size[2] = {2*3, 2*6};
107
109 size_t_proc_data[comm_rank], size_t_recv_buffer, 2, local_group_comm);
110 for (size_t i = 0; i < ref_size_t_allgather_size[group_idx]; ++i)
111 if (size_t_recv_buffer[i] != ref_size_t_allgather[group_idx][i])
112 PUT_ERR("ERROR in yac_allgather_size_t");
113
114 for (int i = 0; i < comm_size; ++i) {
115
116 // if current rank is root
117 if (i == comm_rank) {
118
119 // bcast to remote group
121 dble_proc_data[comm_rank], 3, MPI_DOUBLE, i, remote_group_comm);
122
123 // bcast to local group
125 dble_proc_data[comm_rank], 3, MPI_DOUBLE, i, local_group_comm);
126
127 } else {
128
129 // bcast within local group
130 double bcast_data[3];
131 yac_bcast_group(bcast_data, 3, MPI_DOUBLE, i, local_group_comm);
132 for (int j = 0; j < 3; ++j)
133 if (fabs(bcast_data[j] - dble_proc_data[i][j]) > 1e-9)
134 PUT_ERR("ERROR in yac_bcast_group");
135 }
136 }
137
138 yac_group_comm_delete(world_group_comm);
139
140 MPI_Finalize();
141
142 return TEST_EXIT_CODE;
143}
#define TEST_EXIT_CODE
Definition tests.h:14
#define PUT_ERR(string)
Definition tests.h:10
int yac_group_comm_get_global_rank(struct yac_group_comm group_comm)
Definition yac_mpi.c:500
int yac_group_comm_get_rank(struct yac_group_comm group_comm)
Definition yac_mpi.c:492
void yac_group_comm_split(struct yac_group_comm group_comm, int split_rank, struct yac_group_comm *local_group_comm, struct yac_group_comm *remote_group_comm)
Definition yac_mpi.c:512
int yac_group_comm_get_global_size(struct yac_group_comm group_comm)
Definition yac_mpi.c:506
struct yac_group_comm yac_group_comm_new(MPI_Comm comm)
Definition yac_mpi.c:477
void yac_allreduce_sum_dble(double *buffer, int count, struct yac_group_comm group_comm)
Definition yac_mpi.c:296
int yac_group_comm_get_size(struct yac_group_comm group_comm)
Definition yac_mpi.c:496
void yac_bcast_group(void *buffer, int count, MPI_Datatype datatype, int root, struct yac_group_comm group_comm)
Definition yac_mpi.c:412
void yac_allgather_size_t(const size_t *sendbuf, size_t *recvbuf, int count, struct yac_group_comm group_comm)
Definition yac_mpi.c:368
void yac_group_comm_delete(struct yac_group_comm group_comm)
Definition yac_mpi.c:487