首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >xr.apply_ufunc过滤3Dx阵列

xr.apply_ufunc过滤3Dx阵列
EN

Stack Overflow用户
提问于 2020-07-08 14:28:00
回答 2查看 1.1K关注 0票数 0

今天,我遇到了在三维x数组数据数组上正确应用xr.apply_ufunc的问题。

我希望沿着“时间”维度过滤数组。因此,产品最终应该具有与输入数组相同的尺寸,并且应该只包含过滤后的数据,而不是每月的数据。为此,我编写了以下两个函数:

代码语言:javascript
复制
import xarray as xr
from scipy import signal
import numpy as np

def butter_filt(x,filt_year,fs,order_butter):

    #filt_year = 1 #1 year
    #fs = 12 #monthly data
    #fn = fs/2; # Nyquist Frequency
    fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
    #fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
    #fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
    b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')

    # Check NA values
    co = np.count_nonzero(~np.isnan(x))
    if co < 4: # If fewer than 4 observations return -9999
        return np.empty(x.shape)
    else:
        return signal.filtfilt(b, a, x)

def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
    # x ...... xr data array
    # dims .... dimension along which to apply function    
    filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
                     input_core_dims=[[dim], [], [], []],
                       dask='parallelized')

    return filt

x_uv = filtfilt_butter(ds.x,
                   filt_year=1,
                   fs=12,
                   order_butter=2,
                   dim='time')

我尝试了butter_filt函数,它本身工作得很好,所以在filtfilt_butter中存在一些问题。试图计算x_uv会产生以下错误:

代码语言:javascript
复制
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<timed exec> in <module>

<ipython-input-86-c470080a0817> in filtfilt_butter(x, filt_year, fs, order_butter, dim)
     20     # x ...... xr data array
     21     # dims .... dimension aong which to apply function
---> 22     filt= xr.apply_ufunc(butter_filt, x,filt_year,fs,order_butter,
     23                          input_core_dims=[[dim], [], [], []],
     24                            dask='parallelized')

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, 
dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, *args)
   1056         )
   1057     elif any(isinstance(a, DataArray) for a in args):
-> 1058         return apply_dataarray_vfunc(
   1059             variables_vfunc,
   1060             *args,

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    231 
    232     data_vars = [getattr(a, "variable", a) for a in args]
--> 233     result_var = func(*data_vars)
    234 
    235     if signature.num_outputs > 1:

~/miniconda3/envs/pyt3_11102018/lib/python3.8/site-packages/xarray/core/computation.py in 
 apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, output_sizes, keep_attrs, 
meta, *args)
    621         data = as_compatible_data(data)
    622         if data.ndim != len(dims):
--> 623             raise ValueError(
    624                 "applied function returned data with unexpected "
    625                 "number of dimensions: {} vs {}, for dimensions {}".format(

ValueError: applied function returned data with unexpected number of dimensions: 3 vs 2, for 
dimensions ('deptht', 'km')

我怎么才能解决这个问题?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-07-09 21:06:30

apply_ufunc期望输出形状是初始形状减去input_core_dims形状加上output_core_dims形状。在您的示例中,您正确地将time作为输入核dim传递,因为您希望确保它被移动到最后一个维度,因此使用axis=-1可以正确工作。

因此,您需要使用output_core_dims来获取xarray,以期望得到一个3d输出数组。您也可以使用time

有关apply_ufunc参数的更详细说明,请参见this other answer

票数 1
EN

Stack Overflow用户

发布于 2020-07-09 16:07:17

通过定义输入和输出维度,我找到了解决这个问题的方法:

def butter_filt(x,filt_year,fs,order_butter):

代码语言:javascript
复制
#filt_year = 1 #1 year
#fs = 12 #monthly data
#fn = fs/2; # Nyquist Frequency
fc = (1/filt_year)/2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
#fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
#fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
b, a = signal.butter(order_butter, fc, 'low', fs=fs, output='ba')

return signal.filtfilt(b, a, x)




def filtfilt_butter(x,filt_year,fs,order_butter,dim='time'):
# x ...... xr data array
# dims .... dimension aong which to apply function    
filt= xr.apply_ufunc(
    butter_filt,  # first the function
    x,# now arguments in the order expected by 'butter_filt'
    filt_year,  # as above
    fs,  # as above
    order_butter,  # as above
    input_core_dims=[["deptht","km","time"], [], [],[]],  # list with one entry per arg
    output_core_dims=[["deptht","km","time"]],  # returned data has 3 dimension
    exclude_dims=set(("time",)),  # dimensions allowed to change size. Must be a set!
    vectorize=True,  # loop over non-core dims
)

return filt
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62797088

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档