#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include <stdbool.h>
#include <stdlib.h>
#include <mpi.h>
#include <yaxt.h>
#include "tests.h"
#include "ctest_common.h"
#include "test_xmap_common.h"
static void
test_xmap_allgather_analog(xmap_constructor xmap_new,
size_t num_indices_per_rank,
MPI_Comm comm);
static void
test_pair(xmap_constructor xmap_new,
MPI_Comm comm);
static void
test_maxpos(xmap_constructor xmap_new,
MPI_Comm comm,
int indices_per_rank);
int
xt_xmap_parallel_test_main(xmap_constructor xmap_new)
{
int comm_rank, comm_size;
MPI_Comm_rank(comm, &comm_rank);
MPI_Comm_size(comm, &comm_size);
test_xmap_allgather_analog(xmap_new, 1, comm);
test_xmap_allgather_analog(xmap_new, 1024, comm);
if (comm_size > 2)
test_ring_1d(xmap_new, comm);
if (comm_size == 2)
test_pair(xmap_new, comm);
if (comm_size > 1)
test_ping_pong(xmap_new, comm, 0, comm_size - 1);
test_maxpos(xmap_new, comm, 5);
test_maxpos(xmap_new, comm, 501);
MPI_Finalize();
return TEST_EXIT_CODE;
}
void
{
int comm_rank, comm_size, is_inter;
xt_mpi_call(MPI_Comm_test_inter(comm, &is_inter), comm);
= is_inter ? MPI_Comm_remote_size : MPI_Comm_size;
PUT_ERR("error in xt_xmap_get_num_destinations\n");
PUT_ERR("error in xt_xmap_get_num_sources\n");
int *ranks =
xmalloc(
sizeof (*ranks) * (
size_t)comm_size);
bool mismatch = false;
for (int i = 0; i < comm_size; ++i)
mismatch |= (ranks[i] != i);
if (mismatch)
PUT_ERR("error in xt_xmap_get_destination_ranks\n");
mismatch = false;
for (int i = 0; i < comm_size; ++i)
mismatch |= (ranks[i] != i);
if (mismatch)
PUT_ERR("error in xt_xmap_get_source_ranks\n");
free(ranks);
}
static void
test_xmap_allgather_analog(xmap_constructor xmap_new,
size_t num_indices_per_rank,
{
int comm_rank, comm_size;
for (size_t i = 0; i < num_indices_per_rank; ++i)
src_index_list[i]
= (
Xt_int)((size_t)comm_rank * num_indices_per_rank + i);
free(src_index_list);
size_t num_gathered = (size_t)comm_size * num_indices_per_rank;
for (size_t i = 0; i < num_gathered; ++i)
dst_index_list[i] = (
Xt_int)i;
free(dst_index_list);
Xt_xmap xmap = xmap_new(src_idxlist, dst_idxlist, comm);
check_xmap_allgather_analog_xmap(xmap, comm);
check_xmap_allgather_analog_xmap(xmap_copy, comm);
}
static void
check_ring_xmap(
Xt_xmap xmap,
const Xt_int dst_index_list[],
bool is_inter)
{
if (!is_inter && (num_dst > 2 || num_dst < 1))
PUT_ERR("error in xt_xmap_get_num_destinations\n");
if (num_src > 2 || num_src < 1)
PUT_ERR("error in xt_xmap_get_num_sources\n");
int ranks[2];
if (!is_inter) {
if (ranks[0] != dst_index_list[0] ||
ranks[num_dst > 1] != dst_index_list[1])
PUT_ERR("error in xt_xmap_get_destination_ranks\n");
}
if (ranks[0] != dst_index_list[0] ||
ranks[num_src > 1] != dst_index_list[1])
PUT_ERR("error in xt_xmap_get_source_ranks\n");
}
void
test_ring_1d(xmap_constructor xmap_new,
MPI_Comm comm)
{
int comm_rank, comm_size, is_inter;
xt_mpi_call(MPI_Comm_test_inter(comm, &is_inter), comm);
= is_inter ? MPI_Comm_remote_size : MPI_Comm_size;
Xt_int dst_index_list[2] = {(
Xt_int)((comm_rank + comm_size - 1)%comm_size),
(
Xt_int)((comm_rank + 1)%comm_size)};
if (dst_index_list[0] > dst_index_list[1]) {
Xt_int temp = dst_index_list[0];
dst_index_list[0] = dst_index_list[1];
dst_index_list[1] = temp;
}
Xt_xmap xmap = xmap_new(src_idxlist, dst_idxlist, comm);
check_ring_xmap(xmap, dst_index_list, (bool)is_inter);
check_ring_xmap(xmap_copy, dst_index_list, (bool)is_inter);
}
void
test_maxpos(xmap_constructor xmap_new,
MPI_Comm comm,
int indices_per_rank)
{
int comm_rank, comm_size;
int world_size = comm_size * indices_per_rank;
Xt_int src_index[indices_per_rank];
for (int i = 0; i < indices_per_rank; ++i)
src_index[i] = (
Xt_int)(i + comm_rank * indices_per_rank);
Xt_int dst_index[indices_per_rank];
for (int i = 0; i < indices_per_rank/2; ++i)
dst_index[i]
= (
Xt_int)((i - indices_per_rank/2
+ (comm_rank+comm_size) * indices_per_rank)%world_size);
for (int i = indices_per_rank/2; i < (indices_per_rank+1)/2; ++i)
dst_index[i] = (
Xt_int)(i + comm_rank * indices_per_rank);
for (int i = 0; i < indices_per_rank/2; ++i)
dst_index[(indices_per_rank+1)/2+i]
= (
Xt_int)(((
int)i + (comm_rank+1) * indices_per_rank)%world_size);
Xt_xmap xmap = xmap_new(src_idxlist, dst_idxlist, comm);
if (max_pos_dst < indices_per_rank-1)
PUT_ERR("error in xt_xmap_get_max_dst_pos\n");
if (max_pos_src < indices_per_rank-1)
PUT_ERR("error in xt_xmap_get_max_src_pos\n");
int pos_update1[indices_per_rank];
for (int i = 0; i < indices_per_rank; ++i)
pos_update1[i] = 2*i;
if (max_pos_dst_u < (indices_per_rank-1)*2)
PUT_ERR("error in xt_xmap_get_max_dst_pos\n");
if (max_pos_src_u < (indices_per_rank-1)*2)
PUT_ERR("error in xt_xmap_get_max_src_pos\n");
int pos_update2[2*indices_per_rank];
for (int i = 0; i < 2*indices_per_rank; ++i)
pos_update2[i] = i/2;
if (max_pos_dst_u2 >= indices_per_rank)
PUT_ERR("error in xt_xmap_get_max_dst_pos\n");
if (max_pos_src_u2 >= indices_per_rank)
PUT_ERR("error in xt_xmap_get_max_src_pos\n");
int spread[] = { 0, indices_per_rank*3 };
if (max_pos_dst_s < (indices_per_rank-1)*3)
PUT_ERR("error in xt_xmap_get_max_dst_pos\n");
if (max_pos_src_s < (indices_per_rank-1)*3)
PUT_ERR("error in xt_xmap_get_max_src_pos\n");
}
static void
{
PUT_ERR("error in xt_xmap_get_num_destinations\n");
PUT_ERR("error in xt_xmap_get_num_sources\n");
int ranks[2];
if (ranks[0] != 0 || ranks[1] != 1)
PUT_ERR("error in xt_xmap_get_destination_ranks\n");
if (ranks[0] != 0 || ranks[1] != 1)
PUT_ERR("error in xt_xmap_get_source_ranks\n");
}
static void
test_pair(xmap_constructor xmap_new,
MPI_Comm comm)
{
int comm_rank, comm_size;
MPI_Comm_rank(comm, &comm_rank);
MPI_Comm_size(comm, &comm_size);
enum { numIdx = 20 };
static const Xt_int src_index_list[2][numIdx] = {
{1,2,3,4,5,9,10,11,12,13,17,18,19,20,21,25,26,27,28,29},
{4,5,6,7,8,12,13,14,15,16,20,21,22,23,24,28,29,30,31,32} };
static const Xt_int dst_index_list[2][numIdx] = {
{10,15,14,13,12,15,10,11,12,13,23,18,19,20,21,31,26,27,28,29},
{13,12,11,10,15,12,13,14,15,10,20,21,22,23,18,28,29,30,31,26}};
Xt_xmap xmap = xmap_new(src_idxlist, dst_idxlist, comm);
check_pair_xmap(xmap);
check_pair_xmap(xmap_copy);
}
static void
test_ping_pong_xmap(
Xt_xmap xmap,
MPI_Comm comm,
int ping_rank,
int pong_rank)
{
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
PUT_ERR("error in xt_xmap_get_num_destinations (rank == %d)\n", comm_rank);
PUT_ERR("error in xt_xmap_get_num_sources (rank == %d)\n", comm_rank);
if (comm_rank == ping_rank) {
int dst_rank;
if (dst_rank != pong_rank)
PUT_ERR("error in xt_xmap_get_destination_ranks\n");
}
if (comm_rank == pong_rank) {
int src_rank;
if (src_rank != ping_rank)
PUT_ERR("error in xt_xmap_get_source_ranks\n");
}
}
void
test_ping_pong(xmap_constructor xmap_new,
MPI_Comm comm,
int ping_rank, int pong_rank)
{
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
enum { numIdx = 5 };
static const Xt_int index_list[numIdx] = {0,1,2,3,4};
Xt_idxlist src_idxlist = (comm_rank == ping_rank)?
Xt_idxlist dst_idxlist = (comm_rank == pong_rank)?
Xt_xmap xmap = xmap_new(src_idxlist, dst_idxlist, comm);
test_ping_pong_xmap(xmap, comm, ping_rank, pong_rank);
test_ping_pong_xmap(xmap_copy, comm, ping_rank, pong_rank);
}
add versions of standard API functions not returning on error
void xt_initialize(MPI_Comm default_comm)
struct Xt_xmap_ * Xt_xmap
struct Xt_idxlist_ * Xt_idxlist
Xt_idxlist xt_idxempty_new(void)
void xt_idxlist_delete(Xt_idxlist idxlist)
Xt_idxlist xt_idxvec_new(const Xt_int *idxlist, int num_indices)
#define xt_mpi_call(call, comm)
void(* xt_sort_int)(int *a, size_t n)
Xt_xmap xt_xmap_update_positions(Xt_xmap xmap, const int *src_positions, const int *dst_positions)
void xt_xmap_delete(Xt_xmap xmap)
Xt_xmap xt_xmap_spread(Xt_xmap xmap, int num_repetitions, const int src_displacements[num_repetitions], const int dst_displacements[num_repetitions])
int xt_xmap_get_num_destinations(Xt_xmap xmap)
Xt_xmap xt_xmap_copy(Xt_xmap xmap)
int xt_xmap_get_max_dst_pos(Xt_xmap xmap)
int xt_xmap_get_num_sources(Xt_xmap xmap)
void xt_xmap_get_source_ranks(Xt_xmap xmap, int *ranks)
void xt_xmap_get_destination_ranks(Xt_xmap xmap, int *ranks)
int xt_xmap_get_max_src_pos(Xt_xmap xmap)