YAC 3.12.0
Yet Another Coupler
Loading...
Searching...
No Matches
dist_merge.c
Go to the documentation of this file.
1// Copyright (c) 2024 The YAC Authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#include <string.h>
6
7#include "dist_merge.h"
8#include "utils_common.h"
9
10static size_t get_pack_size(
11 size_t count, unsigned char * array, size_t element_size, MPI_Comm comm,
12 size_t(*element_get_pack_size)(void * element, MPI_Comm comm)) {
13
14 if (count == 0) return 0;
15
16 int count_pack_size;
17 yac_mpi_call(MPI_Pack_size(1, MPI_INT, comm, &count_pack_size), comm);
18
19 size_t array_pack_size = 0;
20 for (size_t i = 0; i < count; ++i)
21 array_pack_size +=
22 element_get_pack_size((void*)(array + i * element_size), comm);
23
24 return (size_t)count_pack_size + array_pack_size;
25}
26
27static void pack(
28 size_t count, unsigned char * array, size_t element_size, void * buffer,
29 int buffer_size, int * position, MPI_Comm comm,
30 void(*element_pack)(
31 void * element, void * buffer, int buffer_size,
32 int * position, MPI_Comm)) {
33
34 int count_int = (int)count;
36 MPI_Pack(
37 &count_int, 1, MPI_INT, buffer, buffer_size, position, comm), comm);
38
39 for (size_t i = 0; i < count; ++i)
40 element_pack(
41 (void*)(array + i * element_size), buffer, buffer_size, position, comm);
42}
43
44static void unpack(
45 void * buffer, int buffer_size, int * position, size_t * count,
46 unsigned char ** array, size_t element_size, MPI_Comm comm,
47 void(*element_unpack)(
48 void * buffer, int buffer_size, int * position,
49 void * element, MPI_Comm comm)) {
50
51 int count_int;
53 MPI_Unpack(
54 buffer, buffer_size, position, &count_int, 1, MPI_INT, comm), comm);
55
56 *count = (size_t)count_int;
57 *array = xmalloc(*count * element_size);
58
59 for (size_t i = 0; i < *count; ++i)
60 element_unpack(
61 buffer, buffer_size, position, (void*)(*array + i * element_size), comm);
62}
63
65 size_t * count, void ** array, size_t element_size, MPI_Comm comm,
66 struct yac_dist_merge_vtable * vtable, size_t ** idx_old_to_new) {
67
68 int rank;
69 MPI_Comm_rank(comm, &rank);
70
71 // initialise
72 unsigned char * input = *array;
73 size_t input_len = *count;
74 size_t* idx = NULL;
75 if (idx_old_to_new) {
76 *idx_old_to_new = xmalloc(input_len * sizeof(**idx_old_to_new));
77 idx = xmalloc(input_len * sizeof(*idx));
78 for(size_t i = 0;i<input_len;++i) idx[i] = i;
79 }
80 unsigned char * arr_new = NULL;
81 size_t len_new = 0;
82
84 input_len < INT_MAX, "ERROR(yac_dist_merge): too many elements");
85
86 // sort input data so that the processing order on all processes is identical
87 if (idx)
88 yac_qsort_index(input, input_len, element_size, vtable->compare, idx);
89 else
90 qsort((void*)input, input_len, element_size, vtable->compare);
91
92 // loop until no process has any unsychronised elements left
93 void * buffer = NULL;
94 while(1){
95
96 // determine pack size of all remaining local data
97 size_t pack_size =
99 input_len, input, element_size, comm, vtable->get_pack_size);
101 pack_size <= LONG_MAX, "ERROR(yac_dist_merge): packing size too big");
102
103 // determine rank with most amount of data
104 struct {
105 long pack_size;
106 int rank;
107 } data_pair = {.pack_size = (long)pack_size, .rank = rank};
109 MPI_Allreduce(
110 MPI_IN_PLACE, &data_pair, 1, MPI_LONG_INT, MPI_MAXLOC, comm), comm);
111
112 // if there is no more data to exchange, the sychronisation is finished
113 if (data_pair.pack_size == 0) break;
114
115 // allocate the buffer according to the processes with the most amount
116 // of unsychronised data
117 pack_size = (size_t)data_pair.pack_size;
118 if (!buffer) buffer = xmalloc(pack_size);
119 int position = 0;
120
121 // the process with the most amount of data packs it
122 if(data_pair.rank == rank)
123 pack(
124 input_len, input, element_size, buffer, pack_size, &position, comm,
125 vtable->pack);
126
127 // broadcast and unpack data (this only contains the basic information
128 // required to identify an element)
130 MPI_Bcast(buffer, pack_size, MPI_PACKED, data_pair.rank, comm), comm);
131 unsigned char * recved = NULL;
132 size_t num_recved;
133 position = 0;
134 unpack(
135 buffer, pack_size, &position, &num_recved, &recved, element_size, comm,
136 vtable->unpack);
137
138 // copy the received elements into the result array
139 arr_new = xrealloc(arr_new, (len_new + num_recved)*element_size);
140 memcpy(
141 arr_new + len_new*element_size, (void*)recved, num_recved*element_size);
142 free(recved);
143
144 // merge received elements into the result array
145 size_t input_idx, input_len_new, i = len_new;
146 len_new += num_recved;
147 for(input_idx = 0, input_len_new = 0; i < len_new; ++i) {
148 void* recved_element = arr_new + i*element_size;
149
150 // search for matching element in input list until an element is
151 // found, which is "bigger" (as defined by the compare function) or the
152 // end of the input list was reached
153 int cmp = 0;
154 void * input_element = NULL;
155 while ((input_idx < input_len) &&
156 (((cmp =
157 vtable->compare(
158 ((input_element = input + input_idx * element_size)),
159 recved_element))) < 0)) {
160
161 // elements from the input list that matched with a received element
162 // are removed from the input list
163 if (input_idx != input_len_new) {
164 memcpy(
165 input + input_len_new * element_size,
166 input_element, element_size);
167 if (idx) idx[input_len_new] = idx[input_idx];
168 }
169 ++input_len_new;
170 ++input_idx;
171 }
172
173 // if the end of the local input list was reached, no more merging is required
174 if (input_idx == input_len) break;
175
176 // if a matching element was found in the input list
177 if (!cmp) {
178
179 // merge input list element with received element
180 if (vtable->merge)
181 vtable->merge(recved_element, input_element, comm);
182 // free the element from the input list
183 vtable->free_data(input_element);
184 // keep track on where the original element is in the result list
185 if (idx_old_to_new) (*idx_old_to_new)[idx[input_idx]] = i;
186 // upate input list idx
187 input_idx++;
188
189 // if no matching element was found in the input list
190 } else if (vtable->merge) {
191 // since the merge operation can potentially be collective, we have
192 // to call it, even if no matching element was found
193 vtable->merge(recved_element, NULL, comm);
194 }
195 }
196
197 // if the end of the input list was reched:
198 // process remaining received elements
199 if (vtable->merge)
200 for(; i < len_new; ++i)
201 vtable->merge(arr_new + i*element_size, NULL, comm);
202
203 // if all received elements where processed:
204 // compress remaining elements in input list
205 for (; input_idx < input_len; ++input_idx, ++input_len_new) {
206 if (input_idx != input_len_new) {
207 memcpy(
208 input + input_len_new * element_size,
209 input + input_idx * element_size, element_size);
210 if (idx) idx[input_len_new] = idx[input_idx];
211 }
212 }
213 // update length of input list
214 input_len = input_len_new;
215 }
216 free(buffer);
217 free(input);
218 free(idx);
219 *array = arr_new;
220 *count = len_new;
221}
#define YAC_ASSERT(exp, msg)
static void pack(size_t count, unsigned char *array, size_t element_size, void *buffer, int buffer_size, int *position, MPI_Comm comm, void(*element_pack)(void *element, void *buffer, int buffer_size, int *position, MPI_Comm))
Definition dist_merge.c:27
void yac_dist_merge(size_t *count, void **array, size_t element_size, MPI_Comm comm, struct yac_dist_merge_vtable *vtable, size_t **idx_old_to_new)
Definition dist_merge.c:64
static void unpack(void *buffer, int buffer_size, int *position, size_t *count, unsigned char **array, size_t element_size, MPI_Comm comm, void(*element_unpack)(void *buffer, int buffer_size, int *position, void *element, MPI_Comm comm))
Definition dist_merge.c:44
static size_t get_pack_size(size_t count, unsigned char *array, size_t element_size, MPI_Comm comm, size_t(*element_get_pack_size)(void *element, MPI_Comm comm))
Definition dist_merge.c:10
#define xrealloc(ptr, size)
Definition ppm_xfuncs.h:67
#define xmalloc(size)
Definition ppm_xfuncs.h:66
void yac_qsort_index(void *a_, size_t count, size_t size, int(*compare)(void const *, void const *), size_t *idx)
Definition quicksort.c:185
void(* unpack)(void *buffer, int buffer_size, int *position, void *element, MPI_Comm comm)
Definition dist_merge.h:32
int(* compare)(void const *a, void const *b)
Definition dist_merge.h:39
void(* free_data)(void *element)
Definition dist_merge.h:61
void(* merge)(void *to, void *from, MPI_Comm comm)
Definition dist_merge.h:56
size_t(* get_pack_size)(void *element, MPI_Comm comm)
Definition dist_merge.h:20
void(* pack)(void *element, void *buffer, int buffer_size, int *position, MPI_Comm)
Definition dist_merge.h:25
double * buffer
#define yac_mpi_call(call, comm)