思路:

类加载器里有类的信息,所以可以通过类加载器拿到类信息,然后过滤。

主要区分两种情况:

1).class在jar里。

2).class在文件系统。

完整代码:

import java.io.File;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

public class ClassUtil {

    /**
     * 从指定的package中获取所有的Class
     * 
     * @param packageName
     * @return
     */
    public static List<Class<?>> getClasses(String packageName) {

        // 第一个class类的集合
        List<Class<?>> classes = new ArrayList<Class<?>>();
        // 获取包的名字 并进行替换
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()) {
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    classes.addAll(findClassByDirectory(packageName, filePath));
                }
                else if ("jar".equals(protocol)) {
                    classes.addAll(findClassInJar(packageName, url));
                }
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }

    /**
     * 以文件的形式来获取包下的所有Class
     * 
     * @param packageName
     * @param packagePath
     */
    public static List<Class<?>> findClassByDirectory(String packageName, String packagePath) {
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        if (!dir.exists() || !dir.isDirectory()) {
            return new ArrayList<>(0);
        }

        File[] dirs = dir.listFiles();
        List<Class<?>> classes = new ArrayList<Class<?>>();
        // 循环所有文件
        for (File file : dirs) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                classes.addAll(findClassByDirectory(packageName + "." + file.getName(),
                        file.getAbsolutePath()));
            }
            else if (file.getName().endsWith(".class")) {
                // 如果是java类文件,去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    classes.add(Class.forName(packageName + '.' + className));
                }
                catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }

        return classes;
    }

    public static List<Class<?>> findClassInJar(String packageName, URL url) {

        List<Class<?>> classes = new ArrayList<Class<?>>();

        String packageDirName = packageName.replace('.', '/');
        // 定义一个JarFile
        JarFile jar;
        try {
            // 获取jar
            jar = ((JarURLConnection) url.openConnection()).getJarFile();
            Enumeration<JarEntry> entries = jar.entries();
            while (entries.hasMoreElements()) {
                // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件
                JarEntry entry = entries.nextElement();
                if (entry.isDirectory()) {
                    continue;
                }

                String name = entry.getName();
                if (name.charAt(0) == '/') {
                    // 获取后面的字符串
                    name = name.substring(1);
                }

                // 如果前半部分和定义的包名相同
                if (name.startsWith(packageDirName) && name.endsWith(".class")) {
                    // 去掉后面的".class"
                    String className = name.substring(0, name.length() - 6).replace('/', '.');
                    try {
                        // 添加到classes
                        classes.add(Class.forName(className));
                    }
                    catch (ClassNotFoundException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }

        return classes;
    }

    public static void main(String[] args) {
        getClasses("org.junit.internal.runners").forEach((e) ->
        {
            System.out.println(e);
        });
    }

}

 

运行结果:

class org.junit.internal.runners.JUnit38ClassRunner$OldTestClassAdaptingListener
class org.junit.internal.runners.JUnit4ClassRunner$1
class org.junit.internal.runners.MethodRoadie$2
class org.junit.internal.runners.TestMethod
class org.junit.internal.runners.rules.RuleMemberValidator
class org.junit.internal.runners.rules.RuleMemberValidator$MethodMustBeARule
class org.junit.internal.runners.rules.RuleMemberValidator$MemberMustBePublic
class org.junit.internal.runners.rules.ValidationError
class org.junit.internal.runners.rules.RuleMemberValidator$MemberMustBeStatic
class org.junit.internal.runners.rules.RuleMemberValidator$Builder
class org.junit.internal.runners.rules.RuleMemberValidator$MemberMustBeNonStaticOrAlsoClassRule
class org.junit.internal.runners.rules.RuleMemberValidator$FieldMustBeARule
class org.junit.internal.runners.rules.RuleMemberValidator$MethodMustBeATestRule
interface org.junit.internal.runners.rules.RuleMemberValidator$RuleValidator
class org.junit.internal.runners.rules.RuleMemberValidator$1
class org.junit.internal.runners.rules.RuleMemberValidator$DeclaringClassMustBePublic
class org.junit.internal.runners.rules.RuleMemberValidator$FieldMustBeATestRule
class org.junit.internal.runners.MethodRoadie$1
class org.junit.internal.runners.TestClass
class org.junit.internal.runners.ErrorReportingRunner
class org.junit.internal.runners.JUnit4ClassRunner
class org.junit.internal.runners.FailedBefore
class org.junit.internal.runners.statements.FailOnTimeout$1
class org.junit.internal.runners.statements.Fail
class org.junit.internal.runners.statements.FailOnTimeout
class org.junit.internal.runners.statements.RunAfters
class org.junit.internal.runners.statements.RunBefores
class org.junit.internal.runners.statements.ExpectException
class org.junit.internal.runners.statements.InvokeMethod
class org.junit.internal.runners.statements.FailOnTimeout$Builder
class org.junit.internal.runners.statements.FailOnTimeout$CallableStatement
class org.junit.internal.runners.JUnit4ClassRunner$2
class org.junit.internal.runners.MethodValidator
class org.junit.internal.runners.JUnit38ClassRunner
class org.junit.internal.runners.SuiteMethod
class org.junit.internal.runners.MethodRoadie
class org.junit.internal.runners.InitializationError
class org.junit.internal.runners.ClassRoadie
class org.junit.internal.runners.JUnit38ClassRunner$1
class org.junit.internal.runners.MethodRoadie$1$1
class org.junit.internal.runners.model.ReflectiveCallable
class org.junit.internal.runners.model.EachTestNotifier
class org.junit.internal.runners.model.MultipleFailureException