首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从插入符号10折简历中提取训练和测试AUROC

从插入符号10折简历中提取训练和测试AUROC
EN

Stack Overflow用户
提问于 2018-01-07 03:27:10
回答 1查看 1.8K关注 0票数 1

假设我正在做如下分类:

代码语言:javascript
复制
library(mlbench)
data(Sonar)

library(caret)
set.seed(998)

my_data <- Sonar

fitControl <-
  trainControl(
    method = "cv",
    number = 10,
    classProbs = T,
    savePredictions = T,
    summaryFunction = twoClassSummary
  )


model <- train(
  Class ~ .,
  data = my_data,
  method = "xgbTree",
  trControl = fitControl,
  metric = "ROC"
)

对于10次折叠中的每一次,10%的数据用于验证。对于插入符号确定的最佳参数,我可以使用getTrainPerf(model)来查找所有10个折叠的平均值验证AUC值,或者使用model$resample来获取每个折叠的AUC值。然而,如果将训练数据放回相同的模型中,我就无法获得AUC。如果我能得到每个训练集的AUC值,那就太好了。

如何提取这些信息?我想确保我的模型不是过拟合的(我正在处理的数据集非常小)。

谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-01-07 08:31:42

根据注释中的要求,这里提供了一个自定义函数,用于评估交叉验证测试错误。我不确定它是否可以从脱字符序列对象中提取出来。

在运行插入符号训练之后,提取折叠符以获得最佳曲调:

代码语言:javascript
复制
library(tidyverse)
model$bestTune %>%
  left_join(model$pred) %>%
  select(rowIndex, Resample) %>%
  mutate(Resample = as.numeric(gsub(".*(\\d$)", "\\1", Resample)),
         Resample = ifelse(Resample == 0, 10, Resample)) %>%
  arrange(rowIndex) -> resamples

构造一个交叉验证函数,它将使用与插入符号相同的折叠:

代码语言:javascript
复制
library(xgboost)
train <- my_data[,!names(my_data)%in% "Class"]
label <- as.numeric(my_data$Class) - 1

test_auc <- lapply(1:10, function(x){
  model <- xgboost(data = data.matrix(train[resamples[,2] != x,]),
                   label = label[resamples[,2] != x],
                   nrounds = model$bestTune$nrounds,
                   max_depth = model$bestTune$max_depth,
                   gamma = model$bestTune$gamma,
                   colsample_bytree = model$bestTune$colsample_bytree,
                   objective = "binary:logistic",
                   eval_metric= "auc" ,
                   print_every_n = 50)
  preds_train <- predict(model, data.matrix(train[resamples[,2] != x,]))
  preds_test <- predict(model, data.matrix(train[resamples[,2] == x,]))
  auc_train <- pROC::auc(pROC::roc(response = label[resamples[,2] != x], predictor = preds_train, levels = c(0, 1)))
  auc_test <- pROC::auc(pROC::roc(response = label[resamples[,2] == x], predictor = preds_test, levels = c(0, 1)))
  return(data.frame(fold = unique(resamples[resamples[,2] == x, 2]), auc_train, auc_test))
  })

do.call(rbind, test_auc)
#output
   fold auc_train  auc_test
1     1         1 0.9909091
2     2         1 0.9797980
3     3         1 0.9090909
4     4         1 0.9629630
5     5         1 0.9363636
6     6         1 0.9363636
7     7         1 0.9181818
8     8         1 0.9636364
9     9         1 0.9818182
10   10         1 0.8888889

arrange(model$resample, Resample)
#output
         ROC      Sens      Spec Resample
1  0.9909091 1.0000000 0.8000000   Fold01
2  0.9898990 0.9090909 0.8888889   Fold02
3  0.9909091 0.9090909 1.0000000   Fold03
4  0.9444444 0.8333333 0.8888889   Fold04
5  0.9545455 0.9090909 0.8000000   Fold05
6  0.9272727 1.0000000 0.7000000   Fold06
7  0.9181818 0.9090909 0.9000000   Fold07
8  0.9454545 0.9090909 0.8000000   Fold08
9  0.9909091 0.9090909 0.9000000   Fold09
10 0.8888889 0.9090909 0.7777778   Fold10

为什么测试折叠AUC从我的函数和插入符号是不一样的,我不能说。我非常确定使用了相同的参数和折叠。我可以假设这与随机种子有关。当我检查插入符号测试预测的auc时,我得到了与插入符号相同的输出:

代码语言:javascript
复制
model$bestTune %>%
  left_join(model$pred) %>%
  arrange(rowIndex) %>%
  select(M, Resample, obs) %>%
  mutate(Resample = as.numeric(gsub(".*(\\d$)", "\\1", Resample)),
                             Resample = ifelse(Resample == 0, 10, Resample),
         obs = as.numeric(obs) - 1) %>%
  group_by(Resample) %>%
  do(auc = as.vector(pROC::auc(pROC::roc(response = .$obs, predictor = .$M)))) %>%
  unnest()
#output
   Resample   auc
      <dbl> <dbl>
 1     1.00 0.991
 2     2.00 0.990
 3     3.00 0.991
 4     4.00 0.944
 5     5.00 0.955
 6     6.00 0.927
 7     7.00 0.918
 8     8.00 0.945
 9     9.00 0.991
10    10.0  0.889

但我再强调一次,测试错误会告诉你很少,你应该依靠训练错误。如果您想让这两个参数更接近,可以考虑使用gammaalphalambda参数。

对于一个小的数据集,我仍然会尝试拆分train : test = 80 : 20,并使用独立的测试集来验证CV错误是否接近测试错误。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48131050

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档