首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch中的model.training是什么?

pytorch中的model.training是什么?
EN

Stack Overflow用户
提问于 2021-09-29 06:50:53
回答 2查看 83关注 0票数 0

嗨,我正在学习关于迁移学习的pytorch教程。(https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

model.training是用来干什么的??

代码语言:javascript
复制
enter def visualize_model(model,num_images=6):
was_training=model.training
model.eval()
images_so_far=0
fig=plt.figure()

with torch.no_grad():
    for i, (inputs,labels) in enumerate(dataloaders['val']):
        inputs=inputs.to(device)
        labels=labels.to(device)
        
        outputs=model(inputs)
        _,pred=torch.max(outputs,1)
        
        for j in range(inputs.size()[0]):
            images_so_far+=1
            ax=plt.subplot(num_images//2,2,images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[preds[j]]))
            imshow(inputs.cpu().data[j])
            
            if images_so_far==num_images:
                model.train(mode=was_training)
                return
    model.train(mode=was_training)code here

我不明白"model.train(model=was_training)“的意思。有什么帮助吗??非常感谢

EN

回答 2

Stack Overflow用户

发布于 2021-09-29 07:12:20

我认为这会有所帮助(link)

所有的训练都有一个内部的nn.Modules属性,通过调用model.train()和model.eval()来改变模型的行为。

was_training变量存储模型的当前训练状态,调用model.eval(),并在最后使用model.train(training=was_training)重置状态。

您可以在pytorch讨论论坛中找到很好的答案;)

票数 0
EN

Stack Overflow用户

发布于 2021-09-29 08:01:37

我想知道为什么他们在测试会话中使用model.train。为什么他们要把这些代码放在with torch.no_grad()中?这不是很明显吗,was_training=false

train的用法有点误导,因为train 也可用于将模型置于推理(评估)模式

代码语言:javascript
复制
>>> model.train(mode=True)
>>> model.training 
True   # <- train mode

>>> model.train(mode=False)
False  # <- eval mode

我同意这并不理想,更恰当的表述应该是:

代码语言:javascript
复制
>>> model.eval()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69371652

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档