进行类不平衡正则化的正确位置(数据级或批处理级)

2024-10-04 11:25:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个二元不平衡数据集,其中标签是0或1,预测输出在0和1之间。正例有10000个样本,而负例有90000个样本。我训练的时候用的是一批100个

在计算二进制交叉熵时(在pytorch中),可以提供每批元素的正则化权重

我的问题是: 为了计算一般等级的重量剂量,在开始时计算1次(阳性病例为1/(10000/(100000)),并用该值衡量每个样本的损失更为合理

或:

通过首先发现批次类别不平衡来计算批次级别的重量(例如,在批次中,可能有25个阳性和75个阴性,因此阳性情况下为1/(25/(25+75))

我问这个是因为损失是整个批次的平均值


Tags: 数据元素二进制pytorch标签阳性交叉权重
1条回答
网友
1楼 · 发布于 2024-10-04 11:25:16

如果您想这样做,您应该计算每批类的不平衡

另一方面,您可能应该确保每个批次保留标签统计信息(例如,对于批次64和您的案例,您应该有6阳性样本,其余为阴性样本)。这样,只需计算一次类不平衡,然后按批将其添加到torch.nn.BCELoss就足够了

不过,我建议使用另一种方法,例如使用PyTorch的Sampler类进行过采样或欠采样(不要复制示例,这样会完全不必要地浪费空间)。您可以手动实现它或使用为您实现它的第三方库,例如torchdata(披露:我是作者)和^{}

相关问题 更多 >