YAC 3.12.0
Yet Another Coupler
Loading...
Searching...
No Matches
toy_multi_curve2d.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2024 The YAC Authors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6
7from yac import YAC, Curve2dGrid, Location, Field, TimeUnit
8from itertools import product
9import math
10import numpy as np
11from mpi4py import MPI
12
13
14def lonlat2xyz(lon, lat):
15 clat = np.cos(lat)
16 slat = np.sin(lat)
17 return clat * np.cos(lon), clat * np.sin(lon), slat
18
19
20def xyz2lonlat(x, y, z):
21 lat = np.arcsin(z)
22 lon = np.arctan2(y, x)
23 return lon, lat
24
25
26def rotmat(theta, plane_dims, d=2):
27 e = np.eye(d)
28 e[np.ix_(plane_dims, plane_dims)] = np.array([[np.cos(theta), -np.sin(theta)],
29 [np.sin(theta), np.cos(theta)]])
30 return e
31
32
33MPI.COMM_WORLD.Barrier()
34
35yac = YAC()
36my_comp_name = "PISM"
37comp = yac.def_comp(my_comp_name)
38
39size = comp.comp_comm.size
40rank = comp.comp_comm.rank
41
42
43def msg(m):
44 if rank == 0:
45 print("toy_multi_curve2d: ", m)
46
47
48# factorize size
49fac = int(np.sqrt(size))
50while size % fac != 0:
51 fac = fac-1
52msg(f"decomposition: {fac} x {size//fac}")
53
54dim2d = [100, 50]
55size2d = [fac, size//fac]
56rank2d = [rank % fac, rank//fac]
57chunk2d = [math.ceil(dim2d[0]/size2d[0]), math.ceil(dim2d[1]/size2d[1])]
58slice2d = [slice(rank2d[0]*chunk2d[0], (rank2d[0]+1)*chunk2d[0]+1),
59 slice(rank2d[1]*chunk2d[1], (rank2d[1]+1)*chunk2d[1]+1)]
60
61x = np.linspace(-.5, .5, dim2d[0])
62y = np.linspace(-.5, .5, dim2d[1])
63x = x[slice2d[0]]
64y = y[slice2d[1]]
65vx, vy = np.meshgrid(x, y)
66
67r = rotmat(-.45*np.pi, [0, 2], 3)@rotmat(0.1*np.pi, [0, 1], 3)
68v_xyz = np.stack(lonlat2xyz(vx, vy), axis=-1)
69v_xyz = v_xyz@r
70c_xyz = 0.25*(v_xyz[:-1, :-1, :]+v_xyz[1:, :-1, :]+v_xyz[:-1, 1:, :]+v_xyz[1:, 1:, :])
71vx, vy = xyz2lonlat(v_xyz[..., 0], v_xyz[..., 1], v_xyz[..., 2])
72cx, cy = xyz2lonlat(c_xyz[..., 0], c_xyz[..., 1], c_xyz[..., 2])
73
74grid = Curve2dGrid(f"{my_comp_name}_grid", vx, vy)
75
76corner_idx = np.arange(math.prod(dim2d)).reshape(dim2d)[slice2d[0], slice2d[1]]
77grid.set_global_index(corner_idx.T.flatten(), Location.CORNER)
78
79cell_idx = np.arange(math.prod([d-1 for d in dim2d])).reshape([d-1 for d in dim2d])[
80 slice(rank2d[0]*chunk2d[0], (rank2d[0]+1)*chunk2d[0]),
81 slice(rank2d[1]*chunk2d[1], (rank2d[1]+1)*chunk2d[1])]
82grid.set_global_index(cell_idx.T.flatten(), Location.CELL)
83
84points_cell = grid.def_points(Location.CELL, cx, cy)
85points_vertex = grid.def_points(Location.CORNER, vx, vy)
86
87interpolations = [("conserv", Location.CELL),
88 ("2nd_conserv", Location.CELL),
89 ("avg", Location.CORNER),
90 ("hcsbb", Location.CELL),
91 ("rbf", Location.CELL)]
92
93MPI.COMM_WORLD.Barrier()
94MPI.COMM_WORLD.Barrier()
95
96field = {f"{interp[0]}_{comp_name}_field_out":
97 Field.create(f"{interp[0]}_{comp_name}_field_out", comp,
98 points_vertex if interp[1] == Location.CORNER else points_cell,
99 1, "2", TimeUnit.SECOND)
100 for comp_name, interp in product(yac.component_names, interpolations)}
101
102MPI.COMM_WORLD.Barrier()
103MPI.COMM_WORLD.Barrier()
104yac.enddef()
105MPI.COMM_WORLD.Barrier()
106
107
108def data(lon, lat):
109 return (np.sin(8*lon)*np.cos(12*lat)).flatten()
110
111
112data_vertex = data(vx, vy)
113data_cell = data(cx, cy)
114
115MPI.COMM_WORLD.Barrier()
116for interp in interpolations:
117 field[f"{interp[0]}_{my_comp_name}_field_out"].put(data_vertex if interp[1] == Location.CORNER else data_cell)
118
119for comp_name, interp in product(yac.component_names, interpolations):
120 if comp_name == my_comp_name:
121 continue
122 data = field[f"{interp[0]}_{comp_name}_field_out"].get()
123MPI.COMM_WORLD.Barrier()
A curvilinear stuctured 2d Grid.
Definition yac.pyx:1398
Initializies a YAC instance and provides further functionality.
Definition yac.pyx:693
rotmat(theta, plane_dims, d=2)