首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >梯度升压机的交叉验证

梯度升压机的交叉验证
EN

Code Review用户
提问于 2015-11-12 16:15:36
回答 1查看 2.4K关注 0票数 2

我对Python相当陌生。我为梯度增强方法实现了一个简短的交叉验证工具。

代码语言:javascript
复制
import numpy as np

from sklearn.metrics import roc_auc_score as auc
from sklearn import cross_validation
from time import time

def heldout_auc(model, X_test, y_test):
    score = np.zeros((model.get_params()["n_estimators"],), dtype=np.float64)
    for i, y_pred in enumerate(model.staged_decision_function(X_test)):
        score[i] = auc(y_test, y_pred)
    return score

def cv_boost_estimate(X,y,model,n_folds=3):
    cv = cross_validation.StratifiedKFold(y, n_folds=n_folds, shuffle=True, random_state=11)
    val_scores = np.zeros((model.get_params()["n_estimators"],), dtype=np.float64)
    t = time()
    i = 0
    for train, test in cv:
        i = i + 1
        print('FOLD : ' + str(i) + '-' + str(n_folds))
        model.fit(X.iloc[train,], y.iloc[train])
        val_scores += heldout_auc(model, X.iloc[test,], y.iloc[test])
    val_scores /= n_folds
    return val_scores,(time()-t)

然后,我可以通过以下方法寻找最佳树数:

代码语言:javascript
复制
print('AUC : ' + str(max(auc)) + ' - index : ' + str(auc.tolist().index(max(auc))))

一切都在工作,但语法感觉不太稳定,而且“不是毕达克”。有人有改进的建议吗?

EN

回答 1

Code Review用户

回答已采纳

发布于 2015-11-16 09:29:52

您没有在您的enumerate循环中使用cv,我想您尝试过这样做,并发现它不起作用:

代码语言:javascript
复制
for i, train, test in enumerate(cv):

要理解这一点,您需要知道enumerate实际上在做什么。它将cv的每个元素包装成一个2项元组,其中索引作为第一项,另一项为可迭代的元素。基本上,它使用的是这样的结构:(i, (train, test))

幸运的是,这意味着您只需进行一次修改就可以获得所需的结果:

代码语言:javascript
复制
for i, (train, test) in enumerate(cv):

现在,它可以正确地提取所有三个值,没有错误。尽管您希望i从1开始,但我发现它比手动递增的值更清晰和可读性更强。您只需考虑将i值提高1。

另外,当您应该使用str.format时,您正在手动连接字符串。它允许您将值插入字符串中,如下所示:

代码语言:javascript
复制
    print('FOLD : {}-{}'.format(i, n_folds))

format将用传递给它的参数替换任何{}。它还将自动尝试对任何可以转换为字符串的对象调用str,这样您就不再需要手动调用了。

最后,您不需要将您的时间表达式封装在括号中,如果没有它,它会返回很好的结果,省略括号更像Pythonic。

代码语言:javascript
复制
return val_scores, time() - t

您应该阅读Python风格指南,它有很多关于Pythonic风格编码的信息。你基本上是正确的,但是有些行有点长,有时你应该在里面放更多的空格(例如。time()-t -> time() - t)。

票数 1
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/110573

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档