其中,exchFlag用来记录当前迭代是否发生了数据交换,而start变量用来表示是奇交换还是偶交换。初始时,start为0,表示进行偶交换,每次迭代结束后,切换start的状态。如果上一次比较交换发生了数据交换,或者当前正在进行的是奇交换,循环就不会停止,直到程序不再发生交换,并且当前进行的是偶交换为止(表示奇偶交换已经成对出现)。
上述代码虽然是串行代码,但是已经可以很方便地改造成并行模式:
01 static int exchFlag=1; 02 static synchronized void setExchFlag(int v){ 03 exchFlag=v; 04 } 05 static synchronized int getExchFlag(){ 06 return exchFlag; 07 } 08 09 public static class OddEvenSortTask implements Runnable{ 10 int i; 11 CountDownLatch latch; 12 public OddEvenSortTask(int i,CountDownLatch latch){ 13 this.i=i; 14 this.latch=latch; 15 } 16 @Override 17 public void run() { 18 if (arr[i] > arr[i + 1]) { 19 int temp = arr[i]; 20 arr[i] = arr[i + 1]; 21 arr[i + 1] = temp; 22 setExchFlag(1); 23 } 24 latch.countDown(); 25 } 26 } 27 public static void pOddEvenSort(int[] arr) throws InterruptedException { 28 int start = 0; 29 while (getExchFlag() == 1 || start == 1) { 30 setExchFlag(0); 31 //偶数的数组长度,当start为1时,只有len/2-1个线程 32 CountDownLatch latch = new CountDownLatch(arr.length/2-(arr.length%2==0?start:0)); 33 for (int i = start; i < arr.length - 1; i += 2) { 34 pool.submit(new OddEvenSortTask(i,latch)); 35 } 36 //等待所有线程结束 37 latch.await(); 38 if (start == 0) 39 start = 1; 40 else 41 start = 0; 42 } 43 }
上述代码第9行,定义了奇偶排序的任务类。该任务的主要工作是进行数据比较和必要的交换(第18~23行)。并行排序的主体是pOddEvenSort()方法,它使用CountDownLatch记录线程数量,对于每一次迭代,使用单独的线程对每一次元素比较和交换进行操作。在下一次迭代开始前,必须等待上一次迭代所有线程的完成。
5.8.2 改进的插入排序:希尔排序
插入排序也是一种很常用的排序算法。它的基本思想是:一个未排序的数组(当然也可以是链表)可以分为两个部分,前半部分是已经排序的,后半部分是未排序的。在进行排序时,只需要在未排序的部分中选择一个元素,将其插入到前面有序的数组中即可。最终,未排序的部分会越来越少,直到为0,那么排序就完成了。初始时,可以假设已排序部分就是第一个元素。
插入排序的几次迭代示意如图5.14所示。
图5.14 插入排序示意图
插入排序的实现如下所示:
01 public static void insertSort(int[] arr) { 02 int length = arr.length; 03 int j, i, key; 04 for (i = 1; i < length; i++) { 05 //key为要准备插入的元素 06 key = arr[i]; 07 j = i - 1; 08 while (j >= 0 && arr[j] > key) { 09 arr[j + 1] = arr[j]; 10 j--; 11 } 12 //找到合适的位置 插入key 13 arr[j + 1] = key; 14 } 15 }
上述代码第6行,提取要准备插入的元素(也就是未排序序列中的第一个元素)。接着,在已排序队列中找到这个元素的插入位置(第8~10行),并进行插入(第13行)即可。
简单的插入排序是很难并行化的。因为这一次的数据插入依赖于上一次得到的有序序列,因此多个步骤之间无法并行。为此,我们可以对插入排序进行扩展,这就是希尔排序。
希尔排序将整个数组根据间隔h分割为若干个子数组。子数组相互穿插在一起,每一次排序时,分别对每一个子数组进行排序。如图5.15所示,当h为3时,希尔排序将整个数组分为交织在一起的三个子数组。其中,所有的方块为一个子数组,所有的圆形、三角形分别组成另外两个子数组。每次排序时,总是交换间隔为h的两个元素。
图5.15 h=3时的数组分割
在每一组排序完成后,可以递减h的值,进行下轮更加精细的排序。直到h为1,此时等价于一次插入排序。
希尔排序的一个主要优点是,即使一个较小的元素在数组的末尾,由于每次元素移动都以h为间隔进行,因此数组末尾的小元素可以在很少的交换次数下,就被置换到最接近元素最终位置的地方。
下面是希尔排序的串行实现:
01 public static void shellSort(int[] arr) { 02 // 计算出最大的h值 03 int h = 1; 04 while (h <= arr.length / 3) { 05 h = h * 3 + 1; 06 } 07 while (h > 0) { 08 for (int i = h; i < arr.length; i++) { 09 if (arr[i] < arr[i - h]) { 10 int tmp = arr[i]; 11 int j = i - h; 12 while (j >= 0 && arr[j] > tmp) { 13 arr[j + h] = arr[j]; 14 j -= h; 15 } 16 arr[j + h] = tmp; 17 } 18 } 19 // 计算出下一个h值 20 h = (h - 1) / 3; 21 } 22 }
上述代码第4~6行,计算一个合适的h值,接着正式进行希尔排序。第8行的for循环进行间隔为h的插入排序,每次排序结束后,递减h的值(第20行)。直到h为1,退化为插入排序。
很显然,希尔排序每次都针对不同的子数组进行排序,各个子数组之间是完全独立的。因此,很容易改写成并行程序:
01 public static class ShellSortTask implements Runnable { 02 int i = 0; 03 int h = 0; 04 CountDownLatch l; 05 06 public ShellSortTask(int i, int h, CountDownLatch latch) { 07 this.i = i; 08 this.h = h; 09 this.l = latch; 10 } 11 12 @Override 13 public void run() { 14 if (arr[i] < arr[i - h]) { 15 int tmp = arr[i]; 16 int j = i - h; 17 while (j >= 0 && arr[j] > tmp) { 18 arr[j + h] = arr[j]; 19 j -= h; 20 } 21 arr[j + h] = tmp; 22 } 23 l.countDown(); 24 } 25 } 26 27 public static void pShellSort(int[] arr) throws InterruptedException { 28 // 计算出最大的h值 29 int h = 1; 30 CountDownLatch latch = null; 31 while (h <= arr.length / 3) { 32 h = h * 3 + 1; 33 } 34 while (h > 0) { 35 System.out.println("h=" + h); 36 if (h >= 4) 37 latch = new CountDownLatch(arr.length - h); 38 for (int i = h; i < arr.length; i++) { 39 // 控制线程数量 40 if (h >= 4) { 41 pool.execute(new ShellSortTask(i, h, latch)); 42 } else { 43 if (arr[i] < arr[i - h]) { 44 int tmp = arr[i]; 45 int j = i - h; 46 while (j >= 0 && arr[j] > tmp) { 47 arr[j + h] = arr[j]; 48 j -= h; 49 } 50 arr[j + h] = tmp; 51 } 52 // System.out.println(Arrays.toString(arr)); 53 } 54 } 55 // 等待线程排序完成,进入下一次排序 56 latch.await(); 57 // 计算出下一个h值 58 h = (h - 1) / 3; 59 } 60 }
上述代码中定义ShellSortTask作为并行任务。一个ShellSortTask的作用是根据给定的起始位置和h,对子数组进行排序,因此可以完全并行化。
为控制线程数量,这里定义并行主函数pShellSort()在h大于或等于4时使用并行线程(第40行),否则则退化为传统的插入排序。
每次计算后,递减h的值(第58行)。
5.9 并行算法:矩阵乘法
我在第一章中已经提到,Linus认为并行程序目前只有在服务端程序和图像处理领域有发展的空间。且不论这种说法是否正确,但从中也可以看出并发对于这两个应用领域的重要性。而对于图像处理来说,矩阵运行是其中必不可少的重要数学方法。当然,除了图像处理,矩阵运算在神经网络、模式识别等领域也有着广泛的用途。在这里,我将向大家介绍矩阵运算的典型代表——矩阵乘法的并行化实现。
在矩阵乘法中,第一个矩阵的列数和第二个矩阵的行数必须是相同的。如图5.16所示,矩阵A和矩阵B相乘,其中矩阵A为4行2列,矩阵B为2行4列,它们相乘后,得到的是4行4列的矩阵,并且新矩阵中每一个元素为矩阵A和B对应行列的乘积求和。
图5.16 矩阵相乘示意图
如果需要进行并行计算,一种简单的策略是可以将A矩阵进行水平分割,得到子矩阵A1和A2,B矩阵进行垂直分割,得到子矩阵B1和B2。此时,我们只要分别计算这些子矩阵的乘积,将结果进行拼接,就能得到原始矩阵A和B的乘积。如图5.17所示,展示了这种并行计算的策略。
图5.17 矩阵拆分进行并行计算
当然,这个过程是可以反复进行的。为了计算A1*B1,我们还可以进一步将A1和B1进行分解,直到我们认为子矩阵的大小已经在可接受范围内。
这里,我们使用ForkJoin框架来实现这个并行矩阵相乘的想法。为了方便矩阵计算,我们使用jMatrices开源软件,作为矩阵计算的工具。其中,使用的主要API如下:
Matrix:代表一个矩阵
MatrixOperator.multiply(Matrix, Matrix):矩阵相乘
Matrix.row():获得矩阵的行数
Matrix.getSubMatrix():获得矩阵的子矩阵
MatrixOperator.horizontalConcatenation(Matrix,Matrix):将两个矩阵进行水平连接
MatrixOperator.verticalConcatenation(Matrix,Matrix):将两个矩阵进行垂直连接
为了计算矩阵乘法,定义一个任务类MatrixMulTask。它会进行矩阵相乘的计算,如果输入矩阵的粒度比较大,则会再次进行任务分解:
01 public class MatrixMulTask extends RecursiveTask<Matrix> { 02 Matrix m1; 03 Matrix m2; 04 String pos; 05 06 public MatrixMulTask(Matrix m1, Matrix m2, String pos) { 07 this.m1 = m1; 08 this.m2 = m2; 09 this.pos = pos; 10 } 11 12 @Override 13 protected Matrix compute() { 14 //System.out.println(Thread.currentThread().getId()+":"+Thread.currentThread(). getName() + " is start"); 15 if (m1.rows() <= PMatrixMul.granularity || m2.cols() <= PMatrixMul.granularity) { 16 Matrix mRe = MatrixOperator.multiply(m1, m2); 17 return mRe; 18 } else { 19 // 如果不是,那么继续分割矩阵 20 int rows; 21 rows = m1.rows(); 22 // 左乘的矩阵横向分割 23 Matrix m11 = m1.getSubMatrix(1, 1, rows / 2, m1.cols()); 24 Matrix m12 = m1.getSubMatrix(rows / 2 + 1, 1, m1.rows(), m1.cols()); 25 // 右乘矩阵纵向分割 26 Matrix m21 = m2.getSubMatrix(1, 1, m2.rows(), m2.cols() / 2); 27 Matrix m22 = m2.getSubMatrix(1, m2.cols() / 2 + 1, m2.rows(), m2.cols()); 28 29 ArrayList<MatrixMulTask> subTasks = new ArrayList<MatrixMulTask>(); 30 MatrixMulTask tmp = null; 31 tmp = new MatrixMulTask(m11, m21, "m1"); 32 subTasks.add(tmp); 33 tmp = new MatrixMulTask(m11, m22, "m2"); 34 subTasks.add(tmp); 35 tmp = new MatrixMulTask(m12, m21, "m3"); 36 subTasks.add(tmp); 37 tmp = new MatrixMulTask(m12, m22, "m4"); 38 subTasks.add(tmp); 39 for (MatrixMulTask t : subTasks) { 40 t.fork(); 41 } 42 Map<String, Matrix> matrixMap = new HashMap<String, Matrix>(); 43 for (MatrixMulTask t : subTasks) { 44 matrixMap.put(t.pos, t.join()); 45 } 46 Matrix tmp1 = MatrixOperator.horizontalConcatenation(matrixMap.get("m1"), matrixMap.get("m2")); 47 Matrix tmp2 = MatrixOperator.horizontalConcatenation(matrixMap.get("m3"), matrixMap.get("m4")); 48 Matrix reM = MatrixOperator.verticalConcatenation(tmp1, tmp2); 49 return reM; 50 } 51 } 52 }
MatrixMulTask类由三个参数构成,分别是需要计算的矩阵双方,以及计算结果位于父矩阵相乘结果中的位置,如图5.18所示。
图5.18 矩阵分解方式
MatrixMulTask中的成员变量m1和m2表示要相乘的两个矩阵,pos表示这个乘积结果在父矩阵相乘结果中所处的位置,有m1、m2、m3和m4等四种。代码第23~27行先对矩阵进行分割,分割后得到m11、m12、m21和m22等四个矩阵,并将它们按照如图5.18所示的规则进行子任务的创建。在第39~41行,计算这些子任务。在子任务返回后,在第42~48行将返回的四个矩阵m1、m2、m3和m4拼接成新的矩阵作为最终结果。
如果矩阵的粒度足够小就直接进行运算而不进行分解(第16行)。
使用这个任务类可以很容易地进行矩阵并行运算,下面是使用方法:
01 public static final int granularity=3; 02 public static void main(String[] args) throws InterruptedException, ExecutionException { 03 ForkJoinPool forkJoinPool = new ForkJoinPool(); 04 Matrix m1=MatrixFactory.getRandomIntMatrix(300, 300, null); 05 Matrix m2=MatrixFactory.getRandomIntMatrix(300, 300, null); 06 MatrixMulTask task=new MatrixMulTask(m1,m2,null); 07 ForkJoinTask<Matrix> result = forkJoinPool.submit(task); 08 Matrix pr=result.get(); 09 System.out.println(pr); 10 }
上述代码中第4~5行创建两个300*300的随机矩阵。构造矩阵计算任务MatrixMulTask并将其提交给ForkJoinPool线程池。第8行执行ForkJoinTask.get()方法等待并获得最终结果。
5.10 准备好了再通知我:网络NIO
Java NIO是New IO的简称,它是一种可以替代Java IO的一套新的IO机制。它提供了一套不同于Java标准IO的操作机制。严格来说,NIO与并发并无直接的关系。但是,使用NIO技术可以大大提高线程的使用效率。
Java NIO中涉及的基础内容有通道(Channel)和缓冲区(Buffer)、文件IO和网络IO。有关通道、缓冲区以及文件IO在这里不打算进行详细的介绍,大家可以参考本章的参考文献。在这里,我想多花一点时间详细介绍一下有关网络IO的内容。
对于标准的网络IO来说,我们会使用Socket进行网络的读写。为了让服务器可以支持更多的客户端连接,通常的做法是为每一个客户端连接开启一个线程。让我们先回顾一下这方面的内容。
5.10.1 基于Socket的服务端的多线程模式
这里,我以一个简单的Echo服务器为例。对于Echo服务器,它会读取客户端的一个输入,并将这个输入原封不动地返回给客户端。这看起来很简单,但是麻雀虽小五脏俱全。为了完成这个功能,服务器还是需要有一套完整的Socket处理机制。因此,这个Echo服务器非常适合来进行学习。实际上,我认为任何业务逻辑简单的系统都很适合学习,大家不用为了去理解业务上复杂的功能而忽略了系统的重点。
服务端使用多线程进行处理时的结构示意图,如图5.19所示。
图5.19 多线程的服务端
服务器会为每一个客户端连接启用一个线程,这个新的线程将全心全意为这个客户端服务。同时,为了接受客户端连接,服务器还会额外使用一个派发线程。
下面的代码实现了这个服务器:
01 public class MultiThreadEchoServer { 02 private static ExecutorService tp=Executors.newCachedThreadPool(); 03 static class HandleMsg implements Runnable{ 04 Socket clientSocket; 05 public HandleMsg(Socket clientSocket){ 06 this.clientSocket=clientSocket; 07 } 08 09 public void run(){ 10 BufferedReader is =null; 11 PrintWriter os = null; 12 try { 13 14 is = new BufferedReader(new InputStreamReader(clientSocket.getInputStream())); 15 os = new PrintWriter(clientSocket.getOutputStream(), true); 16 // 从InputStream当中读取客户端所发送的数据 17 String inputLine = null; 18 long b=System.currentTimeMillis(); 19 while ((inputLine = is.readLine()) != null) { 20 os.println(inputLine); 21 } 22 long e=System.currentTimeMillis(); 23 System.out.println("spend:"+(e-b)+"ms"); 24 } catch (IOException e) { 25 e.printStackTrace(); 26 }finally{ 27 try { 28 if(is!=null)is.close(); 29 if(os!=null)os.close(); 30 clientSocket.close(); 31 } catch (IOException e) { 32 e.printStackTrace(); 33 } 34 } 35 } 36 } 37 public static void main(String args[]) { 38 ServerSocket echoServer = null; 39 Socket clientSocket = null; 40 try { 41 echoServer = new ServerSocket(8000); 42 } catch (IOException e) { 43 System.out.println(e); 44 } 45 while (true) { 46 try { 47 clientSocket = echoServer.accept(); 48 System.out.println(clientSocket.getRemoteSocketAddress() + " connect!"); 49 tp.execute(new HandleMsg(clientSocket)); 50 } catch (IOException e) { 51 System.out.println(e); 52 } 53 } 54 } 55 }
第2行,我们使用了一个线程池来处理每一个客户端连接。第3~33行,定义了HandleMsg线程,它由一个客户端Socket构造而成,它的任务是读取这个Socket的内容并将其进行返回,返回成功后,任务完成,客户端Soceket就被正常关闭。其中第23行,统计并输出了服务端线程处理一次客户端请求所花费的时间(包括读取数据和回写数据的时间)。主线程main的主要作用是在8000端口上进行等待。一旦有新的客户端连接,它就根据这个连接创建HandleMsg线程进行处理(第47~49行)。
这就是一个支持多线程的服务端的核心内容。它的特点是,在相同可支持的线程范围内,可以尽量多地支持客户端的数量,同时和单线程服务器相比,它也可以更好地使用多核CPU。
为了方便大家学习,这里再给出一个客户端的参考实现:
01 public static void main(String[] args) throws IOException { 02 Socket client = null; 03 PrintWriter writer = null; 04 BufferedReader reader = null; 05 try { 06 client = new Socket(); 07 client.connect(new InetSocketAddress("localhost", 8000)); 08 writer = new PrintWriter(client.getOutputStream(), true); 09 writer.println("Hello!"); 10 writer.flush(); 11 12 reader = new BufferedReader(new InputStreamReader(client.getInputStream())); 13 System.out.println("from server: " + reader.readLine()); 14 } catch (UnknownHostException e) { 15 e.printStackTrace(); 16 } catch (IOException e) { 17 e.printStackTrace(); 18 } finally { 19 if (writer != null) 20 writer.close(); 21 if (reader != null) 22 reader.close(); 23 if (client != null) 24 client.close(); 25 } 26 }
上述代码在第7行,连接了服务器的8000端口,并发送字符串。接着在第12行,读取服务器的返回信息并进行输出。
可以说,这种多线程的服务器开发模式是极其常用的。对于绝大多数应用来说,这种模式可以很好地工作。但是,如果你想让你的程序工作得更加有效,就必须知道这种模式的一个重大弱点——那就是它倾向于让CPU进行IO等待。为了理解这一点,让我们看一下下面这个比较极端的例子:
01 public class HeavySocketClient { 02 private static ExecutorService tp=Executors.newCachedThreadPool(); 03 private static final int sleep_time=1000*1000*1000; 04 public static class EchoClient implements Runnable{ 05 public void run(){ 06 Socket client = null; 07 PrintWriter writer = null; 08 BufferedReader reader = null; 09 try { 10 client = new Socket(); 11 client.connect(new InetSocketAddress("localhost", 8000)); 12 writer = new PrintWriter(client.getOutputStream(), true); 13 writer.print("H"); 14 LockSupport.parkNanos(sleep_time); 15 writer.print("e"); 16 LockSupport.parkNanos(sleep_time); 17 writer.print("l"); 18 LockSupport.parkNanos(sleep_time); 19 writer.print("l"); 20 LockSupport.parkNanos(sleep_time); 21 writer.print("o"); 22 LockSupport.parkNanos(sleep_time); 23 writer.print("!"); 24 LockSupport.parkNanos(sleep_time); 25 writer.println(); 26 writer.flush(); 27 28 reader = new BufferedReader(new InputStreamReader(client.getInputStream())); 29 System.out.println("from server: " + reader.readLine()); 30 } catch (UnknownHostException e) { 31 e.printStackTrace(); 32 } catch (IOException e) { 33 e.printStackTrace(); 34 } finally { 35 try { 36 if (writer != null) 37 writer.close(); 38 if (reader != null) 39 reader.close(); 40 if (client != null) 41 client.close(); 42 } catch (IOException e) { 43 } 44 } 45 } 46 } 47 public static void main(String[] args) throws IOException { 48 EchoClient ec=new EchoClient(); 49 for(int i=0;i<10;i++) 50 tp.execute(ec); 51 } 52 }
上述代码定义了一个新的客户端,它会进行10次请求(第49~50行开启10个线程)。每一次请求都会访问8000端口。连接成功后,会向服务器输出“Hello!”字符串(第13~26行),但是在这一次交互中,客户端会慢慢地进行输出,每次只输出一个字符,之后进行1秒的等待。因此,整个过程会持续6秒。
开启多线程池的服务器和上述客户端。服务器端的部分输出如下:
spend:6000ms spend:6000ms spend:6000ms spend:6001ms spend:6002ms spend:6002ms spend:6002ms spend:6002ms spend:6003ms spend:6003ms
可以看到,对于服务端来说,每一个请求的处理时间都在6秒左右。这很容易理解,因为服务器要先读入客户端的输入,而客户端缓慢的处理速度(当然也可能是一个拥塞的网络环境)使得服务器花费了不少等待时间。
我们可以试想一下,如果服务器要处理大量的请求连接,每个请求如果都像这样拖慢了服务器的处理速度,那么服务端能够处理的并发数量就会大幅度减少。反之,如果服务器每次都能很快地处理一次请求,那么相对的,它的并发能力就能上升。
在这个案例中,服务器处理请求之所以慢,并不是因为在服务端有多少繁重的任务,而仅仅是因为服务线程在等待IO而已。让高速运转的CPU去等待极其低效的网络IO是非常不合算的行为。那么,我们是不是可以想一个方法,将网络IO的等待时间从线程中分离出来呢?
5.10.2 使用NIO进行网络编程
使用Java的NIO就可以将上面的网络IO等待时间从业务处理线程中抽取出来。那么NIO是什么,它又是如何工作的呢?
要了解NIO,我们首先需要知道在NIO中的一个关键组件Channel(通道)。Channel有点类似于流,一个Channel可以和文件或者网络Socket对应。如果Channel对应着一个Soceket,那么往这个Channel中写数据,就等同于向Socket中写入数据。
和Channel一起使用的另外一个重要组件就是Buffer。大家可以简单地把Buffer理解成一个内存区域或者byte数组。数据需要包装成Buffer的形式才能和Channel交互(写入或者读取)。
另外一个与Channel密切相关的是Selector(选择器)。在Channel的众多实现中,有一个SelectableChannel实现,表示可被选择的通道。任何一个SelectableChannel都可以将自己注册到一个Selector中。这样,这个Channel就能被Selector所管理。而一个Selector可以管理多个SelectableChannel。当SelectableChannel的数据准备好时,Selector就会接到通知,得到那些已经准备好的数据。而SocketChannel就是SelectableChannel的一种。因此,它们构成了如图5.20所示的结构。
图5.20 Selector和Channel
大家可以看到,一个Selector可以由一个线程进行管理,而一个SocketChannel则可以表示一个客户端连接,因此这就构成由一个或者极少数线程,来处理大量客户端连接的结构。当与客户端连接的数据没有准备好时,Selector会处于等待状态(不过幸好,用于管理Selector的线程数是极少量的),而一旦有任何一个SocketChannel准备好了数据,Selector就能立即得到通知,获取数据进行处理。
下面就让我们用NIO来重新构造这个多线程的Echo服务器吧!
首先,我们需要定义一个Selector和线程池:
private Selector selector; private ExecutorService tp=Executors.newCachedThreadPool();
其中,selector用于处理所有的网络连接。线程池tp用于对每一个客户端进行相应的处理,每一个请求都会委托给线程池中的线程进行实际的处理。