YAC 3.8.0
Yet Another Coupler
Loading...
Searching...
No Matches
plot_weights.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2# Copyright (c) 2024 The YAC Authors
3#
4# SPDX-License-Identifier: BSD-3-Clause
5
6from netCDF4 import Dataset
7import matplotlib
8import matplotlib.pyplot as plt
9import cartopy.crs as ccrs
10import cartopy.feature as cfeature
11import numpy as np
12from grid_utils import get_grid
13import logging
14from logging import info
15
16
17def plot_weights(ax, weightsfile, src_grid, tgt_grid,
18 src_idx=None, tgt_idx=None, zoom=1, quiver_kwargs={}):
19 weights = Dataset(weightsfile, "r")
20 if 'version' in weights.__dict__:
21 info(f"{weights.version=}")
22 if weights.version != "yac weight file 1.0":
23 print("WARNING: You are using an incompatible weight file version\n" +
24 weights.version + " != yac weight file 1.0")
25 num_src_fields = weights.num_src_fields if "num_src_fields" in weights.variables else 1
26 info(f"{num_src_fields=}")
27 for src_field in range(num_src_fields):
28 if "src_locations" in weights.variables:
29 src_locstr = bytes(np.array(weights["src_locations"][src_field, :])).decode().rstrip("\0")
30 else:
31 src_locstr = "CELL"
32 info(f"source loc: {src_locstr}")
33 if src_locstr == "CELL":
34 src_points = np.stack([src_grid.clon, src_grid.clat])
35 elif src_locstr == "CORNER":
36 src_points = np.stack([src_grid.vlon, src_grid.vlat])
37 elif src_locstr == "EDGE":
38 src_points = np.stack([src_grid.vlon, src_grid.vlat])
39 else:
40 raise f"Unknown location string {src_locstr}"
41 if "dst_locations" in weights.variables:
42 tgt_locstr = bytes(np.array(weights["dst_location"])).decode().rstrip("\0")
43 else:
44 tgt_locstr = "CELL"
45 info(f"target loc: {tgt_locstr}")
46 if tgt_locstr == "CELL":
47 tgt_points = np.stack([tgt_grid.clon, tgt_grid.clat])
48 elif tgt_locstr == "CORNER":
49 tgt_points = np.stack([tgt_grid.vlon, tgt_grid.vlat])
50 elif tgt_locstr == "EDGE":
51 tgt_points = np.stack([tgt_grid.elon, tgt_grid.elat])
52 else:
53 raise f"Unknown location string {tgt_locstr}"
54
55 yac_weights_format_address_offset = 1
56 src_adr = np.asarray(weights["src_address"])-yac_weights_format_address_offset-src_grid.idx_offset
57 tgt_adr = np.asarray(weights["dst_address"])-yac_weights_format_address_offset-tgt_grid.idx_offset
58
59 info(f"Source indices in weight file range from {np.min(src_adr)} to {np.max(src_adr)}")
60 info(f"Number of source points is {src_points.shape[1]}")
61 info(f"Target indices in weight file range from {np.min(tgt_adr)} to {np.max(tgt_adr)}")
62 info(f"Number of target points is {tgt_points.shape[1]}")
63
64 if np.max(src_adr) > src_points.shape[1]:
65 raise Exception(f"Source grid too small. Max index in weight file is {np.max(src_adr)}, number of source points are {src_points.shape[1]}")
66 if np.max(tgt_adr) > tgt_points.shape[1]:
67 raise Exception(f"Target grid too small. Max index in weight file is {np.max(tgt_adr)}, number of source points are {tgt_points.shape[1]}")
68
69 # Remove redundant links as in SCRIP bilinear and bicubic weights
70 mask = (src_adr >= 0)
71
72 # Restrain plot for targeted cells
73 if src_idx is not None:
74 mask *= (src_adr == (src_idx-src_grid.idx_offset))
75 if tgt_idx is not None:
76 mask *= (tgt_adr == (tgt_idx-tgt_grid.idx_offset))
77
78 if not any(mask):
79 raise Exception("All points are mask.")
80
81 if src_idx is not None or tgt_idx is not None:
82 info(f"Mask: {sum(mask)}/{len(mask)}")
83 src_adr = src_adr[mask, ...]
84 tgt_adr = tgt_adr[mask, ...]
85 # transform to the projection that is used
86 src_t = ax.projection.transform_points(ccrs.PlateCarree(),
87 src_points[0, src_adr], src_points[1, src_adr])
88 info(f"src: {src_points[0, src_adr], src_points[1, src_adr]} -> {src_t[:, 0], src_t[:, 1]}")
89 tgt_t = ax.projection.transform_points(ccrs.PlateCarree(),
90 tgt_points[0, tgt_adr], tgt_points[1, tgt_adr])
91 info(f"tgt: {tgt_points[0, tgt_adr], tgt_points[1, tgt_adr]} -> {tgt_t[:, 0], tgt_t[:, 1]}")
92 e = np.array([min(np.min(src_t[:, 0]), np.min(tgt_t[:, 0])),
93 max(np.max(src_t[:, 0]), np.max(tgt_t[:, 0])),
94 min(np.min(src_t[:, 1]), np.min(tgt_t[:, 1])),
95 max(np.max(src_t[:, 1]), np.max(tgt_t[:, 1]))])
96 c = np.array([0.5*(e[0]+e[1]), 0.5*(e[0]+e[1]),
97 0.5*(e[2]+e[3]), 0.5*(e[2]+e[3])])
98 extent = (e-c)*1.15*zoom + c
99 info(f"{c=}")
100 info(f"{e=}")
101 info(f"{extent=}")
102 ax.set_extent(extent, crs=ax.projection)
103 else:
104 # transform to the projection that is used
105 src_t = ax.projection.transform_points(ccrs.PlateCarree(),
106 src_points[0, src_adr], src_points[1, src_adr])
107 tgt_t = ax.projection.transform_points(ccrs.PlateCarree(),
108 tgt_points[0, tgt_adr], tgt_points[1, tgt_adr])
109
110 extent = ax.get_extent()
111 info(f"Extent: {extent}")
112 # Finalize restriction for arrays
113 mask *= (((src_t[:, 0] >= extent[0]) * (src_t[:, 0] <= extent[1]) *
114 (src_t[:, 1] >= extent[2]) * (src_t[:, 1] <= extent[3])) +
115 ((tgt_t[:, 0] >= extent[0]) * (tgt_t[:, 0] <= extent[1]) *
116 (tgt_t[:, 1] >= extent[2]) * (tgt_t[:, 1] <= extent[3])))
117 info(f"Mask: {sum(mask)}/{len(mask)}")
118
119 src_adr = src_adr[mask, ...]
120 tgt_adr = tgt_adr[mask, ...]
121 src_t = src_t[mask, ...]
122 tgt_t = tgt_t[mask, ...]
123
124 if sum(mask) > 3000:
125 logging.warning(f"Trying to display a lot of points ({sum(mask)}). "
126 "This may take some time. "
127 "Consider a smaller --zoom parameter to reduce the number of points")
128
129 uv = tgt_t - src_t
130
131 c = weights["remap_matrix"][mask, 0]
132
133 norm = matplotlib.colors.Normalize()
134 cm = matplotlib.cm.Oranges # decide for colormap
135 ax.quiver(src_t[:, 0], src_t[:, 1],
136 uv[:, 0], uv[:, 1], angles='xy', scale_units='xy', scale=1,
137 color=cm(norm(c)),
138 **quiver_kwargs)
139 sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
140 sm.set_array([])
141 clb = plt.colorbar(sm, ax=ax)
142 clb.ax.zorder = -1
143
144 if src_idx is not None:
145 ax.set_title(f"Interpolation for source index {src_idx}\nSum of weights: {sum(c):.8f}")
146 if tgt_idx is not None:
147 ax.set_title(f"Interpolation for target index {tgt_idx}\nSum of weights: {sum(c):.8f}")
148 # add pop ups if restricted to one cell:
149 if src_idx is not None or tgt_idx is not None:
150 # add label
151 for x, y, t in zip(src_t[:, 0] + 0.5*uv[:, 0],
152 src_t[:, 1] + 0.5*uv[:, 1], c):
153 ax.text(x, y, f"{t:.3}",
154 horizontalalignment='center', verticalalignment='center')
155
156 wlines = ax.plot([src_t[:,0], tgt_t[:,0]],
157 [src_t[:,1], tgt_t[:,1]],
158 color='white', alpha=0.)
159 wlabel = np.vstack([src_adr, tgt_adr, c.data])
160 annotation = ax.annotate(text='', xy=(0, 0), xytext=(15, 15),
161 textcoords='offset points',
162 bbox={'boxstyle': 'round', 'fc': 'w'},
163 arrowprops={'arrowstyle': '->'},
164 zorder=9999)
165 annotation.set_visible('False')
166
167 def motion_hover(event):
168 annotation_visible = annotation.get_visible()
169 if event.inaxes == ax:
170 if idx := next((idx+1 for idx, wl in enumerate(wlines) if wl.contains(event)[0]), None):
171 annotation.xy = (event.xdata, event.ydata)
172 text_label = f"src: {wlabel[0, idx-1] + src_grid.idx_offset:.0f} tgt: {wlabel[1, idx-1] + tgt_grid.idx_offset:.0f}\nweight: {wlabel[2, idx-1]:.3f}"
173 annotation.set_text(text_label)
174 annotation.set_visible(True)
175 ax.figure.canvas.draw_idle()
176 else:
177 if annotation_visible:
178 annotation.set_visible(False)
179 ax.figure.canvas.draw_idle()
180
181 ax.figure.canvas.mpl_connect('motion_notify_event', motion_hover)
182
183 src_mask = {f"{src_locstr.lower()}_idx": src_adr}
184 tgt_mask = {f"{tgt_locstr.lower()}_idx": tgt_adr}
185
186 return src_mask, tgt_mask
187
188
189def cell_extent(grid, idx, zoom=4):
190 print(f"cell {idx}: ", grid.clon[idx-grid.idx_offset],
191 grid.clat[idx-grid.idx_offset])
192 vidx = grid.vertex_of_cell[:, idx-grid.idx_offset]
193 e = np.array([np.min(grid.vlon[vidx]), np.max(grid.vlon[vidx]),
194 np.min(grid.vlat[vidx]), np.max(grid.vlat[vidx])])
195 c = np.array([0.5*(e[0]+e[1]), 0.5*(e[0]+e[1]),
196 0.5*(e[2]+e[3]), 0.5*(e[2]+e[3])])
197 return (e-c)*zoom + c
198
199
200def main(source_grid, target_grid, weights_file, center=None,
201 source_idx=None, target_idx=None, zoom=1,
202 label_src_grid=None, label_tgt_grid=None,
203 coast_res="50m", projection="orthografic",
204 stencil_only=False,
205 save_as=None, log_level=logging.INFO):
206 logging.basicConfig(level=log_level)
207 src_grid = get_grid(source_grid)
208 if target_grid is not None:
209 tgt_grid = get_grid(target_grid)
210 else:
211 tgt_grid = None
212 fig = plt.figure(figsize=[10, 10])
213 if center is None:
214 center = [0, 0]
215 if source_idx is not None:
216 center = [src_grid.clon[source_idx-src_grid.idx_offset],
217 src_grid.clat[source_idx-src_grid.idx_offset]]
218 elif target_idx is not None:
219 center = [tgt_grid.clon[target_idx-tgt_grid.idx_offset],
220 tgt_grid.clat[target_idx-tgt_grid.idx_offset]]
221
222 if projection == "orthographic":
223 proj = ccrs.Orthographic(*center)
224 elif projection == "stereographic":
225 proj = ccrs.Stereographic(*center[::-1])
226 elif projection == "platecarree":
227 proj = ccrs.PlateCarree(center[0])
228 info(f"{center=}")
229 ax = fig.add_subplot(1, 1, 1, projection=proj)
230
231 if weights_file is None:
232 if source_idx is not None:
233 extent = cell_extent(src_grid, source_idx, zoom)
234 ax.set_extent(extent, crs=ccrs.PlateCarree())
235 elif target_idx is not None:
236 extent = cell_extent(tgt_grid, target_idx, zoom)
237 ax.set_extent(extent, crs=ccrs.PlateCarree())
238 if source_idx is None and target_idx is None:
239 ax.set_extent([center[0]-1000000*zoom, center[0]+1000000*zoom,
240 center[1]-1000000*zoom, center[1]+1000000*zoom], crs=proj)
241
242 # Put a background image on for nice sea rendering.
243 ax.set_facecolor("#9ddbff")
244 if coast_res:
245 feature = cfeature.NaturalEarthFeature(name='land',
246 category='physical',
247 scale=coast_res,
248 edgecolor='#000000',
249 facecolor='#cbe7be')
250 ax.add_feature(feature, zorder=-999)
251 glin = ax.gridlines(draw_labels=True, alpha=0)
252 if source_idx is not None or target_idx is not None:
253 glin.top_labels = False
254
255 if weights_file is not None:
256 src_mask, tgt_mask = plot_weights(ax, weights_file, src_grid, tgt_grid,
257 source_idx, target_idx, zoom, quiver_kwargs={"zorder": 2})
258
259 # Plot grids
260 if src_grid is not None:
261 gridname = src_grid.gridname if hasattr(src_grid, 'gridname') else type(src_grid).__name__
262 src_grid.plot(ax, label=label_src_grid,
263 plot_kwargs={"color": "green",
264 "zorder": 1,
265 "linewidth": 3,
266 "label": f"Source grid: {gridname}"},
267 key="S",
268 **(src_mask if stencil_only else {}))
269 if tgt_grid is not None:
270 gridname = tgt_grid.gridname if hasattr(tgt_grid, 'gridname') else type(tgt_grid).__name__
271 tgt_grid.plot(ax, label=label_tgt_grid,
272 plot_kwargs={"color": "blue",
273 "zorder": 2,
274 "linewidth": 2,
275 "label": f"Target grid: {gridname}"},
276 key="T",
277 **(tgt_mask if stencil_only else {}))
278
279 # remove duplicate legend entry
280 handles, labels = plt.gca().get_legend_handles_labels()
281 by_label = dict(zip(labels, handles))
282 plt.legend(by_label.values(), by_label.keys(), loc='upper center',
283 fancybox=True, shadow=True, ncol=2)
284
285 if save_as:
286 plt.savefig(save_as)
287 else:
288 ax.text(0.0, 0.0, "Press 'S' or 'T' to show source or target grid indices",
289 verticalalignment='bottom', horizontalalignment='left',
290 transform=fig.transFigure)
291 plt.show()
292
293
294if __name__ == '__main__':
295 import argparse
296 parser = argparse.ArgumentParser(prog="plot_weights.py",
297 description="""
298 Plot grids and yac weights file.
299 """)
300 parser.add_argument("source_grid", type=str,
301 help="source grid (for an overview of grids and how they are specified here see grid_utils.py)")
302 parser.add_argument("target_grid", type=str, nargs='?',
303 default=None,
304 help="target grid (for an overview of grids and how they are specified here see grid_utils.py)")
305 parser.add_argument("weights_file", type=str, help="YAC weights file", nargs='?',
306 default=None)
307 parser.add_argument("--center", "-c", type=float, nargs=2, help="center of the orthografic projection",
308 default=(0, 0), metavar=("LON", "LAT"))
309 parser.add_argument("--source_idx", "-s", type=int,
310 help="index of source cell to focus")
311 parser.add_argument("--target_idx", "-t", type=int,
312 help="index of target cell to focus")
313 parser.add_argument("--zoom", "-z", type=float, default=1,
314 help="zoom around the cell")
315 parser.add_argument("--label_src_grid", type=str, default=None,
316 choices=("vertex", "edge", "cell"),
317 help="Add labels at the source grid")
318 parser.add_argument("--label_tgt_grid", type=str, default=None,
319 choices=("vertex", "edge", "cell"),
320 help="Add labels at the source grid")
321 parser.add_argument("--coast_res", type=str, default="50m",
322 nargs='?',
323 choices=("10m", "50m", "110m"),
324 help="Resolution of coastlines (def 50m).\nOmit argument to disable coastlines.")
325 parser.add_argument("--projection", type=str, default="orthographic", choices=("orthographic", "stereographic", "platecarree"),
326 nargs="?", help="Type of projection")
327 parser.add_argument("--stencil_only", action="store_true")
328 parser.add_argument("--save_as", type=argparse.FileType("wb"), help="Save to file instead of showing the figure")
329 parser.add_argument("--log-level", default=logging.WARNING, type=lambda x: getattr(logging, x.upper()),
330 help="Configure the logging level.")
331 args = parser.parse_args()
332 main(**args.__dict__)
cell_extent(grid, idx, zoom=4)