首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >微调后如何使用“scikit学习校准”

微调后如何使用“scikit学习校准”
EN

Stack Overflow用户
提问于 2022-05-08 17:36:28
回答 1查看 187关注 0票数 0

我精调LGBM和应用校准,但应用校准有困难。

我有训练,有效,测试数据。

I使用1)训练数据和2)有效数据对LGBM进行了训练和微调。然后,得到了LGBM.的最佳参数。

在此之后,我要进行校准,以使我的模型的输出可以直接解释为一个信心水平。但是我对使用CalibratedClassifierCV感到困惑。

在我的情况下,是使用cv='prefit‘还是使用cv=5?另外,我应该使用列车数据还是适合CalibratedClassifierCV?的有效数据

1) uncalibrated_clf但训练后

代码语言:javascript
复制
clf = lgb.LGBMClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], verbose=True, early_stopping_rounds=20)

2-1) Calibrated_clf

代码语言:javascript
复制
cal_clf = CalibratedClassifierCV(clf, cv='prefit', method='isotonic')
cal_clf.fit(X_valid, y_valid)

2-2) Calibrated_clf

代码语言:javascript
复制
cal_clf = CalibratedClassifierCV(clf, cv=5, method='isotonic')
cal_clf.fit(X_train, y_train)

2-3) Calibrated_clf

代码语言:javascript
复制
cal_clf = CalibratedClassifierCV(clf, cv=5, method='isotonic')
cal_clf.fit(X_valid, y_valid)

哪一个是对的?一切都是对的,或者只有一两个是对的?

下面是密码。

代码语言:javascript
复制
import numpy as np
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.calibration import calibration_curve
from sklearn.calibration import CalibratedClassifierCV
import lightgbm as lgb
import matplotlib.pyplot as plt

np.random.seed(0)
n_samples = 10000
X, y = make_classification(
    n_samples=3*n_samples, n_features=20, n_informative=2,
    n_classes=2, n_redundant=2, random_state=32)
#n_samples = N_SAMPLES//10

X_train, y_train = X[:n_samples], y[:n_samples]
X_valid, y_valid = X[n_samples:2*n_samples], y[n_samples:2*n_samples] 
X_test, y_test = X[2*n_samples:], y[2*n_samples:]

plt.figure(figsize=(12, 9))
plt.plot([0, 1], [0, 1], '--', color='gray')

# 1) Uncalibrated_clf but fine-tuned on training data
clf = lgb.LGBMClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], verbose=True, early_stopping_rounds=20)

y_prob = clf.predict_proba(X_test)[:, 1]
fraction_of_positives, mean_predicted_value = calibration_curve(y_test, y_prob, n_bins=10)

plt.plot(
    fraction_of_positives,
    mean_predicted_value,
    'o-', label='uncalibrated_clf')

# 2-1) Calibrated_clf
cal_clf = CalibratedClassifierCV(clf, cv='prefit', method='isotonic')
cal_clf.fit(X_valid, y_valid)

y_prob1 = cal_clf.predict_proba(X_test)[:, 1]
fraction_of_positives1, mean_predicted_value1 = calibration_curve(y_test, y_prob1, n_bins=10)

plt.plot(
    fraction_of_positives1,
    mean_predicted_value1,
    'o-', label='calibrated_clf1')


# 2-2) Calibrated_clf
cal_clf = CalibratedClassifierCV(clf, cv=5, method='isotonic')
cal_clf.fit(X_train, y_train)

y_prob2 = cal_clf.predict_proba(X_test)[:, 1]
fraction_of_positives2, mean_predicted_value2 = calibration_curve(y_test, y_prob2, n_bins=10)

plt.plot(
    fraction_of_positives2,
    mean_predicted_value2,
    'o-', label='calibrated_clf2')

plt.legend()

# 2-3) Calibrated_clf
cal_clf = CalibratedClassifierCV(clf, cv=5, method='isotonic')
cal_clf.fit(X_valid, y_valid)

y_prob3 = cal_clf.predict_proba(X_test)[:, 1]
fraction_of_positives3, mean_predicted_value3 = calibration_curve(y_test, y_prob3, n_bins=10)

plt.plot(
    fraction_of_positives2,
    mean_predicted_value2,
    'o-', label='calibrated_clf3')

plt.legend()
EN

回答 1

Stack Overflow用户

发布于 2022-10-03 13:47:06

这样做的方法是:

( a)拟合模型并在保持装置上进行校准。

代码语言:javascript
复制
model.fit(X_train, y_train)
calibrated = CalibratedClassifierCV(model, cv='prefit').fit(X_val, y_val)
y_pred = calibrated.predict(X_test)

(这实际上是预置的意思:模型已经安装好了,现在拿一个新的相关集并校准输出)。

b)对模型进行拟合,并在训练集上进行交叉验证。

代码语言:javascript
复制
model.fit(X_train, y_train)
calibrated = CalibratedClassifierCV(model, cv=5).fit(X_train, y_train)
y_pred_val = calibrated.predict(X_val)

就像通常的情况一样,交叉验证的数量和方法(在scikit-learn的行话中,等元回归相对于Platt比例或sigmoid )严格地取决于您的数据和设置。因此,我建议把这些放在网格搜索,看看什么会产生最好的结果。

最后,在这里可以找到更深层次的潜水:https://machinelearningmastery.com/calibrated-classification-model-in-scikit-learn/

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72163596

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档