YAC 3.12.0
Yet Another Coupler
Loading...
Searching...
No Matches
yac_replay.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2# Copyright (c) 2024 The YAC Authors
3#
4# SPDX-License-Identifier: BSD-3-Clause
5
6import numpy as np
7import uxarray as ux
8import xarray as xr
9import yac
10import argparse
11import logging
12from mpi4py import MPI # < some MPI impl. need this to be imported explicitly
13
14
15parser = argparse.ArgumentParser("yac_replay",
16 description="""Replay simulation from a dataset.
17The given files are loaded with `uxarray.open_dataset`.
18It is currently not supported to run this component in parallel.""")
19parser.add_argument("gridfile", type=str,
20 default="path to the gridfile")
21parser.add_argument("datafile", type=str, nargs='+',
22 default="path to the dataset")
23parser.add_argument("--compname", type=str, default="replay",
24 help="Name for the yac component (default: replay)")
25parser.add_argument("--gridname", type=str, default="replay_grid",
26 help="Name for the yac grid (default: replay_grid)")
27parser.add_argument("--log-level", default=logging.WARNING, type=lambda x: getattr(logging, x.upper()),
28 help="Configure the logging level.")
29parser.add_argument("--coupling-config", type=str,
30 help="If given the yac config is read with yac.read_config_yaml")
31parser.add_argument("--cell-mask", type=str,
32 help="Expression for the cell mask used when defining the fields. All variables in the grid file can be used.")
33parser.add_argument("--edge-mask", type=str,
34 help="Expression for the edge mask used when defining the fields. All variables in the grid file can be used.")
35parser.add_argument("--vertex-mask", type=str,
36 help="Expression for the vertex mask used when defining the fields. All variables in the grid file can be used.")
37
38args = parser.parse_args()
39
40log_handler = logging.StreamHandler()
41log_handler.setFormatter(logging.Formatter("%(asctime)-10s %(name)-10s %(levelname)-8s %(message)s"))
42logging.basicConfig(level=args.log_level, handlers=[log_handler])
43
44logging.info("Instantiate yac")
46
47if args.coupling_config:
48 y.read_config_yaml(args.coupling_config)
49
50logging.info(f"open dataset: {args.gridfile=}, {args.datafile=}")
51ds = ux.open_mfdataset(args.gridfile, args.datafile)
52
53logging.info(f"define component: {args.compname}")
54comp = y.def_comp(args.compname)
55
56fields = []
57
58cell_to_edge = np.asarray(ds.uxgrid.face_edge_connectivity).reshape((-1,), order="C")
59cell_to_edge = cell_to_edge[cell_to_edge != ux.INT_FILL_VALUE]
60
61grid = yac.UnstructuredGridEdge(args.gridname,
62 ds.uxgrid.n_nodes_per_face,
63 np.deg2rad(ds.uxgrid.node_lon),
64 np.deg2rad(ds.uxgrid.node_lat),
65 cell_to_edge, ds.uxgrid.edge_node_connectivity)
66
67cell_point_id = grid.def_points(yac.Location.CELL,
68 np.deg2rad(ds.uxgrid.face_lon),
69 np.deg2rad(ds.uxgrid.face_lat))
70edge_point_id = grid.def_points(yac.Location.EDGE,
71 np.deg2rad(ds.uxgrid.edge_lon),
72 np.deg2rad(ds.uxgrid.edge_lat))
73vertex_point_id = grid.def_points(yac.Location.CORNER,
74 np.deg2rad(ds.uxgrid.node_lon),
75 np.deg2rad(ds.uxgrid.node_lat))
76
77# define the masks
78grid_ds = xr.open_dataset(args.gridfile)
79if args.cell_mask is not None:
80 mask_values = eval(args.cell_mask, globals=None, locals=grid_ds.variables)
81 cell_mask = grid.def_mask(yac.Location.CELL,
82 mask_values, "replay_cell_mask")
83else:
84 cell_mask = None
85if args.edge_mask is not None:
86 mask_values = eval(args.edge_mask, globals=None, locals=grid_ds.variables)
87 edge_mask = grid.def_mask(yac.Location.EDGE,
88 mask_values, "replay_edge_mask")
89else:
90 edge_mask = None
91if args.vertex_mask is not None:
92 mask_values = eval(args.vertex_mask, globals=None, locals=grid_ds.variables)
93 vertex_mask = grid.def_mask(yac.Location.CORNER,
94 mask_values, "replay_vertex_mask")
95else:
96 vertex_mask = None
97
98dt = np.diff(ds.coords["time"])
99assert np.all(dt[0] == dt), "Time coordinates are not equidistant"
100dt = dt[0]
101
102for varname, var in ds.variables.items():
103 logging.info(f"checking variable {varname}")
104 if "time" not in var.dims:
105 logging.info(f"{varname} was skipped: lacking time coordinate")
106 continue
107 spatial_axis = {"n_face", "n_edge", "n_node"} & set(var.dims)
108 if len(spatial_axis) != 1:
109 logging.info(
110 f"{varname} was skipped: exactly one of n_face, n_edge or n_node must be a coordinate.")
111 continue
112 vaxis = set(var.dims) - {"time", "n_face", "n_edge", "n_node"}
113 if len(vaxis) > 1:
114 logging.info(f"{varname} was skipped: too many coordinates")
115 continue
116 if len(vaxis) > 0:
117 zcoord, = vaxis
118 collection_size = np.asarray(ds[zcoord]).shape[0]
119 else:
120 collection_size = 1
121
122 spatial_axis, = spatial_axis
123 point_id, mask = {"n_face": (cell_point_id, cell_mask),
124 "n_edge": (edge_point_id, edge_mask),
125 "n_node": (vertex_point_id, vertex_mask)}[spatial_axis]
126 fields.append((
127 yac.Field.create(varname,
128 comp,
129 point_id,
130 collection_size,
131 str(int(dt / np.timedelta64(1, 'ms'))), yac.TimeUnit.MILLISECOND, mask),
132 var.transpose("time", *list(vaxis), spatial_axis)
133 ))
134
135logging.info("calling enddef")
136y.enddef()
137
138fields = [(field, var) for field, var in fields if field.role == yac.ExchangeType.SOURCE]
139logging.info(f"{len(fields)} fields are coupled")
140
141t0 = np.searchsorted(ds.coords["time"], np.datetime64(y.start_datetime))
142assert ds.coords["time"][t0] == np.datetime64(y.start_datetime), "starttime not in time axis"
143
144while True:
145 for field, var in fields:
146 logging.info(f"processing {field.name} at {field.datetime}")
147 t = np.datetime64(field.datetime)
148 t_idx = np.searchsorted(ds["time"], t)
149 assert ds["time"][t_idx] == t
150 info = field.put(var.isel(time=t_idx).data)
151 if info == yac.Action.PUT_FOR_RESTART:
152 break
153
154logging.info("done")
create(cls, str field_name, Component comp, points, collection_size, str timestep, TimeUnit timeunit, masks=None)
Definition yac.pyx:1570
Initializies a YAC instance and provides further functionality.
Definition yac.pyx:693