spark 计算TF-IDF的多种方法

TF-IDF相关概念请移步百度百科

下面记录自己在做关键词抽取的时候,计算TF-IDF时候的经历,使用spark不多。
下面的始终方法我都试过,最后选了第四个方案,在500W篇新闻中,计算2-gram的IDF,用时5个小时,虽然时间用的很长,但是最终是可以跑起来。

1. 基于mllib.HashingTF

这个方法几乎可以直接copy 官网中的example,没啥需要改的,**但是**HashingTF无法根据映射的index获取token.所有,我们只能得到一篇文档的TF-IDF向量,而无法得到向量中的每个TF-IDF值对应哪个词。这种方法适合使用TF-IDF计算文档级别粒度的模型,比如使用TF-IDF计算文本相似度。代码如下:

package Main;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import java.util.regex.Pattern;

public class HashTF implements Serializable {


    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("LoadData")//
                .setMaster("local[*]");
        Logger logger = Logger.getLogger("org.apache.spark");
        logger.setLevel(Level.WARN);
        JavaSparkContext jsc = new JavaSparkContext(conf);
        final SQLContext sqlContext = new SQLContext(jsc);

        //1. 读取数据
        DataFrame data = sqlContext.read().parquet(args[0])
                .select("docId", "text");

        //2. 切分n-gram
        sqlContext.udf().register("tmp", new UDF1<String, List<String>>() {
            @Override
            public List<String> call(String s) throws Exception {
                return tokenizer(s, 2);
            }
        }, DataTypes.createArrayType(DataTypes.StringType));

        data.withColumn("ngramToken",
                functions.callUDF("tmp",
                        functions.col("text")));

        HashingTF hashingTF = new HashingTF()
                .setInputCol("ngramToken")
                .setOutputCol("tfFeature")
                .setNumFeatures(100);

        DataFrame tfFeature = hashingTF.transform(data);

        IDF idf = new IDF()
                .setInputCol("tfFeature")
                .setOutputCol("features");

        IDFModel idfModel = idf.fit(tfFeature);
        //获取TF-IDF
        DataFrame rst = idfModel.transform(tfFeature);

        rst.show(false);

        jsc.close();

    }

    public static List<String> tokenizer(String sentence, int maxLen) {
        List<String> rst = new ArrayList<>();
        if (sentence == null || sentence.equals("")) {
            return rst;
        }
        String token = "";
        for (int i = 0; i < sentence.length(); ++i) {
            for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
                token = sentence.substring(i, j + i);
                String pattern = "[\\u4e00-\\u9fa5]+";
                if (!Pattern.matches(pattern, token)){
                    continue;
                }
                rst.add(token);
            }
        }
        return rst;
    }
}

这个方法其实挺好的,可以承受大数据集,但是不能取出每个index 对应的token。于是我就使用Spark中另外一种index的方法CountVectorizer

2. 基于CountVectorizer

基于这种的方法可以取到每个index映射后的token,但是存在一个大问题,就是CountVectorizer会将全部的n-gram 都映射起来,这将非常消耗内存,我试过几次,都会出现OOM的问题。

package Main;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.ml.feature.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;

import java.io.Serializable;


import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

public class HashTF implements Serializable {


    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("LoadData")//
                .setMaster("local[*]");
        Logger logger = Logger.getLogger("org.apache.spark");
        logger.setLevel(Level.WARN);
        JavaSparkContext jsc = new JavaSparkContext(conf);
        final SQLContext sqlContext = new SQLContext(jsc);

        //1. 读取数据
        DataFrame data = sqlContext.read().parquet(args[0])
                .select("docId", "text");

        //2. 切分n-gram
        sqlContext.udf().register("tmp", new UDF1<String, List<String>>() {
            @Override
            public List<String> call(String s) throws Exception {
                return tokenizer(s, 2);
            }
        }, DataTypes.createArrayType(DataTypes.StringType));

        data.withColumn("ngramToken",
                functions.callUDF("tmp",
                        functions.col("text")));

        CountVectorizer countVectorizer = new CountVectorizer()
                .setInputCol("ngramToken")
                .setOutputCol("tfFeature")
                .setMinDF(2);

        CountVectorizerModel model = countVectorizer.fit(data);
        DataFrame indexFeature = model.transform(data);

        IDF idf = new IDF()
                .setInputCol("tfFeature")
                .setOutputCol("features");

        IDFModel idfModel = idf.fit(indexFeature);
        //获取TF-IDF
        DataFrame rst = idfModel.transform(indexFeature);

        rst.show(false);


        //获取每个index  对应的token
        final String[] vocab = model.vocabulary();
        //获取其TF-IDF值
        JavaRDD<Row> tFIdfDF = rst.select("tfFeature", "features")
                .toJavaRDD()
                .flatMap(new FlatMapFunction<Row, Row>() {
                    @Override
                    public Iterable<Row> call(Row row) throws Exception {
                        Vector tf = row.getAs(0);
                        Vector tfidf = row.getAs(1);
                        Map<Integer, Double> tfMap = parserVector(tf.toJson());
                        Map<Integer, Double> tfIdfMap = parserVector(tfidf.toJson());

                        List<Row> rst = new ArrayList<>();

                        //遍历 所有的idx,找到对应的token
                        for (Map.Entry<Integer, Double> entry : tfMap.entrySet()) {
                            String token = vocab[entry.getKey()];
                            JSONObject object = new JSONObject();
                            rst.add(RowFactory.create(token, tfIdfMap.get(entry.getKey())));
                        }
                        return rst;
                    }
                });

        StructType schema = new StructType(new StructField[]{
                new StructField("token", DataTypes.StringType, false, Metadata.empty()),
                new StructField("idf", DataTypes.DoubleType, false, Metadata.empty())
        });
        DataFrame df = sqlContext.createDataFrame(tFIdfDF, schema);
        df.show();
        jsc.close();

    }

    public static List<String> tokenizer(String sentence, int maxLen) {
        List<String> rst = new ArrayList<>();
        if (sentence == null || sentence.equals("")) {
            return rst;
        }
        String token = "";
        for (int i = 0; i < sentence.length(); ++i) {
            for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
                token = sentence.substring(i, j + i);
                String pattern = "[\\u4e00-\\u9fa5]+";
                if (!Pattern.matches(pattern, token)){
                    continue;
                }
                rst.add(token);
            }
        }
        return rst;
    }

    public static Map<Integer, Double> parserVector(String json) throws JSONException {
        Map<Integer, Double> value = new HashMap<>();
        JSONObject object = new JSONObject(json);
        JSONArray valueObj = object.getJSONArray("values");
        JSONArray indexObj = object.getJSONArray("indices");
        for (int i = 0; i < valueObj.length(); ++i) {
            value.put(indexObj.getInt(i), valueObj.getDouble(i));
        }
        return value;
    }
}

修改点不多,就是将HashingTF 换成了CountVectorizer, 优化的时候可以将

3. 使用SqlContext.registerTempTable() 使用groupBy查

这个方法 其实挺好的,但是,,,我这里老是OOM,因为regist到内存中的表太大了。
核心代码:

package Main;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.ml.feature.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import scala.Tuple2;

import java.io.Serializable;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

public class HashTF implements Serializable {


    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("LoadData")//
                .setMaster("local[*]");
        Logger logger = Logger.getLogger("org.apache.spark");
        logger.setLevel(Level.WARN);
        JavaSparkContext jsc = new JavaSparkContext(conf);
        final SQLContext sqlContext = new SQLContext(jsc);

        //1. 读取数据
        DataFrame data = sqlContext.read().parquet(args[0])
                .select("docId", "text");
        JavaRDD<Row> docTokenPair = data.toJavaRDD()
                .flatMapToPair(new PairFlatMapFunction<Row, String, String>() {
                    @Override
                    public Iterable<Tuple2<String, String>> call(Row row) throws Exception {
                        String docId = row.getString(0);
                        String text = row.getString(1);
                        //NGRAM
                        List<Tuple2<String, String>> all = new ArrayList<>();
                        for (String token : tokenizer(text, 2)) {
                            all.add(new Tuple2<>(token, docId));
                        }
                        return all;
                    }
                })
                .map(new Function<Tuple2<String, String>, Row>() {
                    @Override
                    public Row call(Tuple2<String, String> stringStringTuple2) throws Exception {
                        return RowFactory.create(stringStringTuple2._1, stringStringTuple2._2);
                    }
                });

        StructType schema = new StructType(new StructField[]{
                new StructField("token", DataTypes.StringType, false, Metadata.empty()),
                new StructField("docId", DataTypes.StringType, false, Metadata.empty())
        });
        DataFrame df = sqlContext.createDataFrame(docTokenPair, schema);
        sqlContext.registerDataFrameAsTable(df, "table");
        df = sqlContext.sql("select token,count(*) from table group by token");
        df.show();
        jsc.close();

    }

    public static List<String> tokenizer(String sentence, int maxLen) {
        List<String> rst = new ArrayList<>();
        if (sentence == null || sentence.equals("")) {
            return rst;
        }
        String token = "";
        for (int i = 0; i < sentence.length(); ++i) {
            for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
                token = sentence.substring(i, j + i);
                String pattern = "[\\u4e00-\\u9fa5]+";
                if (!Pattern.matches(pattern, token)){
                    continue;
                }
                rst.add(token);
            }
        }
        return rst;
    }
}

觉得这个方法挺好的,优化一下内存存储,应该是很简单的。下面使用combinedByKey的方法其实想法差不多。

4. combinedByKey()

直接出代码:

package Main;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import utils.IOUtils;

import java.io.Serializable;
import java.util.*;
import java.util.regex.Pattern;

/**
 * IDF计算,不使用MLLIB包
 */
public class TFIDF implements Serializable{

    private static volatile Broadcast<Set<String>> stopWords = null;

    public static void ngram(String sentence, int n, Set<String> ndata) {
        if (sentence == null){
            return;
        }
        if (sentence.length() <= n) {
            ndata.add(sentence);
            return;
        }
        int start = 0;
        int last = start + n - 1;

        int length = sentence.length();
//        sentence.toCharArray();
        String token = "";
        Integer len = sentence.length();
        while (last <= length - 1 && last < length) {
            if (start > last || last < 0 || start < 0){
                break;
            }
            token = sentence.substring(start, last + 1);
            if (!isRemove(token)) {
                ndata.add(token);
            }
            start++;
            last = start + n - 1;
        }
    }

    public static boolean isRemove(String token) {
        if (token == null || token.equals("")) {
            return true;
        }
        String pattern = "[\\u4e00-\\u9fa5]+";
        if (!Pattern.matches(pattern, token)){
            return true;
        }

        for (char s : token.toCharArray()) {
            if (stopWords != null) {
                if (stopWords.value().contains(String.valueOf(s))) {
                    return true;
                }
            }
        }
        return false;
    }


    public static DataFrame getTokenFeq(SQLContext sqlContext, DataFrame dataDF, final int n) {
        Function<String, Set<String>> createCombiner = new Function<String, Set<String>>() {
            @Override
            public Set<String> call(String s) throws Exception {
                Set<String> t =  new HashSet<String>();
                t.add(s);
                return t;
            }
        };

        Function2<Set<String>, String, Set<String>> mergeValue = new Function2<Set<String>, String, Set<String>>() {
            @Override
            public Set<String> call(Set<String> strings, String s) throws Exception {
                Set<String> t = new HashSet<String>(strings);
                t.add(s);
                return t;
            }
        };

        Function2<Set<String>, Set<String>, Set<String>> mergeCombiner = new Function2<Set<String>, Set<String>, Set<String>>() {
            @Override
            public Set<String> call(Set<String> strings, Set<String> strings2) throws Exception {
                Set<String> t = new HashSet<String>();
                t.addAll(strings);
                t.addAll(strings2);
                return t;
            }
        };

        JavaRDD<Row> tokens = dataDF.select("docId", "title", "text").toJavaRDD()
                .flatMapToPair(new PairFlatMapFunction<Row, String, String>() {
                    @Override
                    public Iterable<Tuple2<String, String>> call(Row row) throws Exception {
                        String id = row.getString(0);
                        String title = row.getString(1);
                        String text = row.getString(2);
                        Set<String> tmp = new HashSet<>();
                        //对title 抽取n-gram
                        if (title != null && !title.equals("")) {
                            ngram(title, n, tmp);
                        }
                        //对text抽取n-gram
                        if (text != null && !text.equals("")) {
                            ngram(text, n, tmp);
                        }
                        //title,text都为空的处理
                        List<Tuple2<String, String>> all = new ArrayList<>();
                        if (tmp.size() == 0){
                            all.add(new Tuple2<String, String>("--", ""));
                            return all;
                        }

                        //组合n-gram token he docId的pair
                        for (String token : tmp) {
                            all.add(new Tuple2<>(token, id));
                        }
//                        docCnt.add(1);//文章全部增量更新
                        return all;
                    }
                })
                .filter(new Function<Tuple2<String, String>, Boolean>() {
                    @Override
                    public Boolean call(Tuple2<String, String> stringStringTuple2) throws Exception {
                        if (stringStringTuple2._1.equals("--")) {
                            return false;
                        }
                        return true;
                    }
                })
                //计算逆文档频率
                .combineByKey(createCombiner, mergeValue, mergeCombiner, 600)//在这里进行过滤掉了频率小于2的单词
                .filter(new Function<Tuple2<String, Set<String>>, Boolean>() {
                    @Override
                    public Boolean call(Tuple2<String, Set<String>> stringSetTuple2) throws Exception {
                        if (stringSetTuple2._2.size() > 2) {
                            return true;
                        }
                        return false;
                    }
                })
                .map(new Function<Tuple2<String, Set<String>>, Row>() {
                    @Override
                    public Row call(Tuple2<String, Set<String>> stringSetTuple2) throws Exception {
                        return RowFactory.create(stringSetTuple2._1, stringSetTuple2._2.size());
                    }
                });

        StructType structType = new StructType(new StructField[]{
                new StructField("token", DataTypes.StringType, false, Metadata.empty()),
                new StructField("fequence", DataTypes.IntegerType, false, Metadata.empty())
        });
        return sqlContext.createDataFrame(tokens, structType);
    }

    public static void main(String[] args) {
//        System.setProperty("hadoop.home.dir", "D:\\winutils");
        SparkConf conf = new SparkConf().setAppName("LoadData");//
//                .setMaster("local[*]");
        Logger logger = Logger.getLogger("org.apache.spark");
        logger.setLevel(Level.WARN);
        JavaSparkContext jsc = new JavaSparkContext(conf);
        final SQLContext sqlContext = new SQLContext(jsc);
        final Accumulator<Integer> docCnt = jsc.accumulator(0);
        //读停用词词典
        Set<String> stopWordsDF = new HashSet<String>(sqlContext.read()
                .text(IOUtils.stopWordPath).toJavaRDD()
                .map(new Function<Row, String>() {
                    @Override
                    public String call(Row row) throws Exception {
                        return row.getString(0);
                    }
                }).collect());

        stopWords = jsc.broadcast(stopWordsDF);

        //读取参数
        //args[0] 为需要读取的hdfs路径
        //args[1] 为需要抽取的n-gram中的{n}
        if (args.length != 2){
            return;
        }
        String path = args[0];
        int n = 0;
        try {
            n = Integer.parseInt(args[1]);
        }catch (NumberFormatException e){
           logger.error("{n}-gram input format is error" + args[1]);
        }
        logger.warn("load hdfs path is: " + path);
        //read data
        DataFrame dataDF = sqlContext
                .read()
                .parquet(path)
                .select("docId", "title", "text");
        DataFrame idfDF = getTokenFeq(sqlContext, dataDF, n);
        IOUtils.saveDataToPSQL(idfDF, "idf_feq");

        sqlContext.clearCache();
        jsc.close();
    }
}

不用groupByKey是因为n-gram的数量太大,太分散,性能上不如combined.

总结

我觉得优化方案应该还有很多,求大神指导。