模拟实现Java线程
先看整体的设计思路,由于Java被定义为一种跨平台语言,而且跨平台是通过JVM层实现的,所以很多概念都通过JVM层进行抽象,包括Java语言的线程,它需要JVM来提供具体实现的。整体的设计思路如下图,在Java层我们用Java语言定义一个Thread类,该类表示Java层的线程。JVM层则需要定义JavaThread类和OSThread类,这两个类都通过C++进行定义,其中JavaThread类用于表示Java层的线程,而OSThread类则用于对不同操作系统底层的抽象。这里我们将基于linux操作系统来模拟实现,linux系统为我们提供了pthread库来操作线程。
最开始我们要在Java层定义一个com.seaboat.Thread类,这个类就是模拟的线程类,对应着Java官方为我们提供的java.lang.Thread类。JVM与Java进行了约定,Thread类中的run方法定义线程要执行的任务,而start方法则是启动线程。run方法由用户在自定义线程类中进行重写,当用户调用start方法时则会调用本地的start0方法,该方法的实现在com_seaboat_Thread.so库中。此外我们还创建了一个AtomicInteger对象用于生成线程ID,在构造函数中会让线程ID不断加一。
1. package com.seaboat;
2.
3. public class Thread {
4. static {
5. ai = new AtomicInteger();
6. System.load("/root/seaboat/native_test/com_seaboat_Thread.so");
7. }
8. static AtomicInteger ai;
9. public int threadId;
10.
11. public Thread() {
12. this.threadId = ai.incrementAndGet();
13. }
14.
15. public void run() {
16. }
17.
18. public void start() {
19. start0();
20. }
21.
22. private native void start0();
23.
24. }
为了实现在com_seaboat_Thread.so库,我们要先定义一个头文件,可以通过 javac -h ./ com/seaboat/Thread.java 命令让工具帮我们生成。它会帮我们生成符合JNI调用的本地方法名,方法的命名也是JVM已经规定好了的,这里的start0方法将对应Java_com_seaboat_Thread_start0方法。其中第一个参数JNIEnv* 指针表示Java运行环境,通过它我们就能够调用Java语言定义的方法,而第二个参数jobject表示调用该本地方法的Java对象。
1. #include
2. /* Header for class com_seaboat_Thread */
3.
4. #ifndef _Included_com_seaboat_Thread
5. #define _Included_com_seaboat_Thread
6. #ifdef __cplusplus
7. extern "C" {
8. #endif
9.
10. /*
11. * Class: com_seaboat_Thread
12. * Method: start0
13. * Signature: ()V
14. */
15. JNIEXPORT void JNICALL Java_com_seaboat_Thread_start0(JNIEnv*, jobject);
16.
17. #ifdef __cplusplus
18. }
19. #endif
20. #endif
接着我们会在JVM层定义一个JavaThread类,该类用来封装Java层的线程对象。在该类中主要要实现三个函数:构造函数、析构函数和执行run方法函数。构造函数主要通过GetJavaVM函数获取JavaVM指针,并且将Java层Thread对象封装成全局引用。析构函数则是将当前线程从JVM中分离。execRunMethod函数会去调用Java层Thread对象的run方法,先通过AttachCurrentThread函数将当前线程附加到JVM中从而获得Java运行环境指针,然后通过GetObjectClass函数获取到Java层Thread对象,再通过GetMethodID函数获取该对象的run方法对象的ID,进而通过CallVoidMethod函数执行该方法,执行完后删除Thread对象的全局引用。该类中的JavaVM*和jobject都属于JNI的内容,这里不展开讲,我们需要关注的是Parker和interruptState两个属性,每个线程都有自己的Parker对象,通过它可以对线程进行阻塞和唤醒操作。而interruptState则表示该线程的中断状态,它是一个布尔变量。
1. class JavaThread {
2. public:
3. JavaVM *jvm;
4. jobject jThreadObjectRef;
5. Parker parker;
6. bool interruptState = false;
7. ~JavaThread();
8. JavaThread(JNIEnv *env, jobject jThreadObject);
9. void execRunMethod();
10. };
11.
12. JavaThread::JavaThread(JNIEnv *env, jobject jThreadObject) {
13. env->GetJavaVM(&(this->jvm));
14. this->jThreadObjectRef = env->NewGlobalRef(jThreadObject);
15. }
16.
17. JavaThread::~JavaThread() {
18. jvm->DetachCurrentThread();
19. }
20.
21. void JavaThread::execRunMethod() {
22. JNIEnv *env;
23. if (jvm->AttachCurrentThread((void**) &env, NULL) != 0) {
24. std::cout << "Failed to attach" << std::endl;
25. }
26. jclass cls = env->GetObjectClass(jThreadObjectRef);
27. jmethodID runId = env->GetMethodID(cls, "run", "()V");
28. if (runId != nullptr) {
29. env->CallVoidMethod(jThreadObjectRef, runId);
30. } else {
31. cout << "No run method found in the Thread object!!" << endl;
32. }
33. env->DeleteGlobalRef(jThreadObjectRef);
34. }
Parker为我们提供的函数为park和unpark,分别表示阻塞和唤醒操作。该类其实就是封装了pthread提供的线程阻塞唤醒相关函数,park函数间接调用pthread_cond_timedwait函数,通过它可以让线程阻塞指定的时间,如果超时了则会自动唤醒,而且操作前必须通过pthread_mutex_lock获取锁,操作完成则通过pthread_mutex_unlock释放锁。Unpark函数间接调用了pthread_cond_signal函数,同样也是需要先获取锁,操作完后释放锁。
1. class Parker {
2. private:
3. pthread_mutex_t _mutex;
4. pthread_cond_t _cond;
5. public:
6. Parker();
7. void park(long millis);
8. void unpark();
9. };
1. Parker::Parker() {
2. pthread_mutex_init(&_mutex, NULL);
3. pthread_cond_init(&_cond, NULL);
4. }
5.
6. void Parker::park(long millis) {
7. struct timespec ts;
8. struct timeval now;
9. int status = pthread_mutex_lock(&_mutex);
10. gettimeofday(&now, NULL);
11. ts.tv_sec = time(NULL) + millis / 1000;
12. ts.tv_nsec = now.tv_usec * 1000 + 1000 * 1000 * (millis % 1000);
13. ts.tv_sec += ts.tv_nsec / (1000 * 1000 * 1000);
14. ts.tv_nsec %= (1000 * 1000 * 1000);
15. status = pthread_cond_timedwait(&_cond, &_mutex, &ts);
16. if (status == 0) {
17.
18. } else if (status == ETIMEDOUT) {
19. // TODO: Time out.
20. }
21. status = pthread_mutex_unlock(&_mutex);
22. }
23.
24. void Parker::unpark() {
25. int status = pthread_mutex_lock(&_mutex);
26. status = pthread_cond_signal(&_cond);
27. status = pthread_mutex_unlock(&_mutex);
28. }
我们还要在JVM层定义一个OSThread类,该类负责封装linux系统pthread库提供的线程。可以看到它关联了JavaThread对象,构造函数中会传入JavaThread指针。call_os_thread函数用于通过pthread库来创建操作系统线程,核心就是调用pthread_create函数,该函数主要关注第三和第四个参数,分别表示线程执行的函数和对应的参数,也就是任务的定义。线程执行的具体任务由OSThread::thread_entry_function函数定义,该函数负责调用JavaThread对象的execRunMethod函数,也就是Java层的Thread对象的run方法。
1. class OSThread {
2. private:
3. JavaThread *javaThread;
4. public:
5. OSThread(JavaThread *javaThread);
6. void call_os_thread();
7. static void* thread_entry_function(void *args);
8. };
9.
10. OSThread::OSThread(JavaThread *javaThread) {
11. this->javaThread = javaThread;
12. }
13.
14. void OSThread::call_os_thread() {
15. pthread_t tid;
16. pthread_attr_t Attr;
17. pthread_attr_init(&Attr);
18. pthread_attr_setdetachstate(&Attr, PTHREAD_CREATE_DETACHED);
19. std::cout << "creating linux thread!" << endl;
20. if (pthread_create(&tid, &Attr, &OSThread::thread_entry_function,
21. this->javaThread) != 0) {
22. std::cout << "Error creating thread" << endl;
23. return;
24. }
25. std::cout << "Started a linux thread! tid=" << tid << endl;
26. pthread_attr_destroy(&Attr);
27. }
28.
29. void* OSThread::thread_entry_function(void *args) {
30. JavaThread *javaThread = (JavaThread*) args;
31. javaThread->execRunMethod();
32. delete javaThread;
33. return NULL;
34. }
以上已经定义好了Java层的Thread线程类以及JVM层的相关类,最后我们来定义start0本地方法。在方法外先创建一个map变量,它用于存放JVM层的所有线程对象,从而方便在不同方法内部根据线程ID来获取JavaThread*。这里简单使用map结构管理线程对象,而且我们不考虑线程安全问题。在方法内先创建JavaThread对象,然后获取Java层的线程ID值,对应Thread类的threadId属性,成功获取后以threadId为键而javaThread为值增加到map结构中。接着继续创建OSThread对象,最后调用OSThread对象的call_os_thread函数。
1. map threads;
2.
3. JNIEXPORT void JNICALL Java_com_seaboat_Thread_start0(JNIEnv *env,
4. jobject jThreadObject) {
5. std::cout << "creating a JavaThread object!" << endl;
6. JavaThread *javaThread = new JavaThread(env, jThreadObject);
7.
8. //将新线程保存到map中,方便后面根据threadId来获取javaThread
9. jclass cls = env->GetObjectClass(javaThread->jThreadObjectRef);
10. jfieldID fID = env->GetFieldID(cls, "threadId", "I");
11. jint threadId = env->GetIntField(javaThread->jThreadObjectRef, fID);
12. threads.insert(map::value_type(threadId, javaThread));
13.
14. std::cout << "threadId = " << threadId << endl;
15. std::cout << "creating a OSThread object!" << endl;
16. OSThread osThread(javaThread);
17. osThread.call_os_thread();
18. return;
19. }
所有代码编写好后要对C++代码进行编译,我们使用g++进行编译,具体命令为g++ -fPIC -c -std=c++0x com_seaboat_Thread.cpp -I /usr/java/jdk1.8.0_111/include/ -I /usr/java/jdk1.8.0_111/include/linux/,它会编译生成com_seaboat_Thread.o目标文件。接着继续通过g++ -shared com_seaboat_Thread.o -o com_seaboat_Thread.so命令生成com_seaboat_Thread.so动态库文件。
最后在Java层创建一个线程测试类,测试代码如下,然后通过java com.seaboat.Thread命令执行该类。
1. public class ThreadTest {
2.
3. public static void main(String[] args) {
4. new MyThread().start();
5. Thread.sleep(100, 0);
6. }
7.
8. static class MyThread extends Thread {
9. public void run() {
10. System.out.println("simulates Java thread!");
11. System.out.println("thread id is " + this.threadId);
12. }
13. }
14. }
最终输出如下。
1. creating a JavaThread object!
2. threadId = 1
3. creating a OSThread object!
4. creating linux thread!
5. Started a linux thread! tid=140586779625216
6. calling sleep operation!
7. simulates Java thread!
8. thread id is 1