有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

java可以使用执行器编写递归fork-join解决方案。newWorkStealingPool()?

下面的代码旨在展示递归fork join(find max)的简单用法,我知道Java JIT可以在一个简单的单线程循环中更快地实现这一点,不过这只是为了演示

最初,我使用ForkJoin框架实现了find max,该框架适用于大型double数组(1024*1024)

我觉得我应该能够在不使用ForkJoin框架、只使用Executor的情况下实现相同的。WorksteachingPool()和Callables/Futures

这可能吗

我的尝试如下:

class MaxTask implements Callable<Double> {

    private double[] array;
    private ExecutorService executorService;
    public MaxTask(double[] array, ExecutorService es){
        this.array = array;
        this.executorService = es;
    }
    @Override
    public Double call() throws Exception {
        if (this.array.length!=2){
            double[] a = new double[(this.array.length/2)];
            double[] b = new double[(this.array.length/2)];
            for (int i=0;i<(this.array.length/2);i++){
                a[i] = array[i];
                b[i] = array[i+(this.array.length/2)];
            }
            Future<Double> f1 = this.executorService.submit(new MaxTask(a,this.executorService));
            Future<Double> f2 = this.executorService.submit(new MaxTask(b,this.executorService));

             return Math.max(f1.get(), f2.get());
        } else {
            return Math.max(this.array[0], this.array[1]);
        }
    }

}

ExecutorService es = Executors.newWorkStealingPool();

double[] x = new double[1024*1024];
for (int i=0;i<x.length;i++){
    x[i] = Math.random();
}

MaxTask mt = new MaxTask(x,es);

es.submit(mt).get();

共 (1) 个答案

  1. # 1 楼答案

    似乎在没有ForkJoin框架的情况下编写“fork/join”类型的计算是可能的(请参阅下面Callable的用法)。 ForkJoin框架本身似乎没有性能差异,但代码可能更整洁一些,我更喜欢只使用可调用函数

    我还修复了最初的尝试。 看起来在最初的尝试中阈值太小了,这就是为什么它很慢的原因,我认为它至少需要和内核的数量一样大

    我不确定ForkJoinPool的使用是否会更快,我需要收集更多的统计数据,我不这么认为,因为它没有任何长期阻塞的操作

    public class Main {
    
    static class FindMaxTask extends RecursiveTask<Double> {
    
        private int threshold;
        private double[] data;
        private int startIndex;
        private int endIndex;
    
        public FindMaxTask(double[] data, int startIndex, int endIndex, int threshold) {
            super();
            this.data = data;
            this.startIndex = startIndex;
            this.endIndex = endIndex;
            this.threshold = threshold;
        }
    
    
        @Override
        protected Double compute() {
            int diff = (endIndex-startIndex+1);
            if (diff!=(this.data.length/threshold)){ 
                int aStartIndex = startIndex;
                int aEndIndex = startIndex + (diff/2) - 1;
                int bStartIndex = startIndex + (diff/2);
                int bEndIndex = endIndex;
    
                FindMaxTask f1 = new FindMaxTask(this.data,aStartIndex,aEndIndex,threshold);
                f1.fork();
                FindMaxTask f2 = new FindMaxTask(this.data,bStartIndex,bEndIndex,threshold);
                return Math.max(f1.join(),f2.compute());
            } else {
                double max = Double.MIN_VALUE;
                for (int i = startIndex; i <= endIndex; i++) {
                    double n = data[i];
                    if (n > max) {
                        max = n;
                    }
                }
                return max;
            }
        }
    
    }
    
    static class FindMax implements Callable<Double> {
    
        private double[] data;
        private int startIndex;
        private int endIndex;
        private int threshold;
    
        private ExecutorService executorService;
    
        public FindMax(double[] data, int startIndex, int endIndex, int threshold, ExecutorService executorService) {
            super();
            this.data = data;
            this.startIndex = startIndex;
            this.endIndex = endIndex;
            this.executorService = executorService;
            this.threshold = threshold;
        }
    
    
    
        @Override
        public Double call() throws Exception {
            int diff = (endIndex-startIndex+1);
            if (diff!=(this.data.length/this.threshold)){
                int aStartIndex = startIndex;
                int aEndIndex = startIndex + (diff/2) - 1;
                int bStartIndex = startIndex + (diff/2);
                int bEndIndex = endIndex;
    
                Future<Double> f1 = this.executorService.submit(new FindMax(this.data,aStartIndex,aEndIndex,this.threshold,this.executorService));
                Future<Double> f2 = this.executorService.submit(new FindMax(this.data,bStartIndex,bEndIndex,this.threshold,this.executorService));
                return Math.max(f1.get(), f2.get());
            } else {
                double max = Double.MIN_VALUE;
                for (int i = startIndex; i <= endIndex; i++) {
                    double n = data[i];
                    if (n > max) {
                        max = n;
                    }
                }
                return max;
            }
        }
    
    }
    
    public static void main(String[] args) throws InterruptedException, ExecutionException {
    
        double[] data = new double[1024*1024*64];
        for (int i=0;i<data.length;i++){
            data[i] = Math.random();
        }
    
        int p = Runtime.getRuntime().availableProcessors();
        int threshold = p;
        int threads = p;
        Instant start = null;
        Instant end = null;
    
        ExecutorService es = null;
        es = Executors.newFixedThreadPool(threads);
        System.out.println("1. started..");
        start = Instant.now();
        System.out.println("max = "+es.submit(new FindMax(data,0,data.length-1,threshold,es)).get());
        end = Instant.now();
        System.out.println("Callable (recrusive), with fixed pool, Find Max took ms = "+ Duration.between(start, end).toMillis());
    
        es = new ForkJoinPool();
        System.out.println("2. started..");
        start = Instant.now();
        System.out.println("max = "+es.submit(new FindMax(data,0,data.length-1,threshold,es)).get());
        end = Instant.now();
        System.out.println("Callable (recursive), with fork join pool, Find Max took ms = "+ Duration.between(start, end).toMillis());
    
        ForkJoinPool fj = new ForkJoinPool(threads);
        System.out.println("3. started..");
        start = Instant.now();
        System.out.println("max = "+fj.invoke(new FindMaxTask(data,0,data.length-1,threshold)));
        end = Instant.now();
        System.out.println("RecursiveTask (fork/join framework),with fork join pool, Find Max took ms = "+ Duration.between(start, end).toMillis());
    }
    

    }