YetAnotherCoupler 3.2.0_a
Loading...
Searching...
No Matches
interp_method_callback.c
Go to the documentation of this file.
1// Copyright (c) 2024 The YAC Authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifdef HAVE_CONFIG_H
6// Get the definition of the 'restrict' keyword.
7#include "config.h"
8#endif
9
10#include <string.h>
11
13#include "utils_core.h"
14#include "yac_mpi_internal.h"
15#include "ensure_array_size.h"
17
23
24static size_t do_search_callback(struct interp_method * method,
25 struct yac_interp_grid * interp_grid,
26 size_t * tgt_points, size_t count,
27 struct yac_interp_weights * weights);
28static void delete_callback(struct interp_method * method);
29
30static struct interp_method_vtable
34
41
42static void get_orig_data(
43 struct remote_point const * remote_point,
44 int * owner_rank, size_t * orig_pos) {
45
46 struct remote_point_info const * point_infos = NULL;
47
48 if (remote_point->data.count == 1) {
49 point_infos = &(remote_point->data.data.single);
50 } else {
51 int min_rank = INT_MAX;
52 for (int i = 0; i < remote_point->data.count; ++i) {
53 int curr_rank = remote_point->data.data.multi[i].rank;
54 if (curr_rank < min_rank) {
55 point_infos = remote_point->data.data.multi + i;
56 min_rank = curr_rank;
57 }
58 }
59 }
60
61 YAC_ASSERT(point_infos != NULL, "ERROR(get_orig_data): internal error")
62
63 *owner_rank = point_infos->rank;
64 *orig_pos = point_infos->orig_pos;
65}
66
67static MPI_Datatype yac_get_request_data_mpi_datatype(MPI_Comm comm) {
68
69 struct tgt_request_data dummy;
70 MPI_Datatype tgt_request_data_dt;
71 int array_of_blocklengths[] = {1, 1, 3};
72 const MPI_Aint array_of_displacements[] =
73 {(MPI_Aint)(intptr_t)(const void *)&(dummy.src_orig_pos) -
74 (MPI_Aint)(intptr_t)(const void *)&dummy,
75 (MPI_Aint)(intptr_t)(const void *)&(dummy.src_global_id) -
76 (MPI_Aint)(intptr_t)(const void *)&dummy,
77 (MPI_Aint)(intptr_t)(const void *)&(dummy.tgt_coord[0]) -
78 (MPI_Aint)(intptr_t)(const void *)&dummy};
79 const MPI_Datatype array_of_types[] =
80 {MPI_UINT64_T, yac_int_dt, MPI_DOUBLE};
82 MPI_Type_create_struct(3, array_of_blocklengths, array_of_displacements,
83 array_of_types, &tgt_request_data_dt), comm);
84 return yac_create_resized(tgt_request_data_dt, sizeof(dummy), comm);
85}
86
87// move the target points, which have a valid source cells
88// (invalid src_cell == SIZE_MAX) to the front of the array and count the number
89// of valid results
90static size_t get_valid_results(
91 size_t * src_cells, size_t count, size_t * tgt_points) {
92
93 size_t valid_count = 0;
94 size_t end_pos = count;
95
96 while (valid_count < end_pos) {
97
98 // find next invalid result
99 while ((valid_count < end_pos) && (src_cells[valid_count] != SIZE_MAX))
100 ++valid_count;
101
102 // find valid result from the end of the array
103 do {
104 --end_pos;
105 } while ((valid_count < end_pos) && (src_cells[end_pos] == SIZE_MAX));
106
107 // switch valid result with invalid one
108 if (valid_count < end_pos) {
109 size_t temp_src_cell = src_cells[valid_count];
110 size_t temp_tgt_point = tgt_points[valid_count];
111 src_cells[valid_count] = src_cells[end_pos];
112 tgt_points[valid_count] = tgt_points[end_pos];
113 src_cells[end_pos] = temp_src_cell;
114 tgt_points[end_pos] = temp_tgt_point;
115 ++valid_count;
116 }
117 }
118
119 return valid_count;
120}
121
122static size_t do_search_callback (struct interp_method * method,
123 struct yac_interp_grid * interp_grid,
124 size_t * tgt_points, size_t count,
125 struct yac_interp_weights * weights) {
126
127 struct interp_method_callback * method_callback =
128 (struct interp_method_callback *)method;
129
130 MPI_Comm comm = yac_interp_grid_get_MPI_Comm(interp_grid);
131 int comm_rank, comm_size;
132 yac_mpi_call(MPI_Comm_rank(comm, &comm_rank), comm);
133 yac_mpi_call(MPI_Comm_size(comm, &comm_size), comm);
134
135 // get coordinates of target points
136 yac_coordinate_pointer tgt_coords = xmalloc(count * sizeof(*tgt_coords));
138 interp_grid, tgt_points, count, tgt_coords);
139
140 // get matching source cells for all target points
141 size_t * src_cells = xmalloc(count * sizeof(*src_cells));
143 interp_grid, tgt_coords, count, src_cells);
144 free(tgt_coords);
145
146 // move the target points, which have a valid source cells
147 // (invalid src_cell == SIZE_MAX) to the front of the array
148 size_t temp_result_count = get_valid_results(src_cells, count, tgt_points);
149
150 // get all unique source result cells
151 yac_quicksort_index_size_t_size_t(src_cells, temp_result_count, tgt_points);
152 size_t num_unique_src_cells = 0;
153 size_t * src_to_unique_src =
154 xmalloc(temp_result_count * sizeof(*src_to_unique_src));
155 for (size_t i = 0, prev_src_cell = SIZE_MAX; i < temp_result_count; ++i) {
156 size_t curr_src_cell = src_cells[i];
157 if (curr_src_cell != prev_src_cell) {
158 src_cells[num_unique_src_cells++] = curr_src_cell;
159 prev_src_cell = curr_src_cell;
160 }
161 src_to_unique_src[i] = num_unique_src_cells - 1;
162 }
163
164 // get remote point information for all unique source result cell
165 struct remote_point * src_remote_points =
167 interp_grid, YAC_LOC_CELL, src_cells, num_unique_src_cells);
168 free(src_cells);
169
170 // get original owners of all unique source result cells
171 int * orig_src_cell_ranks =
172 xmalloc(num_unique_src_cells * sizeof(*orig_src_cell_ranks));
173 size_t * orig_src_cell_pos =
174 xmalloc(num_unique_src_cells * sizeof(*orig_src_cell_pos));
175 for (size_t i = 0; i < num_unique_src_cells; ++i)
177 src_remote_points + i, orig_src_cell_ranks + i, orig_src_cell_pos + i);
178
179 // set up communication buffers
180 size_t * sendcounts, * recvcounts, * sdispls, * rdispls;
182 1, &sendcounts, &recvcounts, &sdispls, &rdispls, comm);
183 // count number of target requests per rank
184 for (size_t i = 0; i < temp_result_count; ++i)
185 sendcounts[orig_src_cell_ranks[src_to_unique_src[i]]]++;
187 1, sendcounts, recvcounts, sdispls, rdispls, comm);
188 size_t request_count = recvcounts[comm_size - 1] + rdispls[comm_size - 1];
189
190 // pack target request data
191 yac_const_coordinate_pointer tgt_field_coords =
193 struct tgt_request_data * request_buffer =
194 xmalloc((temp_result_count + request_count) * sizeof(*request_buffer));
195 struct tgt_request_data * request_send_buffer = request_buffer;
196 struct tgt_request_data * request_recv_buffer =
197 request_buffer + temp_result_count;
198 size_t * new_tgt_points =
199 xmalloc(temp_result_count * sizeof(*new_tgt_points));
200 for (size_t i = 0; i < temp_result_count; ++i) {
201 size_t unique_src_cell_idx = src_to_unique_src[i];
202 size_t pos = sdispls[orig_src_cell_ranks[unique_src_cell_idx] + 1]++;
203 request_send_buffer[pos].src_orig_pos =
204 orig_src_cell_pos[unique_src_cell_idx];
205 request_send_buffer[pos].src_global_id =
206 src_remote_points[unique_src_cell_idx].global_id;
207 memcpy(request_send_buffer[pos].tgt_coord, tgt_field_coords[tgt_points[i]],
208 3 * sizeof(double));
209 new_tgt_points[pos] = tgt_points[i];
210 }
211 free(src_remote_points);
212 free(orig_src_cell_pos);
213 free(src_to_unique_src);
214 free(orig_src_cell_ranks);
215
216 // bring target points into the order in which we will receive the results
217 memcpy(tgt_points, new_tgt_points, temp_result_count * sizeof(*tgt_points));
218 free(new_tgt_points);
219
220 // transfer tgt coords, orig src pos and src id
221 MPI_Datatype request_data_dt = yac_get_request_data_mpi_datatype(comm);
223 request_send_buffer, sendcounts, sdispls,
224 request_recv_buffer, recvcounts, rdispls,
225 sizeof(*request_recv_buffer), request_data_dt, comm);
226 yac_mpi_call(MPI_Type_free(&request_data_dt), comm);
227
229 method_callback->compute_weights_callback;
230 void * user_data = method_callback->user_data;
231 size_t num_src_fields = yac_interp_grid_get_num_src_fields(interp_grid);
232 uint64_t * uint64_t_buffer =
233 xmalloc(num_src_fields * (request_count + temp_result_count) *
234 sizeof(*uint64_t_buffer));
235 uint64_t * temp_num_results_per_src_field_per_tgt = uint64_t_buffer;
236 uint64_t * num_results_per_src_field_per_tgt_uint64 =
237 uint64_t_buffer + num_src_fields * request_count;
238 yac_int * temp_global_ids = NULL;
239 size_t temp_global_ids_array_size = 0;
240 double * temp_w = NULL;
241 size_t temp_w_array_size = 0;
242 size_t temp_weights_count = 0;
243
244 // the weight computation function should be available on
245 // all source
247 (request_count == 0) || (compute_weights != NULL),
248 "ERROR(do_search_callback): "
249 "no callback routine defined on source process")
250
251 // compute weights and store results
252 for (size_t i = 0, k = 0; i < request_count; ++i) {
253
254 // get weights for the current target point from the user
255 int const * curr_global_result_points[num_src_fields];
256 double * curr_result_weights[num_src_fields];
257 size_t curr_result_counts[num_src_fields];
259 (double const *)(request_recv_buffer[i].tgt_coord),
260 (int)(request_recv_buffer[i].src_global_id),
261 (size_t)(request_recv_buffer[i].src_orig_pos),
262 curr_global_result_points, curr_result_weights, curr_result_counts,
263 user_data);
264
265 // copy current results
266 for (size_t j = 0; j < num_src_fields; ++j, ++k) {
267 size_t curr_count = curr_result_counts[j];
268 temp_num_results_per_src_field_per_tgt[k] = (uint64_t)curr_count;
270 temp_global_ids, temp_global_ids_array_size,
271 temp_weights_count + curr_count);
273 temp_w, temp_w_array_size,
274 temp_weights_count + curr_count);
275 for (size_t l = 0; l < curr_count; ++l)
276 temp_global_ids[temp_weights_count + l] =
277 (yac_int)(curr_global_result_points[j][l]);
278 memcpy(temp_w + temp_weights_count, curr_result_weights[j],
279 curr_count * sizeof(*curr_result_weights));
280 temp_weights_count += curr_count;
281 }
282 }
283 free(request_buffer);
284
285 // return number of results per source field per target point
286 for (int i = 0; i < comm_size; ++i) {
287 sendcounts[i] *= num_src_fields;
288 recvcounts[i] *= num_src_fields;
289 sdispls[i] *= num_src_fields;
290 rdispls[i] *= num_src_fields;
291 }
292 yac_alltoallv_uint64_p2p(
293 temp_num_results_per_src_field_per_tgt, recvcounts, rdispls,
294 num_results_per_src_field_per_tgt_uint64, sendcounts, sdispls, comm);
295
296 // set up comm buffers for exchanging of the interpolation results and
297 // count the total number of weights per source field
298 size_t num_weights = 0;
299 size_t * total_num_results_per_src_field =
300 xcalloc(num_src_fields, sizeof(*total_num_results_per_src_field));
301 {
302 uint64_t * curr_num_send_results = temp_num_results_per_src_field_per_tgt;
303 uint64_t * curr_num_recv_results = num_results_per_src_field_per_tgt_uint64;
304 size_t saccu = 0, raccu = 0;
305 for (int i = 0; i < comm_size; ++i) {
306 size_t num_send_results = recvcounts[i] / num_src_fields;
307 size_t num_recv_results = sendcounts[i] / num_src_fields;
308 size_t num_send_weights = 0;
309 size_t num_recv_weights = 0;
310 for (size_t j = 0; j < num_send_results; ++j)
311 for (size_t k = 0; k < num_src_fields; ++k, ++curr_num_send_results)
312 num_send_weights += (size_t)*curr_num_send_results;
313 for (size_t j = 0; j < num_recv_results; ++j) {
314 for (size_t k = 0; k < num_src_fields; ++k, ++curr_num_recv_results) {
315 size_t curr_count = (size_t)*curr_num_recv_results;
316 num_recv_weights += curr_count;
317 total_num_results_per_src_field[k] += curr_count;
318 }
319 }
320 sdispls[i] = saccu;
321 rdispls[i] = raccu;
322 sendcounts[i] = num_send_weights;
323 recvcounts[i] = num_recv_weights;
324 saccu += num_send_weights;
325 raccu += num_recv_weights;
326 num_weights += num_recv_weights;
327 }
328 }
329
330 double * w = xmalloc(num_weights * sizeof(*w));
331 yac_int * global_ids = xmalloc(num_weights * sizeof(*global_ids));
332
333 // return interpolation results
334 yac_alltoallv_dble_p2p(
335 temp_w, sendcounts, sdispls, w, recvcounts, rdispls, comm);
336 yac_alltoallv_yac_int_p2p(
337 temp_global_ids, sendcounts, sdispls,
338 global_ids, recvcounts, rdispls, comm);
339 yac_free_comm_buffers(sendcounts, recvcounts, sdispls, rdispls);
340 free(temp_w);
341 free(temp_global_ids);
342
343 // check which target points have results
344 size_t * interpolated_flag =
345 xmalloc(temp_result_count * sizeof(*interpolated_flag));
346 size_t result_count = 0;
347 for (size_t i = 0, k = 0; i < temp_result_count; ++i) {
348 int flag = 0;
349 for (size_t j = 0; j < num_src_fields; ++j, ++k)
350 flag |= (num_results_per_src_field_per_tgt_uint64[k] > 0);
351 if (flag) {
352 if (result_count != i)
353 memmove(
354 num_results_per_src_field_per_tgt_uint64 + result_count * num_src_fields,
355 num_results_per_src_field_per_tgt_uint64 + i * num_src_fields,
356 num_src_fields * sizeof(*num_results_per_src_field_per_tgt_uint64));
357 interpolated_flag[i] = result_count++;
358 } else {
359 interpolated_flag[i] = SIZE_MAX;
360 }
361 }
362
363 // sort the target points that can be interpolated to the beginning
364 // of the array
366 interpolated_flag, temp_result_count, tgt_points);
367 free(interpolated_flag);
368
369 // sort global result ids into a per tgt per src_field order
370 size_t * global_id_reorder_idx =
371 xmalloc((num_weights + num_src_fields) * sizeof(*global_id_reorder_idx));
372 size_t * src_field_displ = global_id_reorder_idx + num_weights;
373 size_t max_num_results_per_src_field = 0;
374 size_t * num_results_per_src_field_per_tgt =
375 xmalloc(result_count * num_src_fields *
376 sizeof(*num_results_per_src_field_per_tgt));
377 for (size_t i = 0, accu = 0; i < num_src_fields; ++i) {
378 src_field_displ[i] = accu;
379 accu += total_num_results_per_src_field[i];
380 if (max_num_results_per_src_field < total_num_results_per_src_field[i])
381 max_num_results_per_src_field = total_num_results_per_src_field[i];
382 }
383 for (size_t i = 0, k = 0, l = 0; i < result_count; ++i) {
384 for (size_t j = 0; j < num_src_fields; ++j, ++k) {
385 num_results_per_src_field_per_tgt[k] =
386 (size_t)(num_results_per_src_field_per_tgt_uint64[k]);
387 size_t curr_count =
388 (size_t)(num_results_per_src_field_per_tgt_uint64[k]);
389 for (size_t m = 0; m < curr_count; ++m, ++l)
390 global_id_reorder_idx[l] = src_field_displ[j]++;
391 }
392 }
394 global_id_reorder_idx, num_weights, global_ids);
395 free(uint64_t_buffer);
396 free(global_id_reorder_idx);
397
398 // get remote data for all required source points
399 struct remote_point * srcs_per_field[num_src_fields];
400 size_t * result_point_buffer =
401 xmalloc(max_num_results_per_src_field * sizeof(*result_point_buffer));
402 for (size_t i = 0, offset = 0; i < num_src_fields; ++i) {
403 size_t curr_count = total_num_results_per_src_field[i];
405 interp_grid, i, global_ids + offset, curr_count, result_point_buffer);
406 srcs_per_field[i] =
408 interp_grid, i, result_point_buffer, curr_count);
409 offset += curr_count;
410 }
411 free(result_point_buffer);
412 free(global_ids);
413 free(total_num_results_per_src_field);
414
415 struct remote_points tgts = {
416 .data =
418 interp_grid, tgt_points, result_count),
419 .count = result_count};
420
421 // generate weights
423 weights, &tgts, num_results_per_src_field_per_tgt, srcs_per_field, w,
424 num_src_fields);
425
426 free(tgts.data);
427 for (size_t i = 0; i < num_src_fields; ++i) free(srcs_per_field[i]);
428 free(num_results_per_src_field_per_tgt);
429 free(w);
430
431 return result_count;
432}
433
435 yac_func_compute_weights compute_weights_callback, void * user_data) {
436
437 struct interp_method_callback * method = xmalloc(1 * sizeof(*method));
438
441 method->user_data = user_data;
442
443 return (struct interp_method*)method;
444}
445
446static void delete_callback(struct interp_method * method) {
447 free(method);
448}
449
450typedef void (*func_dummy)(void);
451static struct {
453 void * user_data;
454 char * key;
458
460 yac_func_compute_weights compute_weights_callback,
461 void * user_data, char const * key) {
462
463 for (size_t i = 0; i < callback_lookup_table_size; ++i)
465 !strcmp(callback_lookup_table[i].key, key) &&
466 ((callback_lookup_table[i].callback == compute_weights_callback) &&
468 "ERROR(interp_method_callback_add_do_search_callback): "
469 "identical key has been set before with different callbacks")
470
473
475 xmalloc((strlen(key)+1) * sizeof(*key));
478 compute_weights_callback;
481}
482
484 char const * key, yac_func_compute_weights * compute_weights_callback,
485 void ** user_data) {
486
487 for (size_t i = 0; i < callback_lookup_table_size; ++i) {
488 if (!strcmp(callback_lookup_table[i].key, key)) {
489 *compute_weights_callback = callback_lookup_table[i].callback;
490 *user_data = callback_lookup_table[i].user_data;
491 return;
492 }
493 }
494 *compute_weights_callback = NULL;
495 *user_data = NULL;
496 return;
497}
#define ENSURE_ARRAY_SIZE(arrayp, curr_array_size, req_size)
void yac_interp_grid_do_points_search(struct yac_interp_grid *interp_grid, yac_coordinate_pointer search_coords, size_t count, size_t *src_cells)
struct remote_point * yac_interp_grid_get_src_remote_points2(struct yac_interp_grid *interp_grid, enum yac_location location, size_t *src_points, size_t count)
size_t yac_interp_grid_get_num_src_fields(struct yac_interp_grid *interp_grid)
MPI_Comm yac_interp_grid_get_MPI_Comm(struct yac_interp_grid *interp_grid)
struct remote_point * yac_interp_grid_get_tgt_remote_points(struct yac_interp_grid *interp_grid, size_t *tgt_points, size_t count)
struct remote_point * yac_interp_grid_get_src_remote_points(struct yac_interp_grid *interp_grid, size_t src_field_idx, size_t *src_points, size_t count)
void yac_interp_grid_get_tgt_coordinates(struct yac_interp_grid *interp_grid, size_t *tgt_points, size_t count, yac_coordinate_pointer tgt_coordinates)
yac_const_coordinate_pointer yac_interp_grid_get_tgt_field_coords(struct yac_interp_grid *interp_grid)
void yac_interp_grid_src_global_to_local(struct yac_interp_grid *interp_grid, size_t src_field_idx, yac_int *src_global_ids, size_t count, size_t *src_local_ids)
void * user_data
void yac_interp_method_callback_get_compute_weights_callback(char const *key, yac_func_compute_weights *compute_weights_callback, void **user_data)
static MPI_Datatype yac_get_request_data_mpi_datatype(MPI_Comm comm)
static size_t do_search_callback(struct interp_method *method, struct yac_interp_grid *interp_grid, size_t *tgt_points, size_t count, struct yac_interp_weights *weights)
void(* func_dummy)(void)
yac_func_compute_weights callback
static void delete_callback(struct interp_method *method)
char * key
static size_t get_valid_results(size_t *src_cells, size_t count, size_t *tgt_points)
static struct interp_method_vtable interp_method_callback_vtable
static size_t callback_lookup_table_size
struct interp_method * yac_interp_method_callback_new(yac_func_compute_weights compute_weights_callback, void *user_data)
void yac_interp_method_callback_add_compute_weights_callback(yac_func_compute_weights compute_weights_callback, void *user_data, char const *key)
static size_t callback_lookup_table_array_size
static void get_orig_data(struct remote_point const *remote_point, int *owner_rank, size_t *orig_pos)
static struct @7 * callback_lookup_table
void(* yac_func_compute_weights)(double const tgt_coords[3], int src_cell_id, size_t src_cell_idx, int const **global_results_points, double **result_weights, size_t *result_count, void *user_data)
static void compute_weights(struct tgt_point_search_data *tgt_point_data, size_t num_tgt_points, struct edge_interp_data *edge_data, size_t num_edges, struct triangle_interp_data *triangle_data, size_t num_triangles, struct weight_vector_data **weights, size_t **num_weights_per_tgt, size_t *total_num_weights)
void yac_interp_weights_add_wsum_mf(struct yac_interp_weights *weights, struct remote_points *tgts, size_t *num_src_per_field_per_tgt, struct remote_point **srcs_per_field, double *w, size_t num_src_fields)
@ YAC_LOC_CELL
Definition location.h:14
#define xcalloc(nmemb, size)
Definition ppm_xfuncs.h:64
#define xmalloc(size)
Definition ppm_xfuncs.h:66
struct interp_method_vtable * vtable
yac_func_compute_weights compute_weights_callback
size_t(* do_search)(struct interp_method *method, struct yac_interp_grid *grid, size_t *tgt_points, size_t count, struct yac_interp_weights *weights)
struct remote_point_info single
struct remote_point_info * multi
union remote_point_infos::@1 data
struct remote_point_infos data
struct remote_point * data
void yac_quicksort_index_size_t_yac_int(size_t *a, size_t n, yac_int *idx)
void yac_quicksort_index_size_t_size_t(size_t *a, size_t n, size_t *idx)
#define YAC_ASSERT(exp, msg)
Definition yac_assert.h:15
void yac_generate_alltoallv_args(int count, size_t const *sendcounts, size_t *recvcounts, size_t *sdispls, size_t *rdispls, MPI_Comm comm)
Definition yac_mpi.c:569
void yac_free_comm_buffers(size_t *sendcounts, size_t *recvcounts, size_t *sdispls, size_t *rdispls)
Definition yac_mpi.c:624
void yac_get_comm_buffers(int count, size_t **sendcounts, size_t **recvcounts, size_t **sdispls, size_t **rdispls, MPI_Comm comm)
Definition yac_mpi.c:593
MPI_Datatype yac_create_resized(MPI_Datatype dt, size_t new_size, MPI_Comm comm)
Definition yac_mpi.c:548
void yac_alltoallv_p2p(void const *send_buffer, size_t const *sendcounts, size_t const *sdispls, void *recv_buffer, size_t const *recvcounts, size_t const *rdispls, size_t dt_size, MPI_Datatype dt, MPI_Comm comm)
Definition yac_mpi.c:129
#define yac_mpi_call(call, comm)
Xt_int yac_int
Definition yac_types.h:15
#define yac_int_dt
Definition yac_types.h:16
double const (*const yac_const_coordinate_pointer)[3]
Definition yac_types.h:20
double(* yac_coordinate_pointer)[3]
Definition yac_types.h:19