首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >矩阵乘法不起作用- Tensorflow

矩阵乘法不起作用- Tensorflow
EN

Stack Overflow用户
提问于 2017-08-25 20:35:57
回答 1查看 152关注 0票数 0

我是使用tensorflow的乞丐,也是学校项目的乞丐。在这里,我试图创建一个房屋标识符,其中我在excel表上创建了一些数据,并将其转换为csv文件,并测试数据是否会被读取。数据被读取了,但当我做矩阵乘法时,它会产生多个错误,并说."ValueError: Shape必须是2级,但是对于'MatMul‘(op:'MatMul')来说是0级,输入形状是:[],1,1。“非常感谢!

代码语言:javascript
复制
import tensorflow as tf
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
filename = dir_path+ "\House Price Data .csv"
w1=tf.Variable(tf.zeros([1,1]))
w2=tf.Variable(tf.zeros([1,1])) #Feature 1's weight
w3=tf.Variable(tf.zeros([1,1])) #Feature 1's weight
b=tf.Variable(tf.zeros([1])) #bias for various features
x1= tf.placeholder(tf.float32,[None, 1])
x2= tf.placeholder(tf.float32,[None, 1])
x3= tf.placeholder(tf.float32,[None, 1])
Y= tf.placeholder(tf.float32,[None, 1])
y_=tf.placeholder(tf.float32,[None,1])
with tf.Session() as sess:
    sess.run( tf.global_variables_initializer())
    with open(filename) as inf:
        # Skip header
        next(inf)
        for line in inf:
            # Read data, using python, into our features
            housenumber, x1, x2, x3, y_ = line.strip().split(",")
            x1 = float(x1)
            product = tf.matmul(x1, w1)
            y = product + b
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-08-25 22:10:31

@Aaron是对的,当您从csv文件加载数据时,您正在覆盖这些变量。

您需要将加载的值保存到单独的变量中,比如_x1而不是x1,然后使用feed_dict将值提供给占位符。因为x1的形状是[None,1],所以需要将字符串标量_x1转换为具有相同形状的浮点数,在本例中是[1,1]

代码语言:javascript
复制
import tensorflow as tf
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
filename = dir_path+ "\House Price Data .csv"
w1=tf.Variable(tf.zeros([1,1]))
b=tf.Variable(tf.zeros([1])) #bias for various features
x1= tf.placeholder(tf.float32,[None, 1])

y_pred = tf.matmul(x1, w1) + b

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer())
    with open(filename) as inf:
        # Skip header
        next(inf)
        for line in inf:
            # Read data, using python, into our features
            housenumber, _x1, _x2, _x3, _y_ = line.strip().split(",")
            sess.run(y_pred, feed_dict={x1:[[float(_x1)]]})
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45889064

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档