有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

java为什么当线程数增加时程序会变慢

我是java的初学者

最近,我在写一个计算矩阵乘法的程序。所以我写了一个类来做这个

public class MultiThreadsMatrixMultipy{
 public   int[][] multipy(int[][] matrix1,int[][] matrix2) {
     if(!utils.CheckDimension(matrix1,matrix2)){
         return null;
     }
     int row1 = matrix1.length;
     int col1 = matrix1[0].length;
     int row2 = matrix2.length;
     int col2 = matrix2[0].length;
     int[][] ans = new int[row1][col2];
     Thread[][]  threads = new SingleRowMultipy[row1][col2];

     for(int i=0;i<row1;i++){
         for(int j=0;j<col2;j++){
             threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
             threads[i][j].start();
         }
     }
     return ans;
 }
}
public class SingleRowMultipy extends Thread{
        private int row;
        private int col;
        private int[][] A;
        private int[][] B;
        private int[][] ans;
        public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
            this.row = row;
            this.col = col;
            this.A = A;
            this.B = B;
            this.ans = C;
        }
        public void run(){
            int sum =0;
            for(int i=0;i<A[row].length;i++){
                 sum+=(A[row][i]*B[i][col]);
            }
            ans[row][col] = sum;
        }
}

我想用一个线程来计算matrix1[i][:] * matrix2[:][j],矩阵的大小是1000*50005000*1000,所以线程的数量是1000*1000。当我运行这个程序时,它的速度非常慢,而且成本大约为38s。如果我只使用single-thread来计算结果,它将花费17s。单线程代码如下所示:

public class SimpleMatrixMultipy
{
    public int[][] multipy(int[][] matrix1,int[][] matrix2){
        int row1 = matrix1.length;
        int col1 = matrix1[0].length;
        int row2 = matrix2.length;
        int col2 = matrix2[0].length;
        int[][] ans = new int[row1][col2];
        for(int i=0;i<row1;i++){
            for(int j=0;j<col2;j++){
                for(int k=0;k<col1;k++){
                    ans[i][j] += matrix1[i][k]*matrix2[k][j];
                }
            }
        }
        return ans;
    }

}

我能做些什么来加速程序


共 (1) 个答案

  1. # 1 楼答案

    正如@Turing85所说,需要管理线程数。有两种方法,一种是将Executors.newFixedThreadPool用于固定数量的线程,另一种是使用Executors.newCachedThreadPool来使用现有的可用线程

    另一个要点是避免直接继承Thread类,而是实现runnable

    import java.util.ArrayList;
    import java.util.Date;
    import java.util.Iterator;
    import java.util.List;
    import java.util.concurrent.Executor;
    import java.util.concurrent.Executors;
    import java.util.concurrent.ThreadFactory;
    
    public class MultiThreadsMatrixMultipy {
    
        public static void main(final String[] args) {
    
        }
    
        public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
            if(!utils.CheckDimension(matrix1,matrix2)){
                return null;
            }
            final int row1 = matrix1.length;
            final int col2 = matrix2[0].length;
            final int[][] ans = new int[row1][col2];
            // final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
            final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));
    
            for (int i = 0; i < row1; i++) {
                for (int j = 0; j < col2; j++) {
                    executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
                }
            }
            return ans;
        }
    }
    
    class CustomThreadFactory implements ThreadFactory {
        private int counter;
        private final String name;
        private final List<String> stats;
    
        public CustomThreadFactory(final String name) {
            counter = 1;
            this.name = name;
            stats = new ArrayList<>();
        }
    
        @Override
        public Thread newThread(final Runnable runnable) {
            final Thread t = new Thread(runnable, name + "-Thread_" + counter);
            counter++;
            stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
            return t;
        }
    
        public String getStats() {
            final StringBuffer buffer = new StringBuffer();
            final Iterator<String> it = stats.iterator();
            while (it.hasNext()) {
                buffer.append(it.next());
            }
            return buffer.toString();
        }
    }
    
    class SingleRowMultipy implements Runnable {
        private final int row;
        private final int col;
        private final int[][] A;
        private final int[][] B;
        private final int[][] ans;
    
        public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
            this.row = row;
            this.col = col;
            this.A = A;
            this.B = B;
            this.ans = C;
        }
    
        @Override
        public void run() {
            int sum = 0;
            for (int i = 0; i < A[row].length; i++) {
                sum += (A[row][i] * B[i][col]);
            }
            ans[row][col] = sum;
        }
    }