有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

梯度下降Logistic回归的机器学习在Java中的实现

我在Java中实现了带有梯度下降的逻辑回归。它似乎不起作用(它没有正确地对记录进行分类;y=1的概率很大)我不知道我的实现是否正确。我已经检查了代码好几次了,我找不到任何bug。我一直在关注Andrew Ng的课程时代机器学习教程。我的Java实现有3个类。即:

  1. 数据集。java:读取数据集
  2. 例如。java:有两个成员:1。双[]x和2。双标签
  3. 后勤。java:这是使用梯度下降实现逻辑回归的主要类

这是我的成本函数:

J(Θ)=(-1/m)[∑mi=1y(i)log(hΘ(x(i))+(1-y(i))log(1-hΘ(x(i))]

对于上述成本函数,这是我的梯度下降算法:

重复(

Θj:=Θj-α∑mi=1(hΘ(x(i))x(i)

(同时更新所有Θj

)

import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class Logistic {

    /** the learning rate */
    private double alpha;

    /** the weight to learn */
    private double[] theta;

    /** the number of iterations */
    private int ITERATIONS = 3000;

    public Logistic(int n) {
        this.alpha = 0.0001;
        theta = new double[n];
    }

    private double sigmoid(double z) {
        return (1 / (1 + Math.exp(-z)));
    }

    public void train(List<Instance> instances) {

    double[] temp = new double[3];

    //Gradient Descent algorithm for minimizing theta
    for(int i=1;i<=ITERATIONS;i++)
    {
       for(int j=0;j<3;j++)
       {      
        temp[j]=theta[j] - (alpha * sum(j,instances));
       }

       //simulataneous updates of theta  
       for(int j=0;j<3;j++)
       {
         theta[j] = temp[j];
       }
        System.out.println(Arrays.toString(theta));
    }

    }

    private double sum(int j,List<Instance> instances)
    {
        double[] x;
        double prediction,sum=0,y;


       for(int i=0;i<instances.size();i++)
       {
          x = instances.get(i).getX();
          y = instances.get(i).getLabel();
          prediction = classify(x);
          sum+=((prediction - y) * x[j]);
       }
         return (sum/instances.size());

    }

    private double classify(double[] x) {
        double logit = .0;
        for (int i=0; i<theta.length;i++)  {
            logit += (theta[i] * x[i]);
        }
        return sigmoid(logit);
    }


    public static void main(String... args) throws FileNotFoundException {

      //DataSet is a class with a static method readDataSet which reads the dataset
      // Instance is a class with two members: double[] x, double label y
      // x contains the features and y is the label.

        List<Instance> instances = DataSet.readDataSet("data.txt");
      // 3 : number of theta parameters corresponding to the features x 
      // x0 is always 1   
        Logistic logistic = new Logistic(3);
        logistic.train(instances);

        //Test data
        double[]x = new double[3];
        x[0]=1;
        x[1]=45;
        x[2] = 85;

        System.out.println("Prob: "+logistic.classify(x));


    }
}

谁能告诉我我做错了什么? 提前感谢!:)


共 (1) 个答案

  1. # 1 楼答案

    在我学习逻辑回归时,我花了时间详细检查了您的代码

    TLDR

    事实上,这个算法似乎是正确的

    我认为,之所以会出现如此多的假阴性或假阳性,是因为您选择了超参数

    模型训练不足,因此假设不符合要求

    详细信息

    我必须创建DataSetInstance类,因为您没有发布它们,并基于冷冻治疗数据集设置一个序列数据集和一个测试数据集。 见http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+

    然后,使用相同的精确代码(对于逻辑回归部分),通过选择0.001的alpha速率和100000的多次迭代,我得到了测试数据集80.64516129032258的准确率,这还不错

    我试图通过手动调整这些超参数来获得更好的精度,但没有得到更好的结果

    我想,在这一点上,一个改进就是实施正规化

    梯度下降公式

    在Andrew Ng关于成本函数和梯度下降的视频中,省略了1/m项是正确的。 一种可能的解释是1/m项包含在alpha项中。 或者这只是一个疏忽。 见第6m53s页https://www.youtube.com/watch?v=TTdcc21Ko9A&index=36&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=6m53s

    但是如果你看Andrew Ng关于正则化和逻辑回归的视频,你会注意到1/m一词显然出现在公式中。 见第2m19s页https://www.youtube.com/watch?v=IXPgm1e0IOo&index=42&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=2m19s