题目

假设将数据(没有重复的数据)存放在一个较大的数组里,利用多线程并发的方式再数组中查找数据(例如数组的大小为 100,如果两个线程的话,一个线程就从 1 到 50 之间找,另外一个线程就在 51 到 100 之间找),如果某一个线程找到该数据,其他线程提前终止(思考一下,如何实现)。

思路

我现在的内心是崩溃的,我花了大部分的时间完成了一件看起来相当愚蠢的事,不过过程还好,就当是学习多线程程序的优化了。

首先看到这个题目。。不知道怎么描述, 非 IO 密集型应用使用多线程的情况非常少,一般是为了防止程序的阻塞, 然而像这种没有 IO 也不需要考虑阻塞的情况就忍不住让人考虑到底有没有必要使用多线程写程序。

考虑的结果是没必要, 但是还是要这么做,毕竟题目这么来的,假设有必要吧,唉

双线程版本

开始之前我们自行模拟出一个没有重复数据的数组, 定义为

1
2
static final int bound = 1000000; // length of given nums
static int[] nums = new int[bound];

nums进行初始化

1
2
3
4
5
static void initNums() {
for (int i = 0; i < bound; i++) {
nums[i] = i;
}
}

看起来还不错,数组的下标值和数组值完全相同,整个数组完全有序,不过并不影响我们的这个程序,因为

  • 查找使用暴力的遍历方式,而不进行任何算法优化,因为我们并不能对这个数组做任何假设
  • 要查找的数值target将随机生成

很容易完成两个线程的版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package assignment5;

import java.util.Random;

public class Main {
final static int bound = 1000000;
static int[] nums = new int[bound];

public static void main(String[] args) throws InterruptedException {
initNums();

// create random with a constant long number, use for test
Random random = new Random(12344533);

long startTime = System.currentTimeMillis();
for (int i = 0; i < 10000; i++) {
result = null; // reset result val

int target = random.nextInt(bound);

{
Thread t1 = new Thread(() -> getIndexOfNums(0, bound / 2, target));
t1.start();

getIndexOfNums(bound / 2 + 1, bound, target);
t1.join(); // wait until t1 is done
}

/* 下面被注释掉的是单线程版本 */
// getIndexOfNums(0, bound, target);

if (target != result)
System.out.printf("expect: '%d', get result '%d'", target, result);
}
long endTime = System.currentTimeMillis();
System.out.println("Total execution time: " + (endTime - startTime) + "ms");
}

static void initNums() {
for (int i = 0; i < bound; i++) {
nums[i] = i;
}
}

static Integer result;

static void getIndexOfNums(int start, int end, int target) {
for (; start < end; start++) {
if (result != null)
return;

if (nums[start] == target) {
result = start;
return;
}
}
}
}
  • getIndexOfNums接收数组的两个边界值和一个target值, 这个程序会在每次遍历的同时进行结果result的检查,如果result不是null, 那么说明结果已经被其它线程计算出来了可以立即终止计算
  • random的生成使用了一个常量作为seed, 每次的运行产生的随机数序列相同
  • 我们判断了想要的结果与得到的结果不一致的情况,目的是确保我们的程序至少不会出现结果的错误(如果出现其中可能暗含了多线程访问的内存安全问题)

注意被注释掉的代码, 和上面的代码,我们可以分别注释上面的和下面的用来对比单线程和双线程执行所用的时间

结果很草, 使用单线程所花费的时间要更少,无论bound的数值多大或者多小, 创建线程的开销在此的占比依旧很大。

如果重复使用单个线程的情况会怎样?

优化版本

为了方便测试, 我把各个benchmark放在了不同的静态方法中

其中单线程的测试方法代码为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
static final int testTimes = 10000; // how many times we will do for benchmark
static final long seed = 345423533; // constant seed, for test only

static void benchmarkSingleThread() {
Random random = new Random(seed);

/* start benchmark */
long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);

getIndexOfNums(0, bound, target);

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();
/* end benchmark */

System.out.println("[SINGLE] Total execution time: " + (endTime - startTime) + "ms");
}

而刚刚实现的 2 个线程的benchmark如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
static void benchmarkTwoThread() throws InterruptedException, BrokenBarrierException {
Random random = new Random(seed);

long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);

Thread t1 = new Thread(() -> getIndexOfNums(0, bound / 2, target));
t1.start();

getIndexOfNums(bound / 2 + 1, bound, target);
t1.join(); // wait until t1 is done

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();

System.out.println("[DOUBLE] Total execution time: " + (endTime - startTime) + "ms");
}

为了实现重复使用线程,我们需要设置一个消息传递的方式,在 Java 中,适合这种场景的一个接口(在这里也可以成为数据结构)是BlockingQueueBlockingQueueQueue的区别在于BlockingQueue需要实现阻塞式的数据检索(take方法),当然不止如此,我们这里用到的只有这个

执行take方法时, BlockingQueue会阻塞直到一个可用的值出现,这作为消息传递的工具非常适合, 功能类似于golang里的channel(用于多个协程的消息传递)

我们创建一个包含一个单位bufferBlockingQueue

1
BlockingQueue<Integer> bQueue = new ArrayBlockingQueue<>(1);

我们的consumer

1
2
3
4
5
6
7
8
Thread t1 = new Thread(() -> {
try {
while (true)
getIndexOfNumsWithBarrier(0, bound / 2, bQueue.take());
} catch (InterruptedException | BrokenBarrierException e) {
}
});
t1.start();

这里的 consumer 线程执行时将不断地从bQueue里读取想要的target,如果得到这个值将会执行getIndexOfNumsWithBarrier方法

我们的producer

1
2
3
4
5
6
7
8
for (int j = 1; j < threadNum; j++) {
// do something

int target = random.nextInt(bound);
bQueue.put(target);

// do something
}

getIndexOfNumsWithBarrier利用CyclicBarrier在方法执行完毕前阻塞以实现结果同步, 在上面我们使用join方法实现结果同步,这里t1将会永远的执行下去(除非我们手动终结它),所以只能使用其它方式进行同步,barrier在这里相当合适

1
2
3
4
5
6
7
8
9
10
11
static CyclicBarrier barrier = new CyclicBarrier(2);

static void getIndexOfNumsWithBarrier(int start, int end, int target)
throws InterruptedException, BrokenBarrierException {

getIndexOfNums(start, end, target);

barrier.await(); // 当两个线程到达这里时才会释放
return;
}
}

后面继续另外一个线程(当前线程)的计算

1
2
3
4
5
getIndexOfNumsWithBarrier(bound / 2 + 1, bound, target);

// 这里我们可以确保结果已经得出,因此直接执行判断
if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);

我们的benchmark的全部代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
 static void benchmarkTwoThreadWithMessagePassing() throws InterruptedException, BrokenBarrierException {
BlockingQueue<Integer> bQueue = new ArrayBlockingQueue<>(1);
Random random = new Random(seed);

// start thread t1
Thread t1 = new Thread(() -> {
try {
while (true) {
getIndexOfNumsWithBarrier(0, bound / 2, bQueue.take());
}
} catch (InterruptedException | BrokenBarrierException e) {
}
});
t1.start();

long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);
bQueue.put(target);

getIndexOfNumsWithBarrier(bound / 2 + 1, bound, target);

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();

System.out.println("[DOUBLEP] Total execution time: " + (endTime - startTime) + "ms");

t1.interrupt();
t1.join();
}

看起来还不错,执行结果示例如下

1
2
3
[SINGLE] Total execution time: 1709ms
[DOUBLE] Total execution time: 3615ms
[DOUBLEP] Total execution time: 1756ms

(脏话。。。。), 多次测试测试 DOUBLEP 表现最好的一次了,当然速度依旧不如单线程

多线程版本

难道是线程数不够吗? 考虑到这种可能性我将代码改成了可以自定义线程数的形式,这是最终形式, 其中你可以通过修改threadNumtestTimesbound进行不同数据的benchmark,只是结果…至少我测试的结果并不是很理想, 似乎在我的机器上当threadNum为 4 时,第三种方式(benchmarkMultipleThreadWithMessagePassing)的表现是最好的

结果示例:

1
2
3
[SINGLE] Total execution time: 1454ms
[MULTIT] Total execution time: 4974ms
[MULTITP] Total execution time: 1587ms

嘛,偶尔还是会超过单线程的,但是平均下来也就这样了。
另外多次重复创建线程的版本明显更慢了

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package assignment5;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

public class Main {
static final int threadNum = 4; // number of threads

static final int testTimes = 10000; // how many times we will do for benchmark
static final long seed = 345423533; // constant seed, for test only

static final int bound = 1000000; // length of given nums
static int[] nums = new int[bound];

static final int step = bound / threadNum;

static Integer result; // to store result val

static void initNums() {
for (int i = 0; i < bound; i++) {
nums[i] = i;
}
}

public static void main(String[] args) throws InterruptedException, BrokenBarrierException {
initNums();

benchmarkSingleThread();
benchmarkMultipleThread();
benchmarkMultipleThreadWithMessagePassing();
}

static void benchmarkSingleThread() {
Random random = new Random(seed);

/* start benchmark */
long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);

getIndexOfNums(0, bound, target);

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();
/* end benchmark */

System.out.println("[SINGLE] Total execution time: " + (endTime - startTime) + "ms");
}

static void benchmarkMultipleThread() throws InterruptedException, BrokenBarrierException {
Random random = new Random(seed);

/* start benchmark */
long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);

int tmp = 0;
for (int j = 1; j < threadNum; j++) {
int startPoint = tmp;

new Thread(() -> {
try {
getIndexOfNumsWithBarrier(startPoint, startPoint + step, target);
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
}).start();

tmp += step;
}

getIndexOfNumsWithBarrier(tmp, bound, target);

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();
/* end benchmark */

System.out.println("[MULTIT] Total execution time: " + (endTime - startTime) + "ms");
}

static void benchmarkMultipleThreadWithMessagePassing() throws InterruptedException, BrokenBarrierException {
Random random = new Random(seed);

List<Thread> threads = new ArrayList<>();
List<BlockingQueue<Integer>> queues = new ArrayList<>();

// create consumer threads
int tmp = 0;
for (int j = 1; j < threadNum; j++) {
queues.add(new ArrayBlockingQueue<>(1));

int startPoint = tmp;
int qIndex = j - 1;
threads.add(new Thread(() -> {
try {
while (true) {
getIndexOfNumsWithBarrier(startPoint, startPoint + step, queues.get(qIndex).take());
}
} catch (InterruptedException | BrokenBarrierException e) {
}
}));

tmp += step;
}

// start all threads
for (Thread thread : threads)
thread.start();

/* start benchmark */
long startTime = System.currentTimeMillis();
for (int i = 0; i < testTimes; i++) {
result = null; // reset result val

int target = random.nextInt(bound);
for (BlockingQueue<Integer> bQueue : queues) {
bQueue.put(target);
}

getIndexOfNumsWithBarrier(tmp, bound, target);

if (target != result)
System.out.printf("expect: '%d', but got val '%d'", target, result);
}
long endTime = System.currentTimeMillis();
/* end benchmark */

System.out.println("[MULTITP] Total execution time: " + (endTime - startTime) + "ms");

for (Thread thread : threads)
thread.interrupt();
}

/**
* get index of nums only when result is null
*
* @param start begin bound
* @param end end bound
* @param target target val
*/
static void getIndexOfNums(int start, int end, int target) {
for (; start < end; start++) {
if (result != null)
return;

if (nums[start] == target) {
result = start;
return;
}
}
}

static CyclicBarrier barrier = new CyclicBarrier(threadNum);

/**
* just like getIndexOfNums, but will done together when one of two threads
* found result
*
* @param start
* @param end
* @param target
* @throws InterruptedException
* @throws BrokenBarrierException
*/
static void getIndexOfNumsWithBarrier(int start, int end, int target)
throws InterruptedException, BrokenBarrierException {

getIndexOfNums(start, end, target);

barrier.await();
return;
}
}

结论

慎用多线程

使用多线程时要考虑线程创建和计算机多线程间context切换时消耗的资源。 Sometimes, less is more.