pytorch中的序数回归模型
spacecutter的Python项目详细描述
太空切割机
spacecutter
是在pytorch中实现顺序回归模型的库。图书馆由模型和损失函数组成。建议使用skorch包装模型,使其与scikit learn兼容。
安装
通过
pip install -e .
用法
型号
定义任何需要生成单个标量预测值的pytorch模型。这将是我们的predictor
模型。然后,这个模型可以用spacecutter.models.OrdinalLogisticModel
包装,它将predictor
的输出从单个数字转换为有序类概率数组。下面的示例演示了如何对两层神经网络predictor
执行此操作,以解决具有三个有序类的问题。
importnumpyasnpimporttorchfromtorchimportnnfromspacecutter.modelsimportOrdinalLogisticModelX=np.array([[0.5,0.1,-0.1],[1.0,0.2,0.6],[-2.0,0.4,0.8]],dtype=np.float32)y=np.array([0,1,2]).reshape(-1,1)num_features=X.shape[1]num_classes=len(np.unique(y))predictor=nn.Sequential(nn.Linear(num_features,num_features),nn.ReLU(),nn.Linear(num_features,1))model=OrdinalLogisticModel(predictor,num_classes)y_pred=model(torch.as_tensor(X))print(y_pred)# tensor([[0.2325, 0.2191, 0.5485],# [0.2324, 0.2191, 0.5485],# [0.2607, 0.2287, 0.5106]], grad_fn=<CatBackward>)
培训
建议使用skorch来训练spacecutter
模型。下面演示如何使用skorch
的累积链路损耗来训练上一节中的模型:
fromskorchimportNeuralNetfromspacecutter.callbacksimportAscensionCallbackfromspacecutter.lossesimportCumulativeLinkLossskorch_model=NeuralNet(module=OrdinalLogisticModel,module__predictor=predictor,module__num_classes=num_classes,criterion=CumulativeLinkLoss,train_split=None,callbacks=[('ascension',AscensionCallback()),],)skorch_model.fit(X,y)
注意,我们必须添加AscensionCallback
。这样可以确保顺序剪切点保持升序。虽然理想情况下,这个约束将被直接考虑到模型优化中,spacecutter
目前通过使用post backward pass回调来剪裁剪切点值来破解与sgd兼容的解决方案。