首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >InvalidArgumentError在tensorflow 2中是什么意思?

InvalidArgumentError在tensorflow 2中是什么意思?
EN

Stack Overflow用户
提问于 2020-06-07 10:53:40
回答 1查看 137关注 0票数 0

我是新来的。我试图实现线性回归与自定义培训,遵循这个教程

但是当我试图计算W*x + b时,我得到了这个错误

代码语言:javascript
复制
tf.add(tf.matmul(W,x),b)

InvalidArgumentError:无法将Add计算为输入#1(以零为基础)是一个双张量,但它是一个浮动张量

I初始化W和b

W = tf.Variable(np.random.rand(1,9))

b = tf.Variable([1],dtype = tf.float32)

x = tf.Variable(np.random.rand(9,100))

但是当我把b的初始化改为

b = tf.Variable(np.random.rand(1))

我没有发现任何错误。原因是什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-07 16:06:13

np.random.rand(1,9) (和其他初始化)的结果是np.float64类型。与tf.Variable一起使用它可以得到一个类型为tf.float64的张量。

Tensorflow的add的参数必须是相同类型的。matmul结果为tf.float64型,btf.float32型。你需要把其中一个投给另一个类型。

在Tensorflow中,您可以这样做(按惯例推荐):

代码语言:javascript
复制
# Can be done in a single line too
matmul_result = tf.matmul(W,x)
matmul_result = tf.cast(matmul_result, tf.float32)
tf.add(matmul_result, b)

或者你可以这么做:

代码语言:javascript
复制
tf.add(tf.matmul(W,x), tf.cast(b, tf.float64))

还可以直接更改numpy数组的类型:

代码语言:javascript
复制
W = tf.Variable(np.random.rand(1,9).astype(np.float32))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62244261

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档