题目
假设将数据(没有重复的数据)存放在一个较大的数组里,利用多线程并发的方式再数组中查找数据(例如数组的大小为 100,如果两个线程的话,一个线程就从 1 到 50 之间找,另外一个线程就在 51 到 100 之间找),如果某一个线程找到该数据,其他线程提前终止(思考一下,如何实现)。
思路
我现在的内心是崩溃的,我花了大部分的时间完成了一件看起来相当愚蠢的事,不过过程还好,就当是学习多线程程序的优化了。
首先看到这个题目。。不知道怎么描述, 非 IO 密集型应用使用多线程的情况非常少,一般是为了防止程序的阻塞, 然而像这种没有 IO 也不需要考虑阻塞的情况就忍不住让人考虑到底有没有必要使用多线程写程序。
考虑的结果是没必要, 但是还是要这么做,毕竟题目这么来的,假设有必要吧,唉
双线程版本
开始之前我们自行模拟出一个没有重复数据的数组, 定义为
1 2 static final int bound = 1000000 ; static int [] nums = new int [bound];
对nums
进行初始化
1 2 3 4 5 static void initNums () { for (int i = 0 ; i < bound; i++) { nums[i] = i; } }
看起来还不错,数组的下标值和数组值完全相同,整个数组完全有序,不过并不影响我们的这个程序,因为
查找使用暴力的遍历方式,而不进行任何算法优化,因为我们并不能对这个数组做任何假设
要查找的数值target
将随机生成
很容易完成两个线程的版本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 package assignment5;import java.util.Random;public class Main { final static int bound = 1000000 ; static int [] nums = new int [bound]; public static void main (String[] args) throws InterruptedException { initNums(); Random random = new Random (12344533 ); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < 10000 ; i++) { result = null ; int target = random.nextInt(bound); { Thread t1 = new Thread (() -> getIndexOfNums(0 , bound / 2 , target)); t1.start(); getIndexOfNums(bound / 2 + 1 , bound, target); t1.join(); } if (target != result) System.out.printf("expect: '%d', get result '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("Total execution time: " + (endTime - startTime) + "ms" ); } static void initNums () { for (int i = 0 ; i < bound; i++) { nums[i] = i; } } static Integer result; static void getIndexOfNums (int start, int end, int target) { for (; start < end; start++) { if (result != null ) return ; if (nums[start] == target) { result = start; return ; } } } }
getIndexOfNums
接收数组的两个边界值和一个target
值, 这个程序会在每次遍历的同时进行结果result
的检查,如果result
不是null
, 那么说明结果已经被其它线程计算出来了可以立即终止计算
random
的生成使用了一个常量作为seed
, 每次的运行产生的随机数序列 相同
我们判断了想要的结果与得到的结果不一致的情况,目的是确保我们的程序至少不会出现结果的错误(如果出现其中可能暗含了多线程访问的内存安全问题)
注意被注释掉的代码, 和上面的代码,我们可以分别注释上面的和下面的用来对比单线程和双线程执行所用的时间
结果很草, 使用单线程所花费的时间要更少,无论bound
的数值多大或者多小, 创建线程的开销在此的占比依旧很大。
如果重复使用单个线程的情况会怎样?
优化版本
为了方便测试, 我把各个benchmark
放在了不同的静态方法中
其中单线程的测试方法代码为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 static final int testTimes = 10000 ; static final long seed = 345423533 ; static void benchmarkSingleThread () { Random random = new Random (seed); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); getIndexOfNums(0 , bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[SINGLE] Total execution time: " + (endTime - startTime) + "ms" ); }
而刚刚实现的 2 个线程的benchmark
如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 static void benchmarkTwoThread () throws InterruptedException, BrokenBarrierException { Random random = new Random (seed); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); Thread t1 = new Thread (() -> getIndexOfNums(0 , bound / 2 , target)); t1.start(); getIndexOfNums(bound / 2 + 1 , bound, target); t1.join(); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[DOUBLE] Total execution time: " + (endTime - startTime) + "ms" ); }
为了实现重复使用线程,我们需要设置一个消息传递的方式,在 Java 中,适合这种场景的一个接口(在这里也可以成为数据结构)是BlockingQueue
, BlockingQueue
与Queue
的区别在于BlockingQueue
需要实现阻塞式的数据检索(take
方法),当然不止如此,我们这里用到的只有这个
执行take
方法时, BlockingQueue
会阻塞直到一个可用的值出现,这作为消息传递的工具非常适合, 功能类似于golang
里的channel
(用于多个协程的消息传递)
我们创建一个包含一个单位buffer
的 BlockingQueue
1 BlockingQueue<Integer> bQueue = new ArrayBlockingQueue <>(1 );
我们的consumer
1 2 3 4 5 6 7 8 Thread t1 = new Thread (() -> { try { while (true ) getIndexOfNumsWithBarrier(0 , bound / 2 , bQueue.take()); } catch (InterruptedException | BrokenBarrierException e) { } }); t1.start();
这里的 consumer 线程执行时将不断地从bQueue
里读取想要的target
,如果得到这个值将会执行getIndexOfNumsWithBarrier
方法
我们的producer
1 2 3 4 5 6 7 8 for (int j = 1 ; j < threadNum; j++) { int target = random.nextInt(bound); bQueue.put(target); }
getIndexOfNumsWithBarrier
利用CyclicBarrier
在方法执行完毕前阻塞以实现结果同步, 在上面我们使用join
方法实现结果同步,这里t1
将会永远的执行下去(除非我们手动终结它),所以只能使用其它方式进行同步,barrier
在这里相当合适
1 2 3 4 5 6 7 8 9 10 11 static CyclicBarrier barrier = new CyclicBarrier (2 );static void getIndexOfNumsWithBarrier (int start, int end, int target) throws InterruptedException, BrokenBarrierException { getIndexOfNums(start, end, target); barrier.await(); return ; } }
后面继续另外一个线程(当前线程)的计算
1 2 3 4 5 getIndexOfNumsWithBarrier(bound / 2 + 1 , bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result);
我们的benchmark
的全部代码如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 static void benchmarkTwoThreadWithMessagePassing () throws InterruptedException, BrokenBarrierException { BlockingQueue<Integer> bQueue = new ArrayBlockingQueue <>(1 ); Random random = new Random (seed); Thread t1 = new Thread (() -> { try { while (true ) { getIndexOfNumsWithBarrier(0 , bound / 2 , bQueue.take()); } } catch (InterruptedException | BrokenBarrierException e) { } }); t1.start(); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); bQueue.put(target); getIndexOfNumsWithBarrier(bound / 2 + 1 , bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[DOUBLEP] Total execution time: " + (endTime - startTime) + "ms" ); t1.interrupt(); t1.join(); }
看起来还不错,执行结果示例如下
1 2 3 [SINGLE] Total execution time: 1709ms [DOUBLE] Total execution time: 3615ms [DOUBLEP] Total execution time: 1756ms
(脏话。。。。), 多次测试测试 DOUBLEP 表现最好的一次了,当然速度依旧不如单线程
多线程版本
难道是线程数不够吗? 考虑到这种可能性我将代码改成了可以自定义线程数的形式,这是最终形式, 其中你可以通过修改threadNum
和testTimes
或bound
进行不同数据的benchmark
,只是结果…至少我测试的结果并不是很理想, 似乎在我的机器上当threadNum
为 4 时,第三种方式(benchmarkMultipleThreadWithMessagePassing)的表现是最好的
结果示例:
1 2 3 [SINGLE] Total execution time: 1454ms [MULTIT] Total execution time: 4974ms [MULTITP] Total execution time: 1587ms
嘛,偶尔还是会超过单线程的,但是平均下来也就这样了。
另外多次重复创建线程的版本明显更慢了
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 package assignment5;import java.util.ArrayList;import java.util.List;import java.util.Random;import java.util.concurrent.ArrayBlockingQueue;import java.util.concurrent.BlockingQueue;import java.util.concurrent.BrokenBarrierException;import java.util.concurrent.CyclicBarrier;public class Main { static final int threadNum = 4 ; static final int testTimes = 10000 ; static final long seed = 345423533 ; static final int bound = 1000000 ; static int [] nums = new int [bound]; static final int step = bound / threadNum; static Integer result; static void initNums () { for (int i = 0 ; i < bound; i++) { nums[i] = i; } } public static void main (String[] args) throws InterruptedException, BrokenBarrierException { initNums(); benchmarkSingleThread(); benchmarkMultipleThread(); benchmarkMultipleThreadWithMessagePassing(); } static void benchmarkSingleThread () { Random random = new Random (seed); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); getIndexOfNums(0 , bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[SINGLE] Total execution time: " + (endTime - startTime) + "ms" ); } static void benchmarkMultipleThread () throws InterruptedException, BrokenBarrierException { Random random = new Random (seed); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); int tmp = 0 ; for (int j = 1 ; j < threadNum; j++) { int startPoint = tmp; new Thread (() -> { try { getIndexOfNumsWithBarrier(startPoint, startPoint + step, target); } catch (InterruptedException | BrokenBarrierException e) { e.printStackTrace(); } }).start(); tmp += step; } getIndexOfNumsWithBarrier(tmp, bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[MULTIT] Total execution time: " + (endTime - startTime) + "ms" ); } static void benchmarkMultipleThreadWithMessagePassing () throws InterruptedException, BrokenBarrierException { Random random = new Random (seed); List<Thread> threads = new ArrayList <>(); List<BlockingQueue<Integer>> queues = new ArrayList <>(); int tmp = 0 ; for (int j = 1 ; j < threadNum; j++) { queues.add(new ArrayBlockingQueue <>(1 )); int startPoint = tmp; int qIndex = j - 1 ; threads.add(new Thread (() -> { try { while (true ) { getIndexOfNumsWithBarrier(startPoint, startPoint + step, queues.get(qIndex).take()); } } catch (InterruptedException | BrokenBarrierException e) { } })); tmp += step; } for (Thread thread : threads) thread.start(); long startTime = System.currentTimeMillis(); for (int i = 0 ; i < testTimes; i++) { result = null ; int target = random.nextInt(bound); for (BlockingQueue<Integer> bQueue : queues) { bQueue.put(target); } getIndexOfNumsWithBarrier(tmp, bound, target); if (target != result) System.out.printf("expect: '%d', but got val '%d'" , target, result); } long endTime = System.currentTimeMillis(); System.out.println("[MULTITP] Total execution time: " + (endTime - startTime) + "ms" ); for (Thread thread : threads) thread.interrupt(); } static void getIndexOfNums (int start, int end, int target) { for (; start < end; start++) { if (result != null ) return ; if (nums[start] == target) { result = start; return ; } } } static CyclicBarrier barrier = new CyclicBarrier (threadNum); static void getIndexOfNumsWithBarrier (int start, int end, int target) throws InterruptedException, BrokenBarrierException { getIndexOfNums(start, end, target); barrier.await(); return ; } }
结论
慎用多线程
使用多线程时要考虑线程创建和计算机多线程间context
切换时消耗的资源。 Sometimes, less is more.