有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

java如何将DataSetitor拆分为测试和训练迭代器?

我使用的是Deeplearning4j和datavec,我有一个DataSetIterator对象,它代表我的所有数据,这是一个时间序列。我如何将其分为训练和测试迭代器?我检查了一下,DataSetIterator类的方法被弃用了。谢谢


共 (1) 个答案

  1. # 1 楼答案

    迭代你的DataSetIterator,为每个DataSet条目创建两个新的DataSets,每个都用于训练和测试

    关键是使用splitTestAndTrain方法,该方法接受一个double fractionTrain,该方法将指定要训练的数据量(其余要测试)。该方法有不同的重载,因此您可以选择最适合您需要的重载。如果希望将所有训练和测试数据集添加到一个公共迭代器中,可以将它们存储在两个不同的列表中,稍后再获取它们相应的迭代器。比如:

    List<DataSet> trainList = new ArrayList<>();
    List<DataSet> testList= new ArrayList<>();
    
    while (yourDataSetIterator.hasNext())
    {
        DataSet ds = yourDataSetIterator.next();
        SplitTestAndTrain splData = ds.splitTestAndTrain(0.5); //half for each         
        DataSet trainDs = splData.getTrain();
        trainList.add(trainDs);
        DataSet testDs  = splData.getTest();
        testList.add(testDs);
        (...)
    }
    
    Iterator<DataSet> trainIterator = trainList.iterator(); 
    Iterator<DataSet> testIterator  = testList.iterator(); 
    

    由于我不知道这个库的具体细节,这个例子只创建了“basic”iterators。这可能是定制的,因此您可以创建DataSetIterators

    请注意,在拆分数据集之前,可能还需要对其进行洗牌(ds.shuffle())。你可以找到一些例子here


    如果希望以特定的方式拆分它,可以标记不同的条目,并找到测试数据集的最大索引;然后,调用^{}方法,该方法专门针对max参数拆分数据集。这里的^{}方法也会有所帮助


    ^{}对有关其他机制的评论提出了一个很好的建议,以分割DataSetIterator,这似乎也是一种“更自然”的方式,即^{}

    它提供了getTrainIterator()getTestIterator()方法,它们返回库的特定迭代器DataSetIterator