YetAnotherCoupler 3.2.0_a
Loading...
Searching...
No Matches
plot_weights.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2024 The YAC Authors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6
7from netCDF4 import Dataset
8import matplotlib
9import matplotlib.pyplot as plt
10import cartopy.crs as ccrs
11import cartopy.feature as cfeature
12import numpy as np
13from grid_utils import get_grid
14
15
16def plot_weights(fig, ax, weightsfile, src_grid, tgt_grid,
17 src_idx=None, tgt_idx=None, zoom=1, quiver_kwargs={}):
18 weights = Dataset(weightsfile, "r")
19 if 'version' in weights.__dict__:
20 if weights.version != "yac weight file 1.0":
21 print("WARNING: You are using an incompatible weight file version\n" +
22 weights.version + " != yac weight file 1.0")
23 src_points = np.empty([2, 0])
24 num_src_fields = weights.num_src_fields if "num_src_fields" in weights.variables else 1
25 for src_field in range(num_src_fields):
26 if "src_locations" in weights.variables:
27 locstr = bytes(np.array(weights["src_locations"][src_field, :])).decode().rstrip("\0")
28 else:
29 locstr = "CELL"
30 if locstr == "CELL":
31 src_points = np.hstack([src_points, np.stack([src_grid.clon,
32 src_grid.clat])])
33 elif locstr == "CORNER":
34 src_points = np.hstack([src_points, np.stack([src_grid.vlon,
35 src_grid.vlat])])
36 elif locstr == "EDGE":
37 src_points = np.hstack([src_points, np.stack([src_grid.vlon,
38 src_grid.vlat])])
39 else:
40 raise f"Unknown location string {locstr}"
41 if "dst_locations" in weights.variables:
42 locstr = bytes(np.array(weights["dst_location"])).decode().rstrip("\0")
43 else:
44 locstr = "CELL"
45 if locstr == "CELL":
46 tgt_points = np.stack([tgt_grid.clon, tgt_grid.clat])
47 elif locstr == "CORNER":
48 tgt_points = np.stack([tgt_grid.vlon, tgt_grid.vlat])
49 elif locstr == "EDGE":
50 tgt_points = np.stack([tgt_grid.elon, tgt_grid.elat])
51 else:
52 raise f"Unknown location string {locstr}"
53
54 yac_weights_format_address_offset = 1
55 src_adr = np.asarray(weights["src_address"])-yac_weights_format_address_offset-src_grid.idx_offset
56 tgt_adr = np.asarray(weights["dst_address"])-yac_weights_format_address_offset-tgt_grid.idx_offset
57
58 # Remove redundant links as in SCRIP bilinear and bicubic weights
59 mask = (src_adr >= 0)
60
61 # Restrain plot for targeted cells
62 if src_idx is not None:
63 mask *= (src_adr == (src_idx-src_grid.idx_offset))
64 if tgt_idx is not None:
65 mask *= (tgt_adr == (tgt_idx-tgt_grid.idx_offset))
66
67 if src_idx is not None or tgt_idx is not None:
68 src_res = src_adr[mask]
69 tgt_res = tgt_adr[mask]
70 e = np.array([min(np.min(src_points[0, src_res]), np.min(tgt_points[0, tgt_res])),
71 max(np.max(src_points[0, src_res]), np.max(tgt_points[0, tgt_res])),
72 min(np.min(src_points[1, src_res]), np.min(tgt_points[1, tgt_res])),
73 max(np.max(src_points[1, src_res]), np.max(tgt_points[1, tgt_res]))])
74 c = np.array([0.5*(e[0]+e[1]), 0.5*(e[0]+e[1]),
75 0.5*(e[2]+e[3]), 0.5*(e[2]+e[3])])
76 extent = (e-c)*1.15*zoom + c
77 ax.set_extent(extent, crs=ccrs.PlateCarree())
78 else:
79 extent = ax.get_extent(crs=ccrs.PlateCarree())
80
81 # Finalize restriction for arrays
82 mask *= (((src_points[0, src_adr] >= extent[0]) * (src_points[0, src_adr] <= extent[1]) *
83 (src_points[1, src_adr] >= extent[2]) * (src_points[1, src_adr] <= extent[3])) +
84 ((tgt_points[0, tgt_adr] >= extent[0]) * (tgt_points[0, tgt_adr] <= extent[1]) *
85 (tgt_points[1, tgt_adr] >= extent[2]) * (tgt_points[1, tgt_adr] <= extent[3])))
86
87 src_adr = src_adr[mask]
88 tgt_adr = tgt_adr[mask]
89
90 ax_proj = ax.projection
91 src_t = ax_proj.transform_points(ccrs.PlateCarree(),
92 src_points[0, src_adr], src_points[1, src_adr])
93 tgt_t = ax_proj.transform_points(ccrs.PlateCarree(),
94 tgt_points[0, tgt_adr], tgt_points[1, tgt_adr])
95 uv = tgt_t-src_t
96
97 c = weights["remap_matrix"][mask, 0]
98
99 norm = matplotlib.colors.Normalize()
100 cm = matplotlib.cm.Oranges # decide for colormap
101 ax.quiver(src_t[:, 0], src_t[:, 1],
102 uv[:, 0], uv[:, 1], angles='xy', scale_units='xy', scale=1,
103 color=cm(norm(c)),
104 **quiver_kwargs)
105 sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
106 sm.set_array([])
107 plt.colorbar(sm, ax=ax)
108
109 if src_idx is not None:
110 ax.set_title(f"Interpolation for source index {src_idx}\nSum of weights: {sum(c):.8f}")
111 if tgt_idx is not None:
112 ax.set_title(f"Interpolation for target index {tgt_idx}\nSum of weights: {sum(c):.8f}")
113 # add pop ups if restricted to one cell:
114 if src_idx is not None or tgt_idx is not None:
115 # add label
116 for x, y, t in zip(src_t[:, 0] + 0.5*uv[:, 0],
117 src_t[:, 1] + 0.5*uv[:, 1], c):
118 ax.text(x, y, f"{t:.3}",
119 horizontalalignment='center', verticalalignment='center')
120
121 wlines = ax.plot([src_points[0, src_adr], tgt_points[0, tgt_adr]],
122 [src_points[1, src_adr], tgt_points[1, tgt_adr]],
123 color='white', alpha=0.,
124 transform=ccrs.PlateCarree())
125 wcenter = np.vstack([0.5*(src_points[0, src_adr]+tgt_points[0, tgt_adr]),
126 0.5*(src_points[1, src_adr]+tgt_points[1, tgt_adr])])
127 wlabel = np.vstack([src_adr, tgt_adr, c.data])
128 annotation = ax.annotate(text='', xy=(0, 0), xytext=(15, 15),
129 textcoords='offset points',
130 bbox={'boxstyle': 'round', 'fc': 'w'},
131 arrowprops={'arrowstyle': '->'},
132 transform=ccrs.PlateCarree(),
133 zorder=9999)
134 annotation.set_visible('False')
135
136 def motion_hover(event):
137 annotation_visible = annotation.get_visible()
138 if event.inaxes == ax:
139 is_on = False
140 for idx, wl in enumerate(wlines):
141 if wl.contains(event)[0]:
142 is_on = True
143 break
144 if is_on:
145 annotation.xy = (wcenter[0, idx], wcenter[1, idx])
146 text_label = 'src: {} tgt: {}\nweight: {:.3f}'.format(int(wlabel[0, idx]),
147 int(wlabel[1, idx]),
148 wlabel[2, idx])
149 annotation.set_text(text_label)
150 annotation.set_visible(True)
151 fig.canvas.draw_idle()
152 else:
153 if annotation_visible:
154 annotation.set_visible(False)
155 fig.canvas.draw_idle()
156
157 fig.canvas.mpl_connect('motion_notify_event', motion_hover)
158
159
160def cell_extent(grid, idx, zoom=4):
161 print(f"cell {idx}: ", grid.clon[idx-grid.idx_offset],
162 grid.clat[idx-grid.idx_offset])
163 vidx = grid.vertex_of_cell[:, idx-grid.idx_offset]
164 e = np.array([np.min(grid.vlon[vidx]), np.max(grid.vlon[vidx]),
165 np.min(grid.vlat[vidx]), np.max(grid.vlat[vidx])])
166 c = np.array([0.5*(e[0]+e[1]), 0.5*(e[0]+e[1]),
167 0.5*(e[2]+e[3]), 0.5*(e[2]+e[3])])
168 return (e-c)*zoom + c
169
170
171def main(source_grid, target_grid, weights_file, center=None,
172 source_idx=None, target_idx=None, zoom=1,
173 label_src_grid=None, label_tgt_grid=None,
174 coast_res="50m", save_as=None):
175 src_grid = get_grid(source_grid)
176 if target_grid is not None:
177 tgt_grid = get_grid(target_grid)
178 else:
179 tgt_grid = None
180 fig = plt.figure(figsize=[10, 10])
181 if center is None:
182 center = [0, 0]
183 if source_idx is not None:
184 center = [src_grid.clon[source_idx-src_grid.idx_offset],
185 src_grid.clat[source_idx-src_grid.idx_offset]]
186 elif target_idx is not None:
187 center = [tgt_grid.clon[target_idx-tgt_grid.idx_offset],
188 tgt_grid.clat[target_idx-tgt_grid.idx_offset]]
189
190 proj = ccrs.Orthographic(*center)
191 ax = fig.add_subplot(1, 1, 1, projection=proj)
192
193 if weights_file is None:
194 if source_idx is not None:
195 extent = cell_extent(src_grid, source_idx, zoom)
196 ax.set_extent(extent, crs=ccrs.PlateCarree())
197 elif target_idx is not None:
198 extent = cell_extent(tgt_grid, target_idx, zoom)
199 ax.set_extent(extent, crs=ccrs.PlateCarree())
200 if source_idx is None and target_idx is None:
201 ax.set_extent([center[0]-1000000*zoom, center[0]+1000000*zoom,
202 center[1]-1000000*zoom, center[1]+1000000*zoom], crs=proj)
203
204 # Put a background image on for nice sea rendering.
205 ax.set_facecolor("#9ddbff")
206 if coast_res:
207 feature = cfeature.NaturalEarthFeature(name='land',
208 category='physical',
209 scale=coast_res,
210 edgecolor='#000000',
211 facecolor='#cbe7be')
212 ax.add_feature(feature, zorder=-999)
213
214 if weights_file is not None:
215 plot_weights(fig, ax, weights_file, src_grid, tgt_grid,
216 source_idx, target_idx, zoom, quiver_kwargs={"zorder": 2})
217
218 # Plot grids
219 if src_grid is not None:
220 src_grid.plot(ax, label=label_src_grid, plot_kwargs={"color": "green", "zorder": 1})
221 if tgt_grid is not None:
222 tgt_grid.plot(ax, label=label_tgt_grid, plot_kwargs={"color": "blue", "zorder": 1})
223
224 if save_as:
225 plt.savefig(save_as)
226 else:
227 plt.show()
228
229
230if __name__ == '__main__':
231 import argparse
232 parser = argparse.ArgumentParser(prog="plot_weights.py",
233 description="""
234 plot grids and yac weights file.
235
236 Grids can be specified either a filename.
237 This is iterpreted like an ICON grid file or
238 a string like "g360,180" from which a structured
239 grid is generated. Use a capital G for 1-based
240 indexing and a small g for 0-basid indexing
241 followed by the resolution (lon,lat). Optionally
242 you can add an extent by adding further 4 numbers
243 (min_lon, max_lon, min_lat, max_lat). E.g.
244 g100,100,-50,-45,-5,5
245 """)
246 parser.add_argument("source_grid", type=str,
247 help="source grid (netCDF file or [gG]lon,lat[,min_lon,max_lon,min_lat,max_lat])")
248 parser.add_argument("target_grid", type=str, nargs='?',
249 default=None,
250 help="target grid (netCDF file or [gG]lon,lat[,min_lon,max_lon,min_lat,max_lat])")
251 parser.add_argument("weights_file", type=str, help="YAC weights file", nargs='?',
252 default=None)
253 parser.add_argument("--center", "-c", type=float, nargs=2, help="center of the orthografic projection",
254 default=(0, 0), metavar=("LON", "LAT"))
255 parser.add_argument("--source_idx", "-s", type=int,
256 help="index of source cell to focus")
257 parser.add_argument("--target_idx", "-t", type=int,
258 help="index of target cell to focus")
259 parser.add_argument("--zoom", "-z", type=float, default=1,
260 help="zoom around the cell")
261 parser.add_argument("--label_src_grid", type=str, default=None,
262 choices=("vertex", "edge", "cell"),
263 help="Add labels at the source grid")
264 parser.add_argument("--label_tgt_grid", type=str, default=None,
265 choices=("vertex", "edge", "cell"),
266 help="Add labels at the source grid")
267 parser.add_argument("--coast_res", type=str, default="50m",
268 nargs='?',
269 choices=("10m", "50m", "110m"),
270 help="Resolution of coastlines (def 50m).\nOmit argument to disable coastlines.")
271 parser.add_argument("--save_as", type=str, help="Save to file instead of showing the figure")
272 args = parser.parse_args()
273 main(**args.__dict__)
cell_extent(grid, idx, zoom=4)
main(source_grid, target_grid, weights_file, center=None, source_idx=None, target_idx=None, zoom=1, label_src_grid=None, label_tgt_grid=None, coast_res="50m", save_as=None)