18 src_idx=None, tgt_idx=None, zoom=1, quiver_kwargs={}):
19 weights = Dataset(weightsfile,
"r")
20 if 'version' in weights.__dict__:
21 info(f
"{weights.version=}")
22 if weights.version !=
"yac weight file 1.0":
23 print(
"WARNING: You are using an incompatible weight file version\n" +
24 weights.version +
" != yac weight file 1.0")
25 num_src_fields = weights.num_src_fields
if "num_src_fields" in weights.variables
else 1
26 info(f
"{num_src_fields=}")
27 for src_field
in range(num_src_fields):
28 if "src_locations" in weights.variables:
29 src_locstr = bytes(np.array(weights[
"src_locations"][src_field, :])).decode().rstrip(
"\0")
32 info(f
"source loc: {src_locstr}")
33 if src_locstr ==
"CELL":
34 src_points = np.stack([src_grid.clon, src_grid.clat])
35 elif src_locstr ==
"CORNER":
36 src_points = np.stack([src_grid.vlon, src_grid.vlat])
37 elif src_locstr ==
"EDGE":
38 src_points = np.stack([src_grid.vlon, src_grid.vlat])
40 raise f
"Unknown location string {src_locstr}"
41 if "dst_locations" in weights.variables:
42 tgt_locstr = bytes(np.array(weights[
"dst_location"])).decode().rstrip(
"\0")
45 info(f
"target loc: {tgt_locstr}")
46 if tgt_locstr ==
"CELL":
47 tgt_points = np.stack([tgt_grid.clon, tgt_grid.clat])
48 elif tgt_locstr ==
"CORNER":
49 tgt_points = np.stack([tgt_grid.vlon, tgt_grid.vlat])
50 elif tgt_locstr ==
"EDGE":
51 tgt_points = np.stack([tgt_grid.elon, tgt_grid.elat])
53 raise f
"Unknown location string {tgt_locstr}"
55 yac_weights_format_address_offset = 1
56 src_adr = np.asarray(weights[
"src_address"])-yac_weights_format_address_offset-src_grid.idx_offset
57 tgt_adr = np.asarray(weights[
"dst_address"])-yac_weights_format_address_offset-tgt_grid.idx_offset
59 info(f
"Source indices in weight file range from {np.min(src_adr)} to {np.max(src_adr)}")
60 info(f
"Number of source points is {src_points.shape[1]}")
61 info(f
"Target indices in weight file range from {np.min(tgt_adr)} to {np.max(tgt_adr)}")
62 info(f
"Number of target points is {tgt_points.shape[1]}")
64 if np.max(src_adr) > src_points.shape[1]:
65 raise Exception(f
"Source grid too small. Max index in weight file is {np.max(src_adr)}, number of source points are {src_points.shape[1]}")
66 if np.max(tgt_adr) > tgt_points.shape[1]:
67 raise Exception(f
"Target grid too small. Max index in weight file is {np.max(tgt_adr)}, number of source points are {tgt_points.shape[1]}")
73 if src_idx
is not None:
74 mask *= (src_adr == (src_idx-src_grid.idx_offset))
75 if tgt_idx
is not None:
76 mask *= (tgt_adr == (tgt_idx-tgt_grid.idx_offset))
79 raise Exception(
"All points are mask.")
81 if src_idx
is not None or tgt_idx
is not None:
82 info(f
"Mask: {sum(mask)}/{len(mask)}")
83 src_adr = src_adr[mask, ...]
84 tgt_adr = tgt_adr[mask, ...]
86 src_t = ax.projection.transform_points(ccrs.PlateCarree(),
87 src_points[0, src_adr], src_points[1, src_adr])
88 info(f
"src: {src_points[0, src_adr], src_points[1, src_adr]} -> {src_t[:, 0], src_t[:, 1]}")
89 tgt_t = ax.projection.transform_points(ccrs.PlateCarree(),
90 tgt_points[0, tgt_adr], tgt_points[1, tgt_adr])
91 info(f
"tgt: {tgt_points[0, tgt_adr], tgt_points[1, tgt_adr]} -> {tgt_t[:, 0], tgt_t[:, 1]}")
92 e = np.array([min(np.min(src_t[:, 0]), np.min(tgt_t[:, 0])),
93 max(np.max(src_t[:, 0]), np.max(tgt_t[:, 0])),
94 min(np.min(src_t[:, 1]), np.min(tgt_t[:, 1])),
95 max(np.max(src_t[:, 1]), np.max(tgt_t[:, 1]))])
96 c = np.array([0.5*(e[0]+e[1]), 0.5*(e[0]+e[1]),
97 0.5*(e[2]+e[3]), 0.5*(e[2]+e[3])])
98 extent = (e-c)*1.15*zoom + c
102 ax.set_extent(extent, crs=ax.projection)
105 src_t = ax.projection.transform_points(ccrs.PlateCarree(),
106 src_points[0, src_adr], src_points[1, src_adr])
107 tgt_t = ax.projection.transform_points(ccrs.PlateCarree(),
108 tgt_points[0, tgt_adr], tgt_points[1, tgt_adr])
110 extent = ax.get_extent()
111 info(f
"Extent: {extent}")
113 mask *= (((src_t[:, 0] >= extent[0]) * (src_t[:, 0] <= extent[1]) *
114 (src_t[:, 1] >= extent[2]) * (src_t[:, 1] <= extent[3])) +
115 ((tgt_t[:, 0] >= extent[0]) * (tgt_t[:, 0] <= extent[1]) *
116 (tgt_t[:, 1] >= extent[2]) * (tgt_t[:, 1] <= extent[3])))
117 info(f
"Mask: {sum(mask)}/{len(mask)}")
119 src_adr = src_adr[mask, ...]
120 tgt_adr = tgt_adr[mask, ...]
121 src_t = src_t[mask, ...]
122 tgt_t = tgt_t[mask, ...]
125 logging.warning(f
"Trying to display a lot of points ({sum(mask)}). "
126 "This may take some time. "
127 "Consider a smaller --zoom parameter to reduce the number of points")
131 c = weights[
"remap_matrix"][mask, 0]
133 norm = matplotlib.colors.Normalize()
134 cm = matplotlib.cm.Oranges
135 ax.quiver(src_t[:, 0], src_t[:, 1],
136 uv[:, 0], uv[:, 1], angles=
'xy', scale_units=
'xy', scale=1,
139 sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
141 clb = plt.colorbar(sm, ax=ax)
144 if src_idx
is not None:
145 ax.set_title(f
"Interpolation for source index {src_idx}\nSum of weights: {sum(c):.8f}")
146 if tgt_idx
is not None:
147 ax.set_title(f
"Interpolation for target index {tgt_idx}\nSum of weights: {sum(c):.8f}")
149 if src_idx
is not None or tgt_idx
is not None:
151 for x, y, t
in zip(src_t[:, 0] + 0.5*uv[:, 0],
152 src_t[:, 1] + 0.5*uv[:, 1], c):
153 ax.text(x, y, f
"{t:.3}",
154 horizontalalignment=
'center', verticalalignment=
'center')
156 wlines = ax.plot([src_t[:,0], tgt_t[:,0]],
157 [src_t[:,1], tgt_t[:,1]],
158 color=
'white', alpha=0.)
159 wlabel = np.vstack([src_adr, tgt_adr, c.data])
160 annotation = ax.annotate(text=
'', xy=(0, 0), xytext=(15, 15),
161 textcoords=
'offset points',
162 bbox={
'boxstyle':
'round',
'fc':
'w'},
163 arrowprops={
'arrowstyle':
'->'},
165 annotation.set_visible(
'False')
167 def motion_hover(event):
168 annotation_visible = annotation.get_visible()
169 if event.inaxes == ax:
170 if idx := next((idx+1
for idx, wl
in enumerate(wlines)
if wl.contains(event)[0]),
None):
171 annotation.xy = (event.xdata, event.ydata)
172 text_label = f
"src: {wlabel[0, idx-1] + src_grid.idx_offset:.0f} tgt: {wlabel[1, idx-1] + tgt_grid.idx_offset:.0f}\nweight: {wlabel[2, idx-1]:.3f}"
173 annotation.set_text(text_label)
174 annotation.set_visible(
True)
175 ax.figure.canvas.draw_idle()
177 if annotation_visible:
178 annotation.set_visible(
False)
179 ax.figure.canvas.draw_idle()
181 ax.figure.canvas.mpl_connect(
'motion_notify_event', motion_hover)
183 src_mask = {f
"{src_locstr.lower()}_idx": src_adr}
184 tgt_mask = {f
"{tgt_locstr.lower()}_idx": tgt_adr}
186 return src_mask, tgt_mask