首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >`from_numpy()`导致预期的vec.is_mps()为真,但最终为假

`from_numpy()`导致预期的vec.is_mps()为真,但最终为假
EN

Stack Overflow用户
提问于 2022-07-08 10:31:17
回答 1查看 98关注 0票数 0

当我试图将数据(numpy nd array)转换为与from_numpy()一起使用mps后端时的张量时。

我将模型初始化如下:

代码语言:javascript
复制
device = "mps" if torch.has_mps else "cpu"
model = NeuralNetwork().to(device)

它使用的是mps后端:

代码语言:javascript
复制
Using mps device
NeuralNetwork(...)

然后按以下方式使用:

代码语言:javascript
复制
observations = env.reset()
X = torch.from_numpy(observations)
logits = model(X)

模型抛出错误

代码语言:javascript
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [9], in <cell line: 3>()
      1 observations = env.reset()
      2 X = torch.from_numpy(observations)
----> 3 logits = model(X)

File lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [2], in NeuralNetwork.forward(self, x)
     13 def forward(self, x):
---> 14     logits = self.linear_relu_stack(x)
     15     return logits

File lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File lib/python3.8/site-packages/torch/nn/modules/container.py:139, in Sequential.forward(self, input)
    137 def forward(self, input):
    138     for module in self:
--> 139         input = module(input)
    140     return input

File lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File lib/python3.8/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected vec.is_mps() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

如果我把设备改为cpu而不是mps,它就能工作。如何将numpy数组与mps后端一起使用?

我正在M1芯片上运行它,torch.has_mpsTrue

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-14 22:50:44

您必须首先将模型发送到mps设备,然后将您的输入显式地发送到mps设备。代码:

代码语言:javascript
复制
model.to('mps')
logits = model(X.to('mps'))

类似这样的东西,我每晚在M1 pro上使用火炬,使用一个不包含运维数据类型的模型,mps目前不支持。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72910108

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档