YAC 3.13.0
Yet Another Coupler
Loading...
Searching...
No Matches
test_core.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2024 The YAC Authors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6
7
11
12try:
13 from mpi4py import MPI # calls MPI_Init with some special treatments
14 comm = MPI.COMM_WORLD
15 rank = comm.rank
16except:
17 comm = None
18 rank = 0
19
20from yac.core import (BasicGrid,
21 InterpField,
22 InterpolationStack,
23 compute_weights,
24 yac_location,
25 yac_weight_file_on_existing,
26 lonlat2xyz,
27 get_io_ranks,
28 )
29import numpy as np
30import sys
31import asyncio
32import logging
33
34logging.basicConfig(level=logging.WARNING, format=f'%(name)s rank={rank} - %(levelname)s - %(message)s')
35logging.getLogger("yac.core").setLevel(logging.INFO)
36
37grid_dir = sys.argv[1]
38
39reg_grid = BasicGrid.reg_2d_new("reg_grid",
40 np.linspace(-2, 2, 20),
41 np.linspace(-1.5, 1.5, 20))
42reg_grid.set_core_mask(yac_location.YAC_LOC_CORNER, np.ones(400, dtype=int))
43
44# assert that set_global_index raises an exception if indices are out of bounds
45try:
46 reg_grid.set_global_index(yac_location.YAC_LOC_CORNER, range(10**99, 10**99+400))
47 raise RuntimeError("set_global_index did not throw an exception for very large indices.")
48except AssertionError:
49 pass
50
51reg_grid.set_global_index(yac_location.YAC_LOC_CORNER, range(400))
52reg_grid.to_file("reg_grid.nc", comm)
53
54coords = reg_grid.add_coordinates(yac_location.YAC_LOC_CORNER)
55src_field = InterpField(coords)
56
57icon_grid, icon_cell_field = BasicGrid.read_icon(grid_dir+"icon_grid_0030_R02B03_G.nc",
58 "icon", comm=comm)
59
60interp_stack = InterpolationStack.from_list([("nnn", {"n": 3, "max_search_distance": 2.}),
61 ("fixed", {"value": 42.0})])
62
63weights = compute_weights(interp_stack,
64 src_field,
65 icon_cell_field)
66
67weights.write_to_file("weights.nc", on_existing=yac_weight_file_on_existing.YAC_WEIGHT_FILE_OVERWRITE)
68
69interpolate = weights.get_interpolation(collection_selection=[1, 4, 2])
70
71src_size = src_field.basic_grid.get_data_size(yac_location.YAC_LOC_CORNER)
72src = np.stack([i*np.ones(src_size, dtype=np.float64) for i in range(5)], axis=0)
73
74tgt = interpolate(src)
75assert np.all(tgt[0,...] == 1.0), f"Wrong result: {tgt=}"
76assert np.all(tgt[1,...] == 4.0), f"Wrong result: {tgt=}"
77assert np.all(tgt[2,...] == 2.0), f"Wrong result: {tgt=}"
78
79local_is_io, io_ranks = get_io_ranks()
80assert local_is_io == True
81assert len(io_ranks) == 1
82assert io_ranks[0] == 0
83
84async def _main():
85 tgt = await interpolate.execute_async(src)
86 assert np.all(tgt[0,...] == 1.0), f"Wrong result: {tgt=}"
87 assert np.all(tgt[1,...] == 4.0), f"Wrong result: {tgt=}"
88 assert np.all(tgt[2,...] == 2.0), f"Wrong result: {tgt=}"
89
90asyncio.run(_main())