有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

使用带有随机轴的quickselect的java第k个最小元素

我正在尝试使用quickselect算法查找数组中的第k个最小值。但是,当我尝试随机选择轴时,输出也是随机的

下面是我的方法实现,

    static int findKthMin(int[]arr, int n, int k) {
        int l=0 , r=n-1;
        Random random = new Random();
        while(true) {
            int x = random.nextInt(r+1-l) + l; // When using x = r (works correctly)
            int pivot = arr[x];
            int idx = l;
            for(int i=l;i<=r;i++) {
                if(arr[i] < pivot) {
                    int temp = arr[idx];
                    arr[idx] = arr[i];
                    arr[i] = temp;
                    
                    idx++;
                }
            }
            arr[x] = arr[idx];
            arr[idx] = pivot;
            
            if(idx == k-1) return pivot;
            
            if(idx > k-1) {
                r = idx-1;
            } else {
                l = idx;
            }
        }
    }

这里,n是数组的大小,k是要找到的第k个最小元素
当我使用x=r时,代码运行良好

我的猜测是,在这种情况下出现了问题

   for(int i=l;i<=r;i++) {
       if(arr[i] < pivot) {
            int temp = arr[idx];
            arr[idx] = arr[i];
            arr[i] = temp;

            idx++;
       }
   }          

但我不知道哪里出了问题以及如何解决。我花了数小时调试和修改代码,但我能找出问题所在

这是我正在尝试的测试用例

6               // n
7 10 4 3 20 15  //arr
3               // k

和,

5             // n
7 10 4 20 15  // arr
4             // k

在这些测试用例中,random pivot将任何数组元素作为输出
任何可能是错误的提示都会非常有用


共 (1) 个答案

  1. # 1 楼答案

    根据@Nico的建议,我只需要将pivot元素与最后一个元素交换
    以下是完整的工作片段

        static int findKthMin(int[]arr, int n, int k) {
            int l=0 , r=n-1;
            Random random = new Random();
            while(true) {
                int x = random.nextInt(r+1-l) + l; // When using x = r (works correctly)
    
                //Swap random pivot with the last index element
                int temp = arr[x];
                arr[x] = arr[r];
                arr[r] = temp;
    
                int pivot = arr[r];
    
                int idx = l;
                for(int i=l;i<=r;i++) {
                    if(arr[i] < pivot) {
                        temp = arr[idx];
                        arr[idx] = arr[i];
                        arr[i] = temp;
    
                        idx++;
                    }
                }
                arr[r] = arr[idx];
                arr[idx] = pivot;
    
                if(idx == k-1) return pivot;
    
                if(idx > k-1) {
                    r = idx-1;
                } else {
                    l = idx;
                }
            }
        }