YAC 3.13.0
Yet Another Coupler
Loading...
Searching...
No Matches
yac_generate_weights.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2025 The YAC Authors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6
7
10
11try:
12 from mpi4py import MPI # calls MPI_Init with some special treatments
13except:
14 pass
15
16from yac.core import (BasicGrid,
17 yac_location,
18 InterpField,
19 InterpolationStack,
20 compute_weights,
21 )
22
23import uxarray as ux
24import numpy as np
25from argparse import ArgumentParser
26import json
27from pathlib import Path
28import logging
29
30
32 if "," in s:
33 s, args = s.split(",")
34 args = json.loads(args)
35 else:
36 args = {}
37 assert type(args) is dict, "Invalid argument spec for interpolation"
38 return s, args
39
40
41def get_coords(loc, uxgrid, yac_grid):
42 if loc == yac_location.YAC_LOC_CELL:
43 lon, lat = np.deg2rad([uxgrid.face_lon, uxgrid.face_lat])
44 elif loc == yac_location.YAC_LOC_CORNER:
45 lon = lat = None
46 elif loc == yac_location.YAC_LOC_EDGE:
47 lon, lat = np.deg2rad([uxgrid.edge_lon, uxgrid.edge_lat])
48 return yac_grid.add_coordinates(loc, lon, lat)
49
50
51yac_location.__str__ = lambda loc: loc.name
52
53
54parser = ArgumentParser(prog="gen_weights.py",
55 description="Generate YAC weights files")
56parser.add_argument("src_grid_path", type=str)
57parser.add_argument("--src-location",
58 type=lambda loc: yac_location[loc],
59 choices=[yac_location.YAC_LOC_CELL, yac_location.YAC_LOC_CORNER, yac_location.YAC_LOC_EDGE],
60 default=yac_location.YAC_LOC_CELL
61 )
62parser.add_argument("--src-mask-var", type=str,
63 help="Variable name of the mask")
64parser.add_argument("--src-mask-filename", type=str,
65 help="Filename to read the mask from.")
66parser.add_argument("--src-mask-valid-values", type=int, nargs="*", default=[1],
67 help="Valid values for the mask")
68parser.add_argument("--src-grid-name", type=str, default="src_grid")
69
70parser.add_argument("tgt_grid_path", type=str)
71parser.add_argument("--tgt-location",
72 type=lambda loc: yac_location[loc],
73 choices=[yac_location.YAC_LOC_CELL, yac_location.YAC_LOC_CORNER, yac_location.YAC_LOC_EDGE],
74 default=yac_location.YAC_LOC_CELL
75 )
76parser.add_argument("--tgt-mask-var", type=str,
77 help="Variable name of the mask")
78parser.add_argument("--tgt-mask-filename", type=str,
79 help="Filename to read the mask from.")
80parser.add_argument("--tgt-mask-valid-values", type=int, nargs="*", default=[1],
81 help="Valid values for the mask")
82parser.add_argument("--tgt-grid-name", type=str, default="tgt_grid")
83
84parser.add_argument("interpolation_stack",
85 type=parse_interp_method,
86 nargs="+",
87 help="format: method_name[,JSON mapping]")
88parser.add_argument("--output-file", type=Path,
89 default="weights.nc")
90parser.add_argument("--log-level",
91 default=logging.INFO,
92 type=lambda x: getattr(logging, x),
93 help="Configure the logging level.",
94 )
95
96args = parser.parse_args()
97
98logging.basicConfig(level=args.log_level)
99
100logging.info("Reading src grid")
101src_uxgrid = ux.open_grid(args.src_grid_path)
102src_grid = BasicGrid.from_uxgrid(args.src_grid_name, src_uxgrid)
103src_coords = get_coords(args.src_location, src_uxgrid, src_grid)
104if args.src_mask_var is not None:
105 src_mask_file = args.src_mask_filename or args.src_grid_path
106 with ux.open_dataset(args.src_grid_path, src_mask_file) as ds:
107 m = np.isin(ds[args.src_mask_var], args.src_mask_valid_values)
108 src_mask = src_grid.add_mask(args.src_location, m)
109else:
110 src_mask = None
111src_field = InterpField(src_coords, src_mask)
112
113logging.info("Reading tgt grid")
114tgt_uxgrid = ux.open_grid(args.tgt_grid_path)
115tgt_grid = BasicGrid.from_uxgrid(args.tgt_grid_name, tgt_uxgrid)
116tgt_coords = get_coords(args.tgt_location, tgt_uxgrid, tgt_grid)
117if args.tgt_mask_var is not None:
118 tgt_mask_file = args.tgt_mask_filename or args.tgt_grid_path
119 with ux.open_dataset(args.tgt_grid_path, tgt_mask_file) as ds:
120 m = np.isin(ds[args.tgt_mask_var], args.tgt_mask_valid_values)
121 tgt_mask = tgt_grid.add_mask(args.tgt_location, m)
122else:
123 tgt_mask = None
124tgt_field = InterpField(tgt_coords, tgt_mask)
125
126interp_stack = InterpolationStack.from_list(args.interpolation_stack)
127
128logging.info("Computing weights")
129weights = compute_weights(interp_stack,
130 src_field,
131 tgt_field)
132
133logging.info("Write weights to file")
134weights.write_to_file(str(args.output_file))
get_coords(loc, uxgrid, yac_grid)