首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >规范化tf.data.Dataset

规范化tf.data.Dataset
EN

Stack Overflow用户
提问于 2021-07-30 12:01:11
回答 1查看 904关注 0票数 2

我有一个输入形状(批处理大小,128,128,2)和目标形状(批大小,128,128,1)的图像的tf.data.Dataset,其中输入的是2通道图像(具有两个通道表示实和虚部分的复值图像),而目标是1通道图像(实值图像)。我需要标准化输入图像和目标图像,首先从它们中删除它们的平均图像,然后将它们缩放到(0,1)范围。如果我没有错,tf.data.Dataset一次只能处理一个批处理,而不能处理整个数据集。因此,我从remove_mean py_function中的批处理中的每个图像中删除批的平均图像,然后通过减去它的最小值并除以py_function linear_scaling中的最大值和最小值的差,将每幅图像缩放到(0,1)。但是,在应用这些函数之前和之后,从数据集中在输入图像中打印min和max值后,图像值不会发生变化。有人能说出这件事可能出了什么问题吗?

代码语言:javascript
复制
def remove_mean(image, target):
    image_mean = np.mean(image, axis=0)
    target_mean = np.mean(target, axis=0)
    image = image - image_mean
    target = target - target_mean
    return image, target

def linear_scaling(image, target):
    image_min = np.ndarray.min(image, axis=(1,2), keepdims=True)
    image_max = np.ndarray.max(image, axis=(1,2), keepdims=True)
    image = (image-image_min)/(image_max-image_min)

    target_min = np.ndarray.min(target, axis=(1,2), keepdims=True)
    target_max = np.ndarray.max(target, axis=(1,2), keepdims=True)
    target = (target-target_min)/(target_max-target_min)
    return image, target

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))


Output -

tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-30 12:26:53

map不是一个内部操作,所以您的train_dataset在执行train_dataset.map(....)时不会改变。

train_dataset = train_dataset.map(...)

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68590624

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档