spark 计算TF-IDF的多种方法
TF-IDF相关概念请移步百度百科
下面记录自己在做关键词抽取的时候,计算TF-IDF时候的经历,使用spark不多。
下面的始终方法我都试过,最后选了第四个方案,在500W篇新闻中,计算2-gram的IDF,用时5个小时,虽然时间用的很长,但是最终是可以跑起来。
1. 基于mllib.HashingTF
这个方法几乎可以直接copy 官网中的example,没啥需要改的,**但是**HashingTF无法根据映射的index获取token.所有,我们只能得到一篇文档的TF-IDF向量,而无法得到向量中的每个TF-IDF值对应哪个词。这种方法适合使用TF-IDF计算文档级别粒度的模型,比如使用TF-IDF计算文本相似度。代码如下:
package Main;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
public class HashTF implements Serializable {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("LoadData")//
.setMaster("local[*]");
Logger logger = Logger.getLogger("org.apache.spark");
logger.setLevel(Level.WARN);
JavaSparkContext jsc = new JavaSparkContext(conf);
final SQLContext sqlContext = new SQLContext(jsc);
//1. 读取数据
DataFrame data = sqlContext.read().parquet(args[0])
.select("docId", "text");
//2. 切分n-gram
sqlContext.udf().register("tmp", new UDF1<String, List<String>>() {
@Override
public List<String> call(String s) throws Exception {
return tokenizer(s, 2);
}
}, DataTypes.createArrayType(DataTypes.StringType));
data.withColumn("ngramToken",
functions.callUDF("tmp",
functions.col("text")));
HashingTF hashingTF = new HashingTF()
.setInputCol("ngramToken")
.setOutputCol("tfFeature")
.setNumFeatures(100);
DataFrame tfFeature = hashingTF.transform(data);
IDF idf = new IDF()
.setInputCol("tfFeature")
.setOutputCol("features");
IDFModel idfModel = idf.fit(tfFeature);
//获取TF-IDF
DataFrame rst = idfModel.transform(tfFeature);
rst.show(false);
jsc.close();
}
public static List<String> tokenizer(String sentence, int maxLen) {
List<String> rst = new ArrayList<>();
if (sentence == null || sentence.equals("")) {
return rst;
}
String token = "";
for (int i = 0; i < sentence.length(); ++i) {
for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
token = sentence.substring(i, j + i);
String pattern = "[\\u4e00-\\u9fa5]+";
if (!Pattern.matches(pattern, token)){
continue;
}
rst.add(token);
}
}
return rst;
}
}
这个方法其实挺好的,可以承受大数据集,但是不能取出每个index 对应的token。于是我就使用Spark中另外一种index的方法CountVectorizer
2. 基于CountVectorizer
基于这种的方法可以取到每个index映射后的token,但是存在一个大问题,就是CountVectorizer会将全部的n-gram 都映射起来,这将非常消耗内存,我试过几次,都会出现OOM的问题。
package Main;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.ml.feature.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
public class HashTF implements Serializable {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("LoadData")//
.setMaster("local[*]");
Logger logger = Logger.getLogger("org.apache.spark");
logger.setLevel(Level.WARN);
JavaSparkContext jsc = new JavaSparkContext(conf);
final SQLContext sqlContext = new SQLContext(jsc);
//1. 读取数据
DataFrame data = sqlContext.read().parquet(args[0])
.select("docId", "text");
//2. 切分n-gram
sqlContext.udf().register("tmp", new UDF1<String, List<String>>() {
@Override
public List<String> call(String s) throws Exception {
return tokenizer(s, 2);
}
}, DataTypes.createArrayType(DataTypes.StringType));
data.withColumn("ngramToken",
functions.callUDF("tmp",
functions.col("text")));
CountVectorizer countVectorizer = new CountVectorizer()
.setInputCol("ngramToken")
.setOutputCol("tfFeature")
.setMinDF(2);
CountVectorizerModel model = countVectorizer.fit(data);
DataFrame indexFeature = model.transform(data);
IDF idf = new IDF()
.setInputCol("tfFeature")
.setOutputCol("features");
IDFModel idfModel = idf.fit(indexFeature);
//获取TF-IDF
DataFrame rst = idfModel.transform(indexFeature);
rst.show(false);
//获取每个index 对应的token
final String[] vocab = model.vocabulary();
//获取其TF-IDF值
JavaRDD<Row> tFIdfDF = rst.select("tfFeature", "features")
.toJavaRDD()
.flatMap(new FlatMapFunction<Row, Row>() {
@Override
public Iterable<Row> call(Row row) throws Exception {
Vector tf = row.getAs(0);
Vector tfidf = row.getAs(1);
Map<Integer, Double> tfMap = parserVector(tf.toJson());
Map<Integer, Double> tfIdfMap = parserVector(tfidf.toJson());
List<Row> rst = new ArrayList<>();
//遍历 所有的idx,找到对应的token
for (Map.Entry<Integer, Double> entry : tfMap.entrySet()) {
String token = vocab[entry.getKey()];
JSONObject object = new JSONObject();
rst.add(RowFactory.create(token, tfIdfMap.get(entry.getKey())));
}
return rst;
}
});
StructType schema = new StructType(new StructField[]{
new StructField("token", DataTypes.StringType, false, Metadata.empty()),
new StructField("idf", DataTypes.DoubleType, false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(tFIdfDF, schema);
df.show();
jsc.close();
}
public static List<String> tokenizer(String sentence, int maxLen) {
List<String> rst = new ArrayList<>();
if (sentence == null || sentence.equals("")) {
return rst;
}
String token = "";
for (int i = 0; i < sentence.length(); ++i) {
for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
token = sentence.substring(i, j + i);
String pattern = "[\\u4e00-\\u9fa5]+";
if (!Pattern.matches(pattern, token)){
continue;
}
rst.add(token);
}
}
return rst;
}
public static Map<Integer, Double> parserVector(String json) throws JSONException {
Map<Integer, Double> value = new HashMap<>();
JSONObject object = new JSONObject(json);
JSONArray valueObj = object.getJSONArray("values");
JSONArray indexObj = object.getJSONArray("indices");
for (int i = 0; i < valueObj.length(); ++i) {
value.put(indexObj.getInt(i), valueObj.getDouble(i));
}
return value;
}
}
修改点不多,就是将HashingTF 换成了CountVectorizer, 优化的时候可以将
3. 使用SqlContext.registerTempTable() 使用groupBy查
这个方法 其实挺好的,但是,,,我这里老是OOM,因为regist到内存中的表太大了。
核心代码:
package Main;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.ml.feature.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import scala.Tuple2;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
public class HashTF implements Serializable {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("LoadData")//
.setMaster("local[*]");
Logger logger = Logger.getLogger("org.apache.spark");
logger.setLevel(Level.WARN);
JavaSparkContext jsc = new JavaSparkContext(conf);
final SQLContext sqlContext = new SQLContext(jsc);
//1. 读取数据
DataFrame data = sqlContext.read().parquet(args[0])
.select("docId", "text");
JavaRDD<Row> docTokenPair = data.toJavaRDD()
.flatMapToPair(new PairFlatMapFunction<Row, String, String>() {
@Override
public Iterable<Tuple2<String, String>> call(Row row) throws Exception {
String docId = row.getString(0);
String text = row.getString(1);
//NGRAM
List<Tuple2<String, String>> all = new ArrayList<>();
for (String token : tokenizer(text, 2)) {
all.add(new Tuple2<>(token, docId));
}
return all;
}
})
.map(new Function<Tuple2<String, String>, Row>() {
@Override
public Row call(Tuple2<String, String> stringStringTuple2) throws Exception {
return RowFactory.create(stringStringTuple2._1, stringStringTuple2._2);
}
});
StructType schema = new StructType(new StructField[]{
new StructField("token", DataTypes.StringType, false, Metadata.empty()),
new StructField("docId", DataTypes.StringType, false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(docTokenPair, schema);
sqlContext.registerDataFrameAsTable(df, "table");
df = sqlContext.sql("select token,count(*) from table group by token");
df.show();
jsc.close();
}
public static List<String> tokenizer(String sentence, int maxLen) {
List<String> rst = new ArrayList<>();
if (sentence == null || sentence.equals("")) {
return rst;
}
String token = "";
for (int i = 0; i < sentence.length(); ++i) {
for (int j = 1; j < Math.min(sentence.length() - i, maxLen + 1); ++j) {
token = sentence.substring(i, j + i);
String pattern = "[\\u4e00-\\u9fa5]+";
if (!Pattern.matches(pattern, token)){
continue;
}
rst.add(token);
}
}
return rst;
}
}
觉得这个方法挺好的,优化一下内存存储,应该是很简单的。下面使用combinedByKey的方法其实想法差不多。
4. combinedByKey()
直接出代码:
package Main;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import utils.IOUtils;
import java.io.Serializable;
import java.util.*;
import java.util.regex.Pattern;
/**
* IDF计算,不使用MLLIB包
*/
public class TFIDF implements Serializable{
private static volatile Broadcast<Set<String>> stopWords = null;
public static void ngram(String sentence, int n, Set<String> ndata) {
if (sentence == null){
return;
}
if (sentence.length() <= n) {
ndata.add(sentence);
return;
}
int start = 0;
int last = start + n - 1;
int length = sentence.length();
// sentence.toCharArray();
String token = "";
Integer len = sentence.length();
while (last <= length - 1 && last < length) {
if (start > last || last < 0 || start < 0){
break;
}
token = sentence.substring(start, last + 1);
if (!isRemove(token)) {
ndata.add(token);
}
start++;
last = start + n - 1;
}
}
public static boolean isRemove(String token) {
if (token == null || token.equals("")) {
return true;
}
String pattern = "[\\u4e00-\\u9fa5]+";
if (!Pattern.matches(pattern, token)){
return true;
}
for (char s : token.toCharArray()) {
if (stopWords != null) {
if (stopWords.value().contains(String.valueOf(s))) {
return true;
}
}
}
return false;
}
public static DataFrame getTokenFeq(SQLContext sqlContext, DataFrame dataDF, final int n) {
Function<String, Set<String>> createCombiner = new Function<String, Set<String>>() {
@Override
public Set<String> call(String s) throws Exception {
Set<String> t = new HashSet<String>();
t.add(s);
return t;
}
};
Function2<Set<String>, String, Set<String>> mergeValue = new Function2<Set<String>, String, Set<String>>() {
@Override
public Set<String> call(Set<String> strings, String s) throws Exception {
Set<String> t = new HashSet<String>(strings);
t.add(s);
return t;
}
};
Function2<Set<String>, Set<String>, Set<String>> mergeCombiner = new Function2<Set<String>, Set<String>, Set<String>>() {
@Override
public Set<String> call(Set<String> strings, Set<String> strings2) throws Exception {
Set<String> t = new HashSet<String>();
t.addAll(strings);
t.addAll(strings2);
return t;
}
};
JavaRDD<Row> tokens = dataDF.select("docId", "title", "text").toJavaRDD()
.flatMapToPair(new PairFlatMapFunction<Row, String, String>() {
@Override
public Iterable<Tuple2<String, String>> call(Row row) throws Exception {
String id = row.getString(0);
String title = row.getString(1);
String text = row.getString(2);
Set<String> tmp = new HashSet<>();
//对title 抽取n-gram
if (title != null && !title.equals("")) {
ngram(title, n, tmp);
}
//对text抽取n-gram
if (text != null && !text.equals("")) {
ngram(text, n, tmp);
}
//title,text都为空的处理
List<Tuple2<String, String>> all = new ArrayList<>();
if (tmp.size() == 0){
all.add(new Tuple2<String, String>("--", ""));
return all;
}
//组合n-gram token he docId的pair
for (String token : tmp) {
all.add(new Tuple2<>(token, id));
}
// docCnt.add(1);//文章全部增量更新
return all;
}
})
.filter(new Function<Tuple2<String, String>, Boolean>() {
@Override
public Boolean call(Tuple2<String, String> stringStringTuple2) throws Exception {
if (stringStringTuple2._1.equals("--")) {
return false;
}
return true;
}
})
//计算逆文档频率
.combineByKey(createCombiner, mergeValue, mergeCombiner, 600)//在这里进行过滤掉了频率小于2的单词
.filter(new Function<Tuple2<String, Set<String>>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Set<String>> stringSetTuple2) throws Exception {
if (stringSetTuple2._2.size() > 2) {
return true;
}
return false;
}
})
.map(new Function<Tuple2<String, Set<String>>, Row>() {
@Override
public Row call(Tuple2<String, Set<String>> stringSetTuple2) throws Exception {
return RowFactory.create(stringSetTuple2._1, stringSetTuple2._2.size());
}
});
StructType structType = new StructType(new StructField[]{
new StructField("token", DataTypes.StringType, false, Metadata.empty()),
new StructField("fequence", DataTypes.IntegerType, false, Metadata.empty())
});
return sqlContext.createDataFrame(tokens, structType);
}
public static void main(String[] args) {
// System.setProperty("hadoop.home.dir", "D:\\winutils");
SparkConf conf = new SparkConf().setAppName("LoadData");//
// .setMaster("local[*]");
Logger logger = Logger.getLogger("org.apache.spark");
logger.setLevel(Level.WARN);
JavaSparkContext jsc = new JavaSparkContext(conf);
final SQLContext sqlContext = new SQLContext(jsc);
final Accumulator<Integer> docCnt = jsc.accumulator(0);
//读停用词词典
Set<String> stopWordsDF = new HashSet<String>(sqlContext.read()
.text(IOUtils.stopWordPath).toJavaRDD()
.map(new Function<Row, String>() {
@Override
public String call(Row row) throws Exception {
return row.getString(0);
}
}).collect());
stopWords = jsc.broadcast(stopWordsDF);
//读取参数
//args[0] 为需要读取的hdfs路径
//args[1] 为需要抽取的n-gram中的{n}
if (args.length != 2){
return;
}
String path = args[0];
int n = 0;
try {
n = Integer.parseInt(args[1]);
}catch (NumberFormatException e){
logger.error("{n}-gram input format is error" + args[1]);
}
logger.warn("load hdfs path is: " + path);
//read data
DataFrame dataDF = sqlContext
.read()
.parquet(path)
.select("docId", "title", "text");
DataFrame idfDF = getTokenFeq(sqlContext, dataDF, n);
IOUtils.saveDataToPSQL(idfDF, "idf_feq");
sqlContext.clearCache();
jsc.close();
}
}
不用groupByKey是因为n-gram的数量太大,太分散,性能上不如combined.
总结
我觉得优化方案应该还有很多,求大神指导。