今天,我遇到了在三维x数组数据数组上正确应用xr.apply_ufunc的问题。
我希望沿着“时间”维度过滤数组。因此,产品最终应该具有与输入数组相同的尺寸,并且应该只包含过滤后的数据,而不是每月的数据。为此,我编写了以下两个函数:
import xarray as xr
from scipy import signal
import numpy as np
def butter_filt(x,filt_year,fs,order_butter):
#filt_year = 1 #1 year
#fs = 12 #monthly data
#fn = fs/2; # Nyquist Frequency
fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
#fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
#fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')
# Check NA values
co = np.count_nonzero(~np.isnan(x))
if co < 4: # If fewer than 4 observations return -9999
return np.empty(x.shape)
else:
return signal.filtfilt(b, a, x)
def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
# x ...... xr data array
# dims .... dimension along which to apply function
filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
input_core_dims=[[dim], [], [], []],
dask='parallelized')
return filt
x_uv = filtfilt_butter(ds.x,
filt_year=1,
fs=12,
order_butter=2,
dim='time')我尝试了butter_filt函数,它本身工作得很好,所以在filtfilt_butter中存在一些问题。试图计算x_uv会产生以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<timed exec> in <module>
<ipython-input-86-c470080a0817> in filtfilt_butter(x, filt_year, fs, order_butter, dim)
20 # x ...... xr data array
21 # dims .... dimension aong which to apply function
---> 22 filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
23 input_core_dims=[[dim], [], [], []],
24 dask='parallelized')
~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in
apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join,
dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, *args)
1056 )
1057 elif any(isinstance(a, DataArray) for a in args):
-> 1058 return apply_dataarray_vfunc(
1059 variables_vfunc,
1060 *args,
~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in
apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
231
232 data_vars = [getattr(a, "variable", a) for a in args]
--> 233 result_var = func(*data_vars)
234
235 if signature.num_outputs > 1:
~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in
apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, output_sizes, keep_attrs,
meta, *args)
621 data = as_compatible_data(data)
622 if data.ndim != len(dims):
--> 623 raise ValueError(
624 "applied function returned data with unexpected "
625 "number of dimensions: {} vs {}, for dimensions {}".format(
ValueError: applied function returned data with unexpected number of dimensions: 3 vs 2, for
dimensions ('deptht', 'km')我怎么才能解决这个问题?
发布于 2020-07-09 21:06:30
apply_ufunc期望输出形状是初始形状减去input_core_dims形状加上output_core_dims形状。在您的示例中,您正确地将time作为输入核dim传递,因为您希望确保它被移动到最后一个维度,因此使用axis=-1可以正确工作。
因此,您需要使用output_core_dims来获取xarray,以期望得到一个3d输出数组。您也可以使用time。
有关apply_ufunc参数的更详细说明,请参见this other answer
发布于 2020-07-09 16:07:17
通过定义输入和输出维度,我找到了解决这个问题的方法:
def butter_filt(x,filt_year,fs,order_butter):
#filt_year = 1 #1 year
#fs = 12 #monthly data
#fn = fs/2; # Nyquist Frequency
fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
#fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
#fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')
return signal.filtfilt(b, a, x)
def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
# x ...... xr data array
# dims .... dimension aong which to apply function
filt= xr.apply_ufunc(
butter_filt, # first the function
x,# now arguments in the order expected by 'butter_filt'
filt_year, # as above
fs, # as above
order_butter, # as above
input_core_dims=[["deptht","km","time"], [], [],[]], # list with one entry per arg
output_core_dims=[["deptht","km","time"]], # returned data has 3 dimension
exclude_dims=set(("time",)), # dimensions allowed to change size. Must be a set!
vectorize=True, # loop over non-core dims
)
return filthttps://stackoverflow.com/questions/62797088
复制相似问题