如何在pytest单元测试中比较XGBoost模型对象(一个已初始化/安装,另一个从文件读取)?

2024-09-30 03:22:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我对用于封装XGBoost模型的类进行了简单测试。为了测试这个类,我训练了一个XGBoost模型并将其保存到文件中,我想使用这个训练过的模型,我将从文件中读取它来测试我的模型训练代码。我不确定如何才能最好地将我将使用已知参数/数据训练的XGBoost模型与我保存到文件中的模型进行比较。例如,我训练并保存了一个XGBoost模型,如下所示:

# specify parameters to use for training the XGBoost model
params = {
    'max_depth': 6,  # the maximum depth of each tree
    'eta': 0.25,  # the training step for each iteration
    'silent': 1,  # logging mode - quiet
    'objective': 'reg:tweedie',
    'booster': 'gbtree',
    'subsample': 0.7,
    'gamma': 0.3,  # regularization parameter
    'colsample_bytree': 0.2,
    'rate_drop': 0.3,
    'skip_drop': 0.2,
    'early_stopping_rounds': 10,
    'eval_metric': ['rmse', 'mae'],  # error evaluation for multiclass training
}

# split X and y into train and test sets
features_train, features_test, target_train, target_test = \
    train_test_split(features, target, test_size=test_percentage, random_state=31)

# package the dataset splits as input for XGBoost
dtrain = xgb.DMatrix(features_train, label=target_train)
dtest = xgb.DMatrix(features_test, label=target_test)
evallist = [(dtest, 'eval'), (dtrain, 'train')]

# train the XGBoost model
xgbooster = xgb.train(params, dtrain, training_iterations, evallist, verbose_eval=0)
pickle.dump(xgbooster, open("/path/to/fitted_model.dat", "wb"))

在模型类的(pytest)单元测试中,我想测试我是否按照预期训练模型,因此我从文件中读取此保存的模型,以便与应该匹配的模型进行比较:

def test_xgboost_fit():

    features_train_df = pd.read_csv("/path/to/features_train.csv"))
    labels_train_df = pd.read_csv("/path/to/labels_train.csv"))
    fixture_xgbooster = pickle.load(open("/path/to/fitted_model.dat", "rb"))

    # train/fit the model
    xgbooster = mymodelclass.XGBoostModel()
    xgbooster.fit(features_train_df, labels_train_df)

    # compare the trained model against the expected model read from file
    assert xgbooster.model == fixture_xgbooster

在这里使用双等号似乎不足以进行比较(否则我还有其他问题,因为它显示了具有相同参数和拟合相同训练数据的两个模型是不相等的)

我应该如何在测试中进行比较?还是有更好的方法来测试这段代码


Tags: csvthetopath模型testtargetdf

热门问题