From e7205dfa274241cd1135c71fcde6a57da0a4e127 Mon Sep 17 00:00:00 2001 From: Alexey Pechnikov Date: Wed, 28 Aug 2024 02:30:15 +0700 Subject: [PATCH] Code refactoring --- pygmtsar/pygmtsar/Stack_dem.py | 35 +--- pygmtsar/pygmtsar/Stack_phasediff.py | 303 +++------------------------ pygmtsar/pygmtsar/Stack_topo.py | 166 +++++---------- pygmtsar/pygmtsar/utils.py | 50 +++++ 4 files changed, 139 insertions(+), 415 deletions(-) diff --git a/pygmtsar/pygmtsar/Stack_dem.py b/pygmtsar/pygmtsar/Stack_dem.py index 8afc8a4..7bfea42 100644 --- a/pygmtsar/pygmtsar/Stack_dem.py +++ b/pygmtsar/pygmtsar/Stack_dem.py @@ -10,6 +10,7 @@ from .Stack_reframe import Stack_reframe from .PRM import PRM from .tqdm_dask import tqdm_dask +from .utils import utils class Stack_dem(Stack_reframe): @@ -60,45 +61,13 @@ def get_geoid(self, grid=None): See EGM96 geoid heights on http://icgem.gfz-potsdam.de/tom_longtime """ import xarray as xr - import dask.array as da import os import importlib.resources as resources - import warnings - # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" - warnings.filterwarnings('ignore') - warnings.filterwarnings('ignore', module='dask') - warnings.filterwarnings('ignore', module='dask.core') - - # use outer variable geoid - def interpolate_chunk(grid_chunk, grid_lat_chunk, grid_lon_chunk, method='cubic'): - dlat, dlon = float(geoid.lat.diff('lat')[0]), float(geoid.lon.diff('lon')[0]) - geoid_chunk = geoid.sel( - lat=slice(grid_lat_chunk[0]-2*dlat, grid_lat_chunk[-1]+2*dlat), - lon=slice(grid_lon_chunk[0]-2*dlon, grid_lon_chunk[-1]+2*dlon) - ).compute() - #print ('geoid_chunk', geoid_chunk) - return geoid_chunk.interp({'lat': grid_lat_chunk, 'lon': grid_lon_chunk}, method=method) - with resources.as_file(resources.files('pygmtsar.data') / 'geoid_egm96_icgem.grd') as geoid_filename: geoid = xr.open_dataarray(geoid_filename, engine=self.netcdf_engine, chunks=self.netcdf_chunksize).rename({'y': 'lat', 'x': 'lon'}) if grid is not None: - # Xarray interpolation struggles with large grids - #geoid = geoid.interp_like(grid.coords, method='linear') - # grid.data is needed only to prevent excessive memory usage during interpolation - geoid_grid = da.blockwise( - interpolate_chunk, - 'ij', - grid.data, - 'ij', - grid.lat.data, - 'i', - grid.lon.data, - 'j', - dtype=geoid.dtype - ) - return xr.DataArray(geoid_grid, coords=grid.coords).rename(geoid.name) - + return utils.interp2d_like(geoid, grid) return geoid def set_dem(self, dem_filename): diff --git a/pygmtsar/pygmtsar/Stack_phasediff.py b/pygmtsar/pygmtsar/Stack_phasediff.py index c2cd32b..dd4f7e2 100644 --- a/pygmtsar/pygmtsar/Stack_phasediff.py +++ b/pygmtsar/pygmtsar/Stack_phasediff.py @@ -10,6 +10,7 @@ from .Stack_topo import Stack_topo from .tqdm_dask import tqdm_dask from .PRM import PRM +from .utils import utils class Stack_phasediff(Stack_topo): @@ -442,20 +443,10 @@ def correlation(self, phase, intensity, debug=False): # return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).rename('phasediff') def phasediff(self, pairs, data='auto', topo='auto', phase=None, method='nearest', joblib_backend=None, debug=False): - import pandas as pd - import dask - import dask.dataframe + #import dask + import dask.array as da import xarray as xr import numpy as np - #from tqdm.auto import tqdm - import joblib - from joblib.externals import loky - loky.get_reusable_executor(kill_workers=True).shutdown(wait=True) - import warnings - # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" - warnings.filterwarnings('ignore') - warnings.filterwarnings('ignore', module='dask') - warnings.filterwarnings('ignore', module='dask.core') if debug: print ('DEBUG: phasediff') @@ -467,272 +458,38 @@ def phasediff(self, pairs, data='auto', topo='auto', phase=None, method='nearest pairs, dates = self.get_pairs(pairs, dates=True) pairs = pairs[['ref', 'rep']].astype(str).values - if isinstance(topo, str) and topo == 'auto': - topo = self.get_topo() - - # calculate the combined earth curvature and topography correction - def calc_drho(rho, topo, earth_radius, height, b, alpha, Bx): - sina = np.sin(alpha) - cosa = np.cos(alpha) - c = earth_radius + height - # compute the look angle using equation (C26) in Appendix C - # GMTSAR uses long double here - #ret = earth_radius + topo.astype(np.longdouble) - ret = earth_radius + topo - cost = ((rho**2 + c**2 - ret**2) / (2. * rho * c)) - #if (cost >= 1.) - # die("calc_drho", "cost >= 0"); - sint = np.sqrt(1. - cost**2) - # Compute the offset effect from non-parallel orbit - term1 = rho**2 + b**2 - 2 * rho * b * (sint * cosa - cost * sina) - Bx**2 - drho = -rho + np.sqrt(term1) - del term1, sint, cost, ret, c, cosa, sina - return drho - - def block_phasediff(date1, date2, prm1, prm2, ylim, xlim): - # use outer variables date, stack_prm - # disable "distributed.utils_perf - WARNING - full garbage collections ..." - try: - from dask.distributed import utils_perf - utils_perf.disable_gc_diagnosis() - except ImportError: - from distributed.gc import disable_gc_diagnosis - disable_gc_diagnosis() - import warnings - # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" - warnings.filterwarnings('ignore') - warnings.filterwarnings('ignore', module='dask') - warnings.filterwarnings('ignore', module='dask.core') - - # for lazy Dask ddataframes - #prm1 = PRM(prm1) - #prm2 = PRM(prm2) - #prm1, prm2 = stack_prm[stack_idx] - #data1, data2 = stack_data[stack_idx] - data1 = data.sel(date=date1) - data2 = data.sel(date=date2) - - # convert indices 0.5, 1.5,... to 0,1,... for easy calculations - block_data1 = data1.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1])).compute(n_workers=1) - block_data2 = data2.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1])).compute(n_workers=1) - del data1, data2 - - if abs(block_data1).sum() == 0: - intf = np.nan * xr.zeros_like(block_data1) - del block_data1, block_data2 - return intf - - ys = block_data1.y.astype(int) - xs = block_data1.x.astype(int) - - block_data1 = block_data1.assign_coords(y=ys, x=xs) - block_data2 = block_data2.assign_coords(y=ys, x=xs) - - if isinstance(topo, xr.DataArray): - dy, dx = topo.y.diff('y').item(0), topo.x.diff('x').item(0) - - # use outer variables topo, data1, data2, prm1, prm2 - # build topo block - if not isinstance(topo, xr.DataArray): - # topography is a constant, typically, zero - block_topo = topo * xr.ones_like(block_data1, dtype=np.float32) - elif dy == 1 and dx == 1: - # topography is already in the original resolution - block_topo = topo.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\ - .compute(n_workers=1)\ - .fillna(0)\ - .assign_coords(y=ys, x=xs) - else: - # topography resolution is different, interpolation with extrapolation required - # convert indices 0.5, 1.5,... to 0,1,... for easy calculations - # fill NaNs by zero because typically DEM is missed outside of land areas - block_topo = topo.sel(y=slice(ys[0]-2*dy, ys[-1]+2*dy), x=slice(xs[0]-2*dx, xs[-1]+2*dx))\ - .compute(n_workers=1)\ - .fillna(0)\ - .interp({'y': block_data1.y, 'x': block_data1.x}, method=method, kwargs={'fill_value': 'extrapolate'})\ - .assign_coords(y=ys, x=xs) - - if phase is not None: - dy, dx = phase.y.diff('y').item(0), phase.x.diff('x').item(0) - if dy == 1 and dx == 1: - # phase is already in the original resolution - block_phase = phase.sel(pair=f'{date1} {date2}').isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\ - .compute(n_workers=1)\ - .assign_coords(y=ys, x=xs) - else: - # phase resolution is different, interpolation with extrapolation required - # convert indices 0.5, 1.5,... to 0,1,... for easy calculations - block_phase = phase.sel(pair=f'{date1} {date2}').sel(y=slice(ys[0]-2*dy, ys[-1]+2*dy), x=slice(xs[0]-2*dx, xs[-1]+2*dx))\ - .compute(n_workers=1)\ - .interp({'y': block_data1.y, 'x': block_data1.x}, method=method, kwargs={'fill_value': 'extrapolate'})\ - .assign_coords(y=ys, x=xs) - # set dimensions - xdim = prm1.get('num_rng_bins') - ydim = prm1.get('num_patches') * prm1.get('num_valid_az') - - # set heights - htc = prm1.get('SC_height') - ht0 = prm1.get('SC_height_start') - htf = prm1.get('SC_height_end') - - # compute the time span and the time spacing - tspan = 86400 * abs(prm2.get('SC_clock_stop') - prm2.get('SC_clock_start')) - assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), 'Check sc_clock_start, sc_clock_end, or PRF' - - from scipy import constants - # setup the default parameters - # constant from GMTSAR code for consistency - #SOL = 299792456.0 - drange = constants.speed_of_light / (2 * prm2.get('rng_samp_rate')) - #drange = SOL / (2 * prm2.get('rng_samp_rate')) - alpha = prm2.get('alpha_start') * np.pi / 180 - cnst = -4 * np.pi / prm2.get('radar_wavelength') - - # calculate initial baselines - Bh0 = prm2.get('baseline_start') * np.cos(prm2.get('alpha_start') * np.pi / 180) - Bv0 = prm2.get('baseline_start') * np.sin(prm2.get('alpha_start') * np.pi / 180) - Bhf = prm2.get('baseline_end') * np.cos(prm2.get('alpha_end') * np.pi / 180) - Bvf = prm2.get('baseline_end') * np.sin(prm2.get('alpha_end') * np.pi / 180) - Bx0 = prm2.get('B_offset_start') - Bxf = prm2.get('B_offset_end') - - # first case is quadratic baseline model, second case is default linear model - if prm2.get('baseline_center') != 0 or prm2.get('alpha_center') != 0 or prm2.get('B_offset_center') != 0: - Bhc = prm2.get('baseline_center') * np.cos(prm2.get('alpha_center') * np.pi / 180) - Bvc = prm2.get('baseline_center') * np.sin(prm2.get('alpha_center') * np.pi / 180) - Bxc = prm2.get('B_offset_center') - - dBh = (-3 * Bh0 + 4 * Bhc - Bhf) / tspan - dBv = (-3 * Bv0 + 4 * Bvc - Bvf) / tspan - ddBh = (2 * Bh0 - 4 * Bhc + 2 * Bhf) / (tspan * tspan) - ddBv = (2 * Bv0 - 4 * Bvc + 2 * Bvf) / (tspan * tspan) - - dBx = (-3 * Bx0 + 4 * Bxc - Bxf) / tspan - ddBx = (2 * Bx0 - 4 * Bxc + 2 * Bxf) / (tspan * tspan) - else: - dBh = (Bhf - Bh0) / tspan - dBv = (Bvf - Bv0) / tspan - dBx = (Bxf - Bx0) / tspan - ddBh = ddBv = ddBx = 0 - - # calculate height increment - dht = (-3 * ht0 + 4 * htc - htf) / tspan - ddht = (2 * ht0 - 4 * htc + 2 * htf) / (tspan * tspan) - - # multiply xr.ones_like(topo) for correct broadcasting - near_range = xr.ones_like(block_topo)*(prm1.get('near_range') + \ - block_topo.x * (1 + prm1.get('stretch_r')) * drange) + \ - xr.ones_like(block_topo)*(block_topo.y * prm1.get('a_stretch_r') * drange) - - # calculate the change in baseline and height along the frame if topoflag is on - time = block_topo.y * tspan / (ydim - 1) - Bh = Bh0 + dBh * time + ddBh * time**2 - Bv = Bv0 + dBv * time + ddBv * time**2 - Bx = Bx0 + dBx * time + ddBx * time**2 - B = np.sqrt(Bh * Bh + Bv * Bv) - alpha = np.arctan2(Bv, Bh) - height = ht0 + dht * time + ddht * time**2 - - # calculate the combined earth curvature and topography correction - drho = calc_drho(near_range, block_topo, prm1.get('earth_radius'), height, B, alpha, Bx) - - # make topographic and model phase corrections - # GMTSAR uses float32 complex operations with precision loss - #phase_shift = np.exp(1j * (cnst * drho).astype(np.float32)) - if phase is not None: - phase_shift = np.exp(1j * (cnst * drho - block_phase)) - # or the same expression in other way - #phase_shift = np.exp(1j * (cnst * drho)) * np.exp(-1j * block_phase) - del block_phase - else: - phase_shift = np.exp(1j * (cnst * drho)) - del block_topo, near_range, drho, height, B, alpha, Bx, Bv, Bh, time - - # calculate phase difference - intf = block_data1 * phase_shift * np.conj(block_data2) - del block_data1, block_data2, phase_shift - return intf.astype(np.complex64) - - # # prepare lazy PRM - # # this is the "true way" but processing is ~40% slower due to additional Dask tasks - # def block_prms(date1, date2): - # prm1 = self.PRM(date1) - # prm2 = self.PRM(date2) - # prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() - # prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() - # return (prm1.df, prm2.df) - # # Define metadata explicitly to match PRM dataframe - # prm_meta = pd.DataFrame(columns=['name', 'value']).astype({'name': 'str', 'value': 'object'}).set_index('name') - - # immediately prepare PRM - # here is some delay on the function call but the actual processing is faster - def prepare_prms(pair): - date1, date2 = pair - prm1 = self.PRM(date1) - prm2 = self.PRM(date2) - prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() - prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() - return {(date1, date2): (prm1, prm2)} - - #with self.tqdm_joblib(tqdm(desc=f'Pre-Processing PRM', total=len(pairs))) as progress_bar: - prms = joblib.Parallel(n_jobs=-1, backend=joblib_backend)(joblib.delayed(prepare_prms)(pair) for pair in pairs) - # convert the list of dicts to a single dict - prms = {k: v for d in prms for k, v in d.items()} - if isinstance(data, str) and data == 'auto': # open datafiles required for all the pairs data = self.open_data(dates) - # define blocks - chunks = data.chunks - ychunks, xchunks = chunks[1], chunks[2] - ychunks = np.concatenate([[0], np.cumsum(ychunks)]) - xchunks = np.concatenate([[0], np.cumsum(xchunks)]) - ylims = [(y1, y2) for y1, y2 in zip(ychunks, ychunks[1:])] - xlims = [(x1, x2) for x1, x2 in zip(xchunks, xchunks[1:])] - #print ('ylims', ylims) - #print ('xlims', xlims) - - stack = [] - for stack_idx, pair in enumerate(pairs): - date1, date2 = pair - - # Create a Dask DataFrame with provided metadata for each Dask block - #prms = dask.delayed(block_prms)(date1, date2) - #prm1 = dask.dataframe.from_delayed(dask.delayed(prms[0]), meta=prm_meta) - #prm2 = dask.dataframe.from_delayed(dask.delayed(prms[1]), meta=prm_meta) - prm1, prm2 = prms[(date1, date2)] - - if topo is None: - # calculation is straightforward and does not require delayed wrappers - intf = (data.sel(date=date1) * np.conj(data.sel(date=date2))) - else: - blocks_total = [] - for ylim in ylims: - blocks = [] - for xlim in xlims: - delayed = dask.delayed(block_phasediff)(date1, date2, prm1, prm2, ylim, xlim) - block = dask.array.from_delayed(delayed, - shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])), - dtype=np.complex64) - blocks.append(block) - del block, delayed - blocks_total.append(blocks) - del blocks - intf = xr.DataArray(dask.array.block(blocks_total), coords={'y': data.y, 'x': data.x}) - del blocks_total - - # add to stack - stack.append(intf) - # cleanup - del intf, prm1, prm2 - del prms - - coord_pair = [' '.join(pair) for pair in pairs] - coord_ref = xr.DataArray(pd.to_datetime(pairs[:,0]), coords={'pair': coord_pair}) - coord_rep = xr.DataArray(pd.to_datetime(pairs[:,1]), coords={'pair': coord_pair}) + # interpret the topo argument as topography, otherwise, use it as topography phase + if isinstance(topo, str) and topo == 'auto': + topo = utils.interp2d_like(self.get_topo(), data, method=method, kwargs={'fill_value': 'extrapolate'}) + if (isinstance(topo, xr.DataArray) and topo.name=='topo'): + phase_topo = self.topo_phase(pairs, topo, grid=data, method=method) + else: + phase_topo = topo - return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).rename('phase') + if phase is not None: + phase_real = utils.interp2d_like(phase, grid=data, method=method, kwargs={'fill_value': 'extrapolate'}) + else: + phase_real = 0 + #phase_real = len(pairs)*[0] + + # calculate phase difference + data1 = data.sel(date=pairs[:,0]).drop_vars('date').rename({'date': 'pair'}) + data2 = data.sel(date=pairs[:,1]).drop_vars('date').rename({'date': 'pair'}) + out = (data1 * phase_topo * np.exp(-1j * phase_real) * da.conj(data2)).astype(np.complex64) + del phase_topo, phase_real, data1, data2 + + # # calculate phase difference + # phase_dask = da.stack([(data.sel(date=pair[0]).drop_vars('date') \ + # * phase_topo[idx] * np.exp(-1j * phase_real[idx]) \ + # * da.conj(data.sel(date=pair[1]).drop_vars('date'))) for idx, pair in enumerate(pairs)], axis=0) + # out = xr.DataArray(phase_dask, coords=phase_topo.coords) + # del phase_topo, phase_real, phase_dask + + return out.astype(np.complex64).rename('phase') def goldstein(self, phase, corr, psize=32, debug=False): import xarray as xr diff --git a/pygmtsar/pygmtsar/Stack_topo.py b/pygmtsar/pygmtsar/Stack_topo.py index 8d8e723..3a590d8 100644 --- a/pygmtsar/pygmtsar/Stack_topo.py +++ b/pygmtsar/pygmtsar/Stack_topo.py @@ -8,6 +8,7 @@ # Licensed under the BSD 3-Clause License (see LICENSE for details) # ---------------------------------------------------------------------------- from .Stack_trans_inv import Stack_trans_inv +from .utils import utils class Stack_topo(Stack_trans_inv): @@ -63,17 +64,12 @@ def plot_topo(self, data='auto', caption='Topography on WGS84 ellipsoid, [m]', plt.ylabel('Azimuth') plt.title(caption) - def topo_phase(self, pairs, topo='auto', debug=False): - """ - decimator = sbas.decimator(resolution=15, grid=(1,1)) - topophase = sbas.topo_phase(pairs, topo=decimator(sbas.get_topo())) - """ + def topo_phase(self, pairs, topo='auto', grid=None, method='nearest', debug=False): import pandas as pd import dask - import dask.dataframe + import dask.array as da import xarray as xr import numpy as np - #from tqdm.auto import tqdm import joblib import warnings # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" @@ -90,6 +86,8 @@ def topo_phase(self, pairs, topo='auto', debug=False): if isinstance(topo, str) and topo == 'auto': topo = self.get_topo() + if grid is not None: + topo = utils.interp2d_like(topo, grid, method=method) # calculate the combined earth curvature and topography correction def calc_drho(rho, topo, earth_radius, height, b, alpha, Bx): @@ -109,39 +107,21 @@ def calc_drho(rho, topo, earth_radius, height, b, alpha, Bx): drho = -rho + np.sqrt(term1) del term1, sint, cost, ret, c, cosa, sina return drho - - def block_phase(prm1, prm2, ylim, xlim): - # use outer variables date, stack_prm - # disable "distributed.utils_perf - WARNING - full garbage collections ..." - try: - from dask.distributed import utils_perf - utils_perf.disable_gc_diagnosis() - except ImportError: - from distributed.gc import disable_gc_diagnosis - disable_gc_diagnosis() - import warnings - # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" - warnings.filterwarnings('ignore') - warnings.filterwarnings('ignore', module='dask') - warnings.filterwarnings('ignore', module='dask.core') - - # for lazy Dask ddataframes - #prm1 = PRM(prm1) - #prm2 = PRM(prm2) - #prm1, prm2 = stack_prm[stack_idx] - #data1, data2 = stack_data[stack_idx] - - # use outer variables topo, data1, data2, prm1, prm2 - # build topo block - block_topo = topo.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\ - .compute(n_workers=1)\ - .fillna(0) - - # set dimensions + + #def block_phase(prm1, prm2, ylim, xlim): + def block_phase_dask(block_topo, y_chunk, x_chunk, prm1, prm2): + from scipy import constants + #assert 0, f'block_topo.shape: {block_topo.shape}, {block_topo}' + # for 3d processing + #block_topo = block_topo[0] + #prm1 = prm1[0] + #prm2 = prm2[0] + + # get full dimensions xdim = prm1.get('num_rng_bins') ydim = prm1.get('num_patches') * prm1.get('num_valid_az') - # set heights + # get heights htc = prm1.get('SC_height') ht0 = prm1.get('SC_height_start') htf = prm1.get('SC_height_end') @@ -150,13 +130,10 @@ def block_phase(prm1, prm2, ylim, xlim): tspan = 86400 * abs(prm2.get('SC_clock_stop') - prm2.get('SC_clock_start')) assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), 'Check sc_clock_start, sc_clock_end, or PRF' - from scipy import constants # setup the default parameters - # constant from GMTSAR code for consistency - #SOL = 299792456.0 drange = constants.speed_of_light / (2 * prm2.get('rng_samp_rate')) - #drange = SOL / (2 * prm2.get('rng_samp_rate')) alpha = prm2.get('alpha_start') * np.pi / 180 + # a constant that converts drho into a phase shift cnst = -4 * np.pi / prm2.get('radar_wavelength') # calculate initial baselines @@ -190,13 +167,12 @@ def block_phase(prm1, prm2, ylim, xlim): dht = (-3 * ht0 + 4 * htc - htf) / tspan ddht = (2 * ht0 - 4 * htc + 2 * htf) / (tspan * tspan) - # multiply xr.ones_like(topo) for correct broadcasting - near_range = xr.ones_like(block_topo)*(prm1.get('near_range') + \ - block_topo.x * (1 + prm1.get('stretch_r')) * drange) + \ - xr.ones_like(block_topo)*(block_topo.y * prm1.get('a_stretch_r') * drange) + near_range = (prm1.get('near_range') + \ + x_chunk.reshape(1,-1) * (1 + prm1.get('stretch_r')) * drange) + \ + y_chunk.reshape(-1,1) * prm1.get('a_stretch_r') * drange - # calculate the change in baseline and height along the frame if topoflag is on - time = block_topo.y * tspan / (ydim - 1) + # calculate the change in baseline and height along the frame + time = y_chunk * tspan / (ydim - 1) Bh = Bh0 + dBh * time + ddBh * time**2 Bv = Bv0 + dBv * time + ddBv * time**2 Bx = Bx0 + dBx * time + ddBx * time**2 @@ -205,27 +181,16 @@ def block_phase(prm1, prm2, ylim, xlim): height = ht0 + dht * time + ddht * time**2 # calculate the combined earth curvature and topography correction - drho = calc_drho(near_range, block_topo, prm1.get('earth_radius'), height, B, alpha, Bx) + drho = calc_drho(near_range, block_topo, prm1.get('earth_radius'), + height.reshape(-1, 1), B.reshape(-1, 1), alpha.reshape(-1, 1), Bx.reshape(-1, 1)) - # make topographic and model phase corrections - # GMTSAR uses float32 complex operations with precision loss - #phase_shift = np.exp(1j * (cnst * drho).astype(np.float32)) phase_shift = np.exp(1j * (cnst * drho)) - del block_topo, near_range, drho, height, B, alpha, Bx, Bv, Bh, time + del near_range, drho, height, B, alpha, Bx, Bv, Bh, time + # for 3d processing + #return np.expand_dims(phase_shift.astype(np.complex64), 0) return phase_shift.astype(np.complex64) - # # prepare lazy PRM - # # this is the "true way" but processing is ~40% slower due to additional Dask tasks - # def block_prms(date1, date2): - # prm1 = self.PRM(date1) - # prm2 = self.PRM(date2) - # prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() - # prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() - # return (prm1.df, prm2.df) - # # Define metadata explicitly to match PRM dataframe - # prm_meta = pd.DataFrame(columns=['name', 'value']).astype({'name': 'str', 'value': 'object'}).set_index('name') - # immediately prepare PRM # here is some delay on the function call but the actual processing is faster def prepare_prms(pair): @@ -234,59 +199,42 @@ def prepare_prms(pair): prm2 = self.PRM(date2) prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() - return {(date1, date2): (prm1, prm2)} + return (prm1, prm2) - #with self.tqdm_joblib(tqdm(desc=f'Pre-Processing PRM', total=len(pairs))) as progress_bar: prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair) for pair in pairs) - # convert the list of dicts to a single dict - prms = {k: v for d in prms for k, v in d.items()} - - # define blocks - chunks = topo.chunks - ychunks, xchunks = chunks[0], chunks[1] - ychunks = np.concatenate([[0], np.cumsum(ychunks)]) - xchunks = np.concatenate([[0], np.cumsum(xchunks)]) - ylims = [(y1, y2) for y1, y2 in zip(ychunks, ychunks[1:])] - xlims = [(x1, x2) for x1, x2 in zip(xchunks, xchunks[1:])] - #print ('ylims', ylims) - #print ('xlims', xlims) - - stack = [] - for stack_idx, pair in enumerate(pairs): - date1, date2 = pair - # Create a Dask DataFrame with provided metadata for each Dask block - #prms = dask.delayed(block_prms)(date1, date2) - #prm1 = dask.dataframe.from_delayed(dask.delayed(prms[0]), meta=prm_meta) - #prm2 = dask.dataframe.from_delayed(dask.delayed(prms[1]), meta=prm_meta) - prm1, prm2 = prms[(date1, date2)] - - blocks_total = [] - for ylim in ylims: - blocks = [] - for xlim in xlims: - delayed = dask.delayed(block_phase)(prm1, prm2, ylim, xlim) - block = dask.array.from_delayed(delayed, - shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])), - dtype=np.complex64) - blocks.append(block) - del block, delayed - blocks_total.append(blocks) - del blocks - intf = xr.DataArray(dask.array.block(blocks_total), coords={'y': topo.y, 'x': topo.x}) - del blocks_total - - # add to stack - stack.append(intf) - # cleanup - del intf, prm1, prm2 - del prms + # fill NaNs by 0 and expand to 3d + topo2d = da.where(da.isnan(topo.data), 0, topo.data) + + # for 3d processing + # topo3d = da.repeat(da.expand_dims(topo2d, 0), len(pairs), axis=0).rechunk((1, 'auto', 'auto')) + # out = da.blockwise( + # block_phase_dask, + # 'kyx', + # topo3d, 'kyx', + # topo.y, 'y', + # topo.x, 'x', + # [prm[0] for prm in prms], 'k', + # [prm[1] for prm in prms], 'k', + # dtype=np.complex64, + # ) + + out = da.stack([da.blockwise( + block_phase_dask, + 'yx', + topo2d, 'yx', + topo.y, 'y', + topo.x, 'x', + prm1=prm[0], + prm2=prm[1], + dtype=np.complex64 + ) for prm in prms], axis=0) coord_pair = [' '.join(pair) for pair in pairs] coord_ref = xr.DataArray(pd.to_datetime(pairs[:,0]), coords={'pair': coord_pair}) coord_rep = xr.DataArray(pd.to_datetime(pairs[:,1]), coords={'pair': coord_pair}) - - return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).where(np.isfinite(topo)).rename('phase') + return xr.DataArray(out, coords={'pair': coord_pair, 'y': topo.y, 'x': topo.x})\ + .assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).where(da.isfinite(topo)).rename('phase') def topo_slope(self, topo='auto', edge_order=1): import xarray as xr diff --git a/pygmtsar/pygmtsar/utils.py b/pygmtsar/pygmtsar/utils.py index b5b962d..54f3944 100644 --- a/pygmtsar/pygmtsar/utils.py +++ b/pygmtsar/pygmtsar/utils.py @@ -33,6 +33,56 @@ class utils(): # .predict(np.column_stack([topo_values])).reshape(phase.shape) # return xr.DataArray(phase_topo, coords=phase.coords) + # Xarray's interpolation can be inefficient for large grids; + # this custom function handles the task more effectively. + @staticmethod + def interp2d_like(grid_in, grid_out, method='cubic', **kwargs): + import xarray as xr + import dask.array as da + import os + import warnings + # suppress Dask warning "RuntimeWarning: invalid value encountered in divide" + warnings.filterwarnings('ignore') + warnings.filterwarnings('ignore', module='dask') + warnings.filterwarnings('ignore', module='dask.core') + + # detect dimensions and coordinates for 2D or 3D grid + dims = grid_out.dims[-2:] + dim1, dim2 = dims + coords = {dim1: grid_out[dim1], dim2: grid_out[dim2]} + #print (f'dims: {dims}, coords: {coords}') + + # use outer variable grid_in + def interpolate_chunk(out_chunk1, out_chunk2, dim1, dim2, method, **kwargs): + d1, d2 = float(grid_in[dim1].diff(dim1)[0]), float(grid_in[dim2].diff(dim2)[0]) + #print ('d1, d2', d1, d2) + chunk = grid_in.sel({ + dim1: slice(out_chunk1[0]-2*d1, out_chunk1[-1]+2*d1), + dim2: slice(out_chunk2[0]-2*d2, out_chunk2[-1]+2*d2) + }).compute(n_workers=1) + #print ('chunk', chunk) + out = chunk.interp({dim1: out_chunk1, dim2: out_chunk2}, method=method, **kwargs) + del chunk + return out + + chunk_sizes = grid_out.chunks[-2:] if hasattr(grid_out, 'chunks') else (self.chunksize, self.chunksize) + # coordinates are numpy arrays + grid_out_y = da.from_array(grid_out[dim1].values, chunks=chunk_sizes[0]) + grid_out_x = da.from_array(grid_out[dim2].values, chunks=chunk_sizes[1]) + + grid = da.blockwise( + interpolate_chunk, + 'yx', + grid_out_y, 'y', + grid_out_x, 'x', + dtype=grid_in.dtype, + dim1=dim1, + dim2=dim2, + method=method, + **kwargs + ) + return xr.DataArray(grid, coords=coords).rename(grid_in.name) + @staticmethod def nanconvolve2d_gaussian(data, weight=None,