JDK1.8 Arrays.sort 源码浅析

写在前面

最近复习完10个常用的排序算法之后, 想起来Java是有内置的排序算法的,之前只是用过,没有深入研究,今天就尝试阅读并理解Java的排序算法。 在之前学习Java多线程的时候,偶然了解到多线程排序,后来查找了相关的资料发现Arrays类里面其实也内置了paralellSort的多线程排序,理解完Array.sort之后也会尝试解读一下这个类。 那么就不多说废话了,开始吧。

万万没想到,这篇文章刚开始写没多久,晚上就因为长时间久坐+坐姿问题肩附近肌肉劳损饱受折磨,广大程序猿还是多注意身体。

Arrays.sort概述以及DualPivotQuicksort.sort用到的排序算法

array_sort

JDK对Arrays类的概述

Arrays这个类包含了一些操作数组的方法比如查找和排序,同时有一个可以把数组看成list的静态工厂(这个我似乎没用过)。所有的方法都会抛出空指针错误。

private static final int MIN_ARRAY_SORT_GRAN = 1 << 13;

MIN_ARRAY_SORT_GRAN 这个参数是执行paralellSort的门槛,也就是说,当排序的数组的元素个数少于2的13次方也就是8192的时候,将不会执行paralellSort而是执行非多线程的排序来提高效率,否则使用多线程归并排序。更具体的以后到了并行排序在进行具体分析吧。

static final class NaturalOrder implements Comparator

这个类简单来说是当数组的元素没有实现Comparable接口(可能是自定义实现的一些类)无法比较的时候作为一个默认的实现Comparable的类。

private static void rangeCheck(int arrayLength, int fromIndex, int toIndex)

这个类是范围检查的类,如果输入的起始Index和终点Index超出范围或者不合法就会抛出异常。

接下去就是开始介绍JDK1.8里使用的快速排序算法– Dual-Pivot Quicksort。可以参考xumingmingv的这篇文章了解一些这个算法。这个算法基于普通的快速排序算法改进的,在当前的计算机环境下更加高效的算法。在源码中可以发现有很多的重载的sort方法对应不同类型的数组。

DualPivotQuicksort

这个类开头定义了很多的threshHold,也就是执行哪个算法的阈值,有quickSort, countingSort等等,我们现在先关注DualPivotQuicksort可能相关的quickSort的值:286

  static void sort(int[] a, int left, int right, int[] work, int workBase, int workLen) 

这是DualPivotQuicksort提供给外部类的一个排序的接口

上面这几个参数是DualPivotQuicksort的sort的方法的参数。 a 是要排序的数组 left 是要排序的数组片段的起始坐标 right 是要排序的数组片段的终止坐标

以下都是在归并排序用到的参数,快排没用到 workspace是用来做归并排序的工作空间 workBase是初始已用的工作数组空间 workLen是可用的工作数组的长度

        // Use Quicksort on small arrays
        if (right - left < QUICKSORT_THRESHOLD) {
            sort(a, left, right, true);
            return;
        }

可以看到,当需要排序的数组的长度小于QUICKSORT_THRESHOLD=286的时候,会执行sort(a, left, right, true),sort就是JDK改进过的快排。

JDK用的快排-sort

private static void sort(int[] a, int left, int right, boolean leftmost)

    /**
     * @param leftmost indicates if this part is the leftmost in the range
     */

这个类的核心思想也是双轴快排,其他几个参数和上一步讲的一样,这个参数leftmost意思是需要排序的数组片段是不是位于整个数组的最左边,也就是从0开始。往下继续看。


 if (length < INSERTION_SORT_THRESHOLD) {
            if (leftmost) {
			执行简单的插入排序
			}else{
			执行成对插入排序
			}
			

length是要排序的数组片段的长度,当length小于插入排序的insertion_thresh_hold=47的时候,就对小的数组片段执行插入排序,插入排序对小数组的效果好。


如果leftmost=true,也就是说排序的片段是处于a的最左边,就用传统的插入排序算法(上一篇常用的排序算法有实现过)。 如果leftmost=false,就通过一个while循环跳过已经是有序的递增的部分,这是一个小优化。接下去使用了一个叫做pair inserion sort的改进版本的插入排序提高效率。简单来说这个算法就是每次插入都成对插入两个元素,先把大的元素插入进去有序数组片段(前面的while循环所确定好的有序片段),然后再把小的那个一插入到大的元素前面的有序数组片段这样就减少了一些无效的比较和数组元素的移动。


关于pair insertion sort,JDK源码的注释解释说比简单插入排序的效率来得快。我在看到这里的时候有疑问:那为什么不统一用pair insertion sort,而是当lestmost=false的时候才用成对插入(看到后面发现,其实就只有当要数组已经不大的时候,一小部分最左边才用简单插入,大部分还是用的成对插入)。我尝试把pair insertion sort的代码拷贝到测试类进行测试,发现报了超出ArrayIndexOutOfBoundsException的错误,分析了一下代码,发现while (a1 < a[–k])和while (a2 < a[–k])这两个循环的条件存在问题,如果k=0的时候还进行减一这个操作,就会出现a[-1],导致数组下标不合法。

		 * Every element from adjoining part plays the role
                 * of sentinel, therefore this allows us to avoid the
                 * left range check on each iteration. 

我又回头看了一下注释,发现当leftmost=true的时候,默认把下标为left-1的元素作为sentinel,按照这个思路,意思是认为a[left-1]的值一定小于a[left]到a[right]的每个元素,因此可以把它作为边界,但是要如何作这样的保证呢?是通过递归调用普通快速排序或者双轴快速排序已经把要排序的片段分成相对有序的一段一段的小片段了,再来调用的这个方法。 因为是两两插入,所以可能存在有末端的元素只剩一个,因此在最后用普通插入排序对a[right]这个元素进行插入排序,保证每个元素都是正确的。


以下是整个sort算法的重点部分,是快速排序及其改进的实现

int seventh = (length >> 3) + (length >> 6) + 1;

利用位运算估算length的1/7长度

        int e3 = (left + right) >>> 1; // The midpoint
        int e2 = e3 - seventh;
        int e1 = e2 - seventh;
        int e4 = e3 + seventh;
        int e5 = e4 + seventh;

将数组长度划分为7等份之后,确定5个分割点为3/14分位点、5/14分位点、中点、9/14分位点以及11/14分位点,程序中分别记为:e1、e2、e3、e4、e5。 接下去JDK源码用了很长的一串代码来进行五个元素的从小到大排序,虽然代码多但是逻辑比较简单,就是比较交换,这里就不在赘述。


下面这部分代码有点多,在这里我切分一下,下面这部分执行的条件是当e1、e2、e3、e4、e5五个元素的值都不一样的时候用的双轴快排 **

			//两个轴
            float pivot1 = a[e2];
            float pivot2 = a[e4];

把第二和第四个元素设置为中轴


            while (a[++less] < pivot1);
            while (a[--great] > pivot2);

同样一个小优化,把指针跳过比pivot1小的元素和跳过比pivot2大的元素



             * Partitioning:
             *
             *   left part           center part                   right part
             * +--------------------------------------------------------------+
             * |  < pivot1  |  pivot1 <= && <= pivot2  |    ?    |  > pivot2  |
             * +--------------------------------------------------------------+
             *               ^                          ^       ^
             *               |                          |       |
             *              less                        k     great

这是JDK源码给的示意图,数组被分为以上几部分

下面是核心的代码,吐槽一下编写JDK sort的大佬的循环逻辑非常喜欢用++或者–,跟我平时的编程习惯不太一样,看多了感觉人有点懵。 这段我直接复制的整段源码加以注释,否则太乱了


outer:
            for (int k = less - 1; ++k <= great; ) {//其实就是从a[less]开始遍历到a[great]
			//ak就是当前要进行快排的数
                int ak = a[k];
				//如果ak小于pivot1,那就要插入到less的左边部分的数组里,相当于扩展right part的数组,具体逻辑就是把a[less]和a[k]交换,然后把less的坐标+1.
                if (ak < pivot1) { // Move a[k] to left part
                    a[k] = a[less];
                    /*
                     * Here and below we use "a[i] = b; i++;" instead
                     * of "a[i++] = b;" due to performance issue.
                     */
                    a[less] = ak;
                    ++less;
                } else if (ak > pivot2) { // 如果ak>pivot2,那就要插入到right的右边的数组,相当于扩展right part的数组	
                    while (a[great] > pivot2) {//再次通过while尝试把大于pivot2的数都添加到right part部分
                        if (great-- == k) {
						//当a[great]和a[k]相遇,说明整个数组已经执行一轮快速排序完毕,注意不是整个排序好了,要进行递归继续排序的
                            break outer;
                        }
                    }
					//这段逻辑有点绕:如果是a[great]<=pivot1<pivot2这个情况,就a[k]赋值为a[less],a[less]=a[great],把less++因为left part新增加了一个数,后面还要把a[great]=ak;就是一个呈三角形赋值的一个关系。
                    if (a[great] < pivot1) { 
                        a[k] = a[less];
                        a[less] = a[great];
                        ++less;
                    }//如果是pivot1 <= a[great] <= pivot2这个情况,就直接把a[k]和a[great]交换就行了。
					else { // pivot1 <= a[great] <= pivot2
                        a[k] = a[great];
                    }
                    /*这里是JDK开发者给的提示,由于这边分析的是Int类型的数组,所以忽略了
                     * Here and below we use "a[i] = b; i--;" instead
                     * of "a[i--] = b;" due to performance issue.
                     */
                    a[great] = ak;
					//great右边的数组left part 加了1,把great指针往左移动一位
                    --great;
                }
            }

            // 最后记得把两个pivot的位置给赋值到正确的位置,就跟普通的快排的最后是一样的,只不过这有2个数。
            a[left]  = a[less  - 1]; a[less  - 1] = pivot1;
            a[right] = a[great + 1]; a[great + 1] = pivot2;

            // 递归继续进行排序,把左右两边递归进行排序,右边部分很显然leftmost是false。
            sort(a, left, less - 2, leftmost);
            sort(a, great + 2, right, false);

			//一个数组被切了2次共有3部分,接下去看中间部分是怎样处理的

            /*
             * If center part is too large (comprises > 4/7 of the array),
             * swap internal pivot values to ends.
             */
			 //如果中间的部分太长了,就要继续进行处理,具体的值是是否大于4/7的length
            if (less < e1 && e5 < great) {
                /*
                 * Skip elements, which are equal to pivot values.
                 */
				 //跳过跟两个轴相等的元素,这些都已经确定好位置了,不用管
                while (a[less] == pivot1) {
                    ++less;
                }

                while (a[great] == pivot2) {
                    --great;
                }

                /*
                 * Partitioning:
                 *
                 *   left part         center part                  right part
                 * +----------------------------------------------------------+
                 * | == pivot1 |  pivot1 < && < pivot2  |    ?    | == pivot2 |
                 * +----------------------------------------------------------+
                 *              ^                        ^       ^
                 *              |                        |       |
                 *             less                      k     great
                 *
                 * Invariants:
                 *
                 *              all in (*,  less) == pivot1
                 *     pivot1 < all in [less,  k)  < pivot2
                 *              all in (great, *) == pivot2
                 *
                 * Pointer k is the first index of ?-part.
                 */
				 //看到这里有没有感觉有点熟悉,跟上面初次进行快排有点像
                outer:
				//从less到great进行遍历
                for (int k = less - 1; ++k <= great; ) {
                    int ak = a[k];
					//a[k]跟pivot1相等,直接给他移到那一堆跟pivot1相等的数组片段里,把less下标向右移动一位
                    if (ak == pivot1) { // Move a[k] to left part
                        a[k] = a[less];
                        a[less] = ak;
                        ++less;
                    }//同样,检查是否有值等于pivot2的,把great指针往左边移
					else if (ak == pivot2) { // Move a[k] to right part
                        while (a[great] == pivot2) {
                            if (great-- == k) {
                                break outer;
                            }
                        }
						//跟上边一样的逻辑,a[great]恰好等于pivot1,进行三方互换
                        if (a[great] == pivot1) { // a[great] < pivot2
                            a[k] = a[less];
                            a[less] = pivot1;
                            ++less;
                        } else { // pivot1 < a[great] < pivot2,把a[k]和a[great]交换,刚好两个元素都到合适的位置]
                            a[k] = a[great];
                        }
                        a[great] = ak;
                        --great;
                    }
                }
		    //执行到这里证明已经把中间部分所有等于pivot1和pivot2的元素排到了正确的位置,剩下的在这里就无能为力了,需要递归进行再次的排序,因为等于pivot1和pivot2的元素都不在left part和right part,所以前面能直接调用sort对左右两部分先进行排序
            sort(a, less, great, false);

这是当e1,e2,e3,e4,e5其中有两个相等的时候用的快速排序算法,不同于双轴快排,这里指定e3为单轴

            long pivot = a[e3];

            /*
             * Partitioning degenerates to the traditional 3-way
             * (or "Dutch National Flag") schema:
             *
             *   left part    center part              right part
             * +-------------------------------------------------+
             * |  < pivot  |   == pivot   |     ?    |  > pivot  |
             * +-------------------------------------------------+
             *              ^              ^        ^
             *              |              |        |
             *             less            k      great
             *
             * Invariants:
             *
             *   all in (left, less)   < pivot
             *   all in [less, k)     == pivot
             *   all in (great, right) > pivot
             *
             * Pointer k is the first index of ?-part.
             */
			 //遍历less到great,终于看见我常用的遍历写法了,感动
            for (int k = less; k <= great; ++k) {
                if (a[k] == pivot) {
                    continue;
                }
                long ak = a[k];
				//如果ak<pivot,扩容左边数组,同上一部分分析的逻辑
                if (ak < pivot) { // Move a[k] to left part
                    a[k] = a[less];
                    a[less] = ak;
                    ++less;
                }//如果大于pivot就分配到右半部分,没啥好说的 
				else { // a[k] > pivot - Move a[k] to right part
				//如果a[great]是大于pivot,就把指针往左移
                    while (a[great] > pivot) {
                        --great;
                    }
					//判断一下是不是小于pivot,如果是的话刚好又可以进行三边交换,美滋滋
                    if (a[great] < pivot) { // a[great] <= pivot
                        a[k] = a[less];
                        a[less] = a[great];
                        ++less;
                    }//如果没得三边交换,那就证明a[great] == pivot,将就一下把a[great]和a[k]交换一下吧
			else { // a[great] == pivot
                        a[k] = pivot;
                    }
					//记得把a[great]复制,把great坐标移动一下
                    a[great] = ak;
                    --great;
                }
            }

            /*
             * Sort left and right parts recursively.
             * All elements from center part are equal
             * and, therefore, already sorted.
             */
			 //除去跟pivot相等的部分的数,剩下的left part 和 right part就可以愉快地进一步sort了,同样,右边部分不是leftmost,设置为false。
            sort(a, left, less - 1, leftmost);
            sort(a, great + 1, right, false);

写到这里,这个sort方法(private static void sort(int[] a, int left, int right, boolean leftmost))就理清楚了,这个方法就是JDK1.8排序的核心方法 但是源码真正提供给外部的sort方法(static void sort(int[] a, int left, int right, int[] work, int workBase, int workLen) )还没解析完成呢,之前我们在第三行代码就开始解析这个私有的sort方法,我们还要接着往下看,把后面的部分也分析完才算成功。


我简述一下下面这段代码的用处:这段代码是判断需要排序的数组的结构是不是大概有序的,如果需要排序的数组可以由限定范围内的几段递增或者递减或者相等的数组组成,那么就使用改进版本的归并排序来进行排序,而如果数组结构超出限定的范围,就用改进的快速排序。细节请看下面代码的注释详解。

		//new一个run长度为68的数组,后续做说明
        int[] run = new int[MAX_RUN_COUNT + 1];
        int count = 0; run[0] = left;
		
        // 初始化完相关参数之后,开始检验整体结构
		//从left开始,遍历到right,这个count是把需要排序数组分成递增或者递减,相等的数组片段数量
        for (int k = left; k < right; run[count] = k) {
		//如果前两个元素时递增,就继续往后判断能递增到第几个元素
            if (a[k] < a[k + 1]) { // ascending
                while (++k <= right && a[k - 1] <= a[k]);
            }		//如果前两个元素时递减,就继续往后判断能递减到第几个元素,有一点不同的是,这里把递减的片段改成了递增的了,方便后续进行归并排序(如果符合归并排序条件的话)
			else if (a[k] > a[k + 1]) { // descending
                while (++k <= right && a[k - 1] >= a[k]);
                for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) {
                    int t = a[lo]; a[lo] = a[hi]; a[hi] = t;
                }
            } else { // 相等的片段数量如果超过MAX_RUN_LENGTH=33的话,就调用快排
                for (int m = MAX_RUN_LENGTH; ++k <= right && a[k - 1] == a[k]; ) {
                    if (--m == 0) {
                        sort(a, left, right, true);
                        return;
                    }
                }
            }

      		//如果count数量超过了MAX_RUN_COUNT=57,说明这个需要排序的数组结构很差,需要进行快排
            if (++count == MAX_RUN_COUNT) {
                sort(a, left, right, true);
                return;
            }
        }

下面这段代码紧接在上面的代码之后,功能是做特殊的判断。

		//当run[count]只对应数组的最后一个数时,扩展一位run[++count],作为哨兵,后续会继续解释,这里先跳过。这段代码看的挺难受的,     我给他改成   if (run[count] == right) run[++count] = ++right,就顺眼多了。
        if (run[count] == right++) { 
            run[++count] = right;
        }//如果count只有1,证明要排序的数组已经是有序的了,直接返回就行了 
		else if (count == 1) { 
            return;
        }

下面这段代码是为正式运行改进版本的插入排序做准备,只看这里肯定一脸懵逼,所以就先大概瞄一下,往下继续看吧。

// Determine alternation base for merge
        byte odd = 0;
        for (int n = 1; (n <<= 1) < count; odd ^= 1);

        // Use or create temporary array b for merging
        int[] b;                 // temp array; alternates with a
        int ao, bo;              // array offsets from 'left'
        int blen = right - left; // space needed for b
        if (work == null || workLen < blen || workBase + blen > work.length) {
            work = new int[blen];
            workBase = 0;
        }
        if (odd == 0) {
            System.arraycopy(a, left, work, workBase, blen);
            b = a;
            bo = 0;
            a = work;
            ao = workBase - left;
        } else {
            b = work;
            ao = 0;
            bo = workBase - left;
        }

下面这部分是JKD改进的归并排序的核心,说实话这部分看得我头有点疼,参考了 octopusflying的sort源码分析,给我指明了这部分源码的实现思想。

		//这个last不先赋初始值真滴看着难受啊,这个last是归并完一次之后待归并的个数。count只要不是1,就证明还没有归并排序完成,就继续循环。
        for (int last; count > 1; count = last) {
		//两个为一组进行归并排序,可能会有一组单手狗没得佩对,后续给他直接完整复制过去就行了
            for (int k = (last = 0) + 2; k <= count; k += 2) {
                int hi = run[k], mi = run[k - 1];
                for (int i = run[k - 2], p = i, q = mi; i < hi; ++i) {
                    if (q >= hi || p < mi && a[p + ao] <= a[q + ao]) {
                        b[i + bo] = a[p++ + ao];
                    } else {
                        b[i + bo] = a[q++ + ao];
                    }
                }
                run[++last] = hi;
            }
			//处理单身狗,复制到排序好的数组然后下次再处理他
            if ((count & 1) != 0) {
                for (int i = right, lo = run[count - 1]; --i >= lo;
                    b[i + bo] = a[i + ao]
                );
                run[++last] = right;
            }
            int[] t = a; a = b; b = t;
            int o = ao; ao = bo; bo = o;
        }

JDK的这个归并排序用的思路是:有两个数组空间a,b,其中一个存放排序好一轮的数组,就把待排序的数组在a,b来回倒腾,直到整个归并排序都完成了。有个关键点是要注意到底哪个数组空间存的是已经排序好一轮的,否则不在白忙活了,还有就是这里的归并没有跟我们平时写的一样,用递归,而是从底层开始,一步一步网上归并,这得益于基本有序的结构,这是我们一开始就判断过的。

总结

终于写完了,对本篇博客小小总结一下吧,我只是跟着源码一步一步理解代码,其中有一些参数比如什么时候进行插入排序,什么时候进行快速排序之类的thresh_hold并不理解其中的用意,不知道是根据数学统计得出的还是什么,看到JDK归并的核心以及双轴快排的时候,刚开始一脸懵逼,后来在网上搜了一下双轴快排的思路,才继续往下走。如果有不理解的同学建议多搜。本来想多贴几张图方便理解的,奈何画图软件用的不行,感觉非常难看,后面可能在补上吧。

现在是2019-6-2 北京时间 0:51,断断续续写了好几天,不得不说以我目前的水平完整地读源码(即使已经选了比较简单的排序的源码)还是感觉挺花费时间的,有一些点还是无法靠自己完全理解,这里感谢并且推荐octopusflying大佬的sort源码分析,在最后一阶段分析这个归并排序的时候看了一下他的图豁然开朗。