梯度下降Logistic回归的机器学习在Java中的实现
我在Java中实现了带有梯度下降的逻辑回归。它似乎不起作用(它没有正确地对记录进行分类;y=1的概率很大)我不知道我的实现是否正确。我已经检查了代码好几次了,我找不到任何bug。我一直在关注Andrew Ng的课程时代机器学习教程。我的Java实现有3个类。即:
- 数据集。java:读取数据集
- 例如。java:有两个成员:1。双[]x和2。双标签
- 后勤。java:这是使用梯度下降实现逻辑回归的主要类李>
这是我的成本函数:
J(Θ)=(-1/m)[∑mi=1y(i)log(hΘ(x(i))+(1-y(i))log(1-hΘ(x(i))]
对于上述成本函数,这是我的梯度下降算法:重复(
Θj:=Θj-α∑mi=1(hΘ(x(i))x(i)
(同时更新所有Θj))
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Logistic {
/** the learning rate */
private double alpha;
/** the weight to learn */
private double[] theta;
/** the number of iterations */
private int ITERATIONS = 3000;
public Logistic(int n) {
this.alpha = 0.0001;
theta = new double[n];
}
private double sigmoid(double z) {
return (1 / (1 + Math.exp(-z)));
}
public void train(List<Instance> instances) {
double[] temp = new double[3];
//Gradient Descent algorithm for minimizing theta
for(int i=1;i<=ITERATIONS;i++)
{
for(int j=0;j<3;j++)
{
temp[j]=theta[j] - (alpha * sum(j,instances));
}
//simulataneous updates of theta
for(int j=0;j<3;j++)
{
theta[j] = temp[j];
}
System.out.println(Arrays.toString(theta));
}
}
private double sum(int j,List<Instance> instances)
{
double[] x;
double prediction,sum=0,y;
for(int i=0;i<instances.size();i++)
{
x = instances.get(i).getX();
y = instances.get(i).getLabel();
prediction = classify(x);
sum+=((prediction - y) * x[j]);
}
return (sum/instances.size());
}
private double classify(double[] x) {
double logit = .0;
for (int i=0; i<theta.length;i++) {
logit += (theta[i] * x[i]);
}
return sigmoid(logit);
}
public static void main(String... args) throws FileNotFoundException {
//DataSet is a class with a static method readDataSet which reads the dataset
// Instance is a class with two members: double[] x, double label y
// x contains the features and y is the label.
List<Instance> instances = DataSet.readDataSet("data.txt");
// 3 : number of theta parameters corresponding to the features x
// x0 is always 1
Logistic logistic = new Logistic(3);
logistic.train(instances);
//Test data
double[]x = new double[3];
x[0]=1;
x[1]=45;
x[2] = 85;
System.out.println("Prob: "+logistic.classify(x));
}
}
谁能告诉我我做错了什么? 提前感谢!:)
# 1 楼答案
在我学习逻辑回归时,我花了时间详细检查了您的代码
TLDR
事实上,这个算法似乎是正确的
我认为,之所以会出现如此多的假阴性或假阳性,是因为您选择了超参数
模型训练不足,因此假设不符合要求
详细信息
我必须创建
DataSet
和Instance
类,因为您没有发布它们,并基于冷冻治疗数据集设置一个序列数据集和一个测试数据集。 见http://archive.ics.uci.edu/ml/datasets/Cryotherapy+Dataset+然后,使用相同的精确代码(对于逻辑回归部分),通过选择
0.001
的alpha速率和100000
的多次迭代,我得到了测试数据集80.64516129032258
的准确率,这还不错我试图通过手动调整这些超参数来获得更好的精度,但没有得到更好的结果
我想,在这一点上,一个改进就是实施正规化
梯度下降公式
在Andrew Ng关于成本函数和梯度下降的视频中,省略了
1/m
项是正确的。 一种可能的解释是1/m
项包含在alpha
项中。 或者这只是一个疏忽。 见第6m53s页https://www.youtube.com/watch?v=TTdcc21Ko9A&index=36&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=6m53s但是如果你看Andrew Ng关于正则化和逻辑回归的视频,你会注意到
1/m
一词显然出现在公式中。 见第2m19s页https://www.youtube.com/watch?v=IXPgm1e0IOo&index=42&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&t=2m19s