import java.lang.Iterable
import org.apache.flink.api.common.functions.GroupReduceFunction
import org.apache.flink.api.java.aggregation.Aggregations.SUM
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.api.scala._
import org.apache.flink.examples.java.graph.util.PageRankData
import org.apache.flink.util.Collector
import scala.collection.JavaConverters._
/**
* 使用方法:
* {{{
* PageRankBasic --pages <path> --links <path> --output <path> --numPages <n> --iterations <n>
* }}}
* 输入数据格式如下:
* pages指定路径的数据格式:
* 通过换行分割的数据集,一行是一条记录,每一行有一个字段ID,我们暂时叫做pageId
* 例如:1\n2\n12\n42\n63代表了5个page, IDs是1, 2, 12, 42, 63
*
* links指定路径的数据格式:
* 由pageId组成的key value键值对,key value通过空格分开,一行是一条记录,一行记录包含两列
* 第一列:我们叫做soureId
* 第二列:我们叫做targetId
* 含义:sourceId -> targetId 第一列的page可以到达第二列的page
*
* 其他参数
* - numPages: 指定一共有多少个page ,即有多少个pageId
* - Iterations: 迭代次数
*/
object PageRankBasic {
//阻尼系数用于校正rank值
private final val DAMPENING_FACTOR: Double = 0.85
//用于递归停止条件,当某个page的上一轮与下一轮的rank值的差的绝对值小于EPSILON时候,停止更新
private final val EPSILON: Double = 0.0001
def main(args: Array[String]) {
val params: ParameterTool = ParameterTool.fromArgs(args)
// 设置flink的执行环境
val env = ExecutionEnvironment.getExecutionEnvironment
// 使的这些参数对于web接口可用
env.getConfig.setGlobalJobParameters(params)
// 读取输入数据集
// pages:DataSet[Long], numPages:Long, links:DataSet[Link], maxIterations:Int
val (pages, numPages) = getPagesDataSet(env, params)
val links = getLinksDataSet(env, params)
val maxIterations = params.getInt("iterations", 10)
// 初始化每个page的rank值,每个pageId对应的rank初始值是 1/numPages,封装到Page(pageId: Long, rank: Double)对象
//pages:DataSet[Long] ,此处的withForwardedFields("*->pageId") 针对map操作而言,含义是将输入数据所有字段(这里只有一个字段)与 输出数据的类型Page类的pageId字段对应
val pagesWithRanks = pages.map(p => Page(p, 1.0 / numPages)).withForwardedFields("*->pageId")
// 构建邻接列表
//groupBy根据输入数据类型Link的sourceId字段分组,输出数据类型(Long,Iterable[Link]),
//然后通过使用reduceGroup算子,将Iterable[Link]变为AdjacencyList类型,
//AdjacencyList类型第一个字段代表pageId ,第二个字段代表该pageId所有可到的pageId组成的集合,
//AdjacencyList(sourceId: Long, targetIds: Array[Long])
//最终结果集adjacencyLists的类型是DataSet[sourceid:Long,AdjacencyList]
val adjacencyLists = links
.groupBy("sourceId").reduceGroup( new GroupReduceFunction[Link, AdjacencyList] {
override def reduce(values: Iterable[Link], out: Collector[AdjacencyList]): Unit = {
var outputId = -1L
//outputList存放的是某个outputId的所有可达pageIds,outputId是sourceId
val outputList = values.asScala map { t => outputId = t.sourceId; t.targetId }
//将sourceId targetIds封装到AdjacencyList对象
out.collect(new AdjacencyList(outputId, outputList.toArray))
}
})
// 开始迭代
val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
currentRanks =>
val newRanks = currentRanks
// distribute ranks to target pages
.join(adjacencyLists).where("pageId").equalTo("sourceId") {
//page是 pagesWithRanks:DataSet[Page]中数据, adjacent是adjacencyLists:DataSet[Long,AdjacencyList]中数据
(page, adjacent, out: Collector[Page]) =>
//获取该sourceId所有的可达pageIds
val targets = adjacent.targetIds
//获取sourceId可达pageId的个数
val len = targets.length
//分别求该sourceId下可达的每个pageId的rank,封装到Page
adjacent.targetIds foreach { t => out.collect(Page(t, page.rank /len )) }
}
//上面代码得到数据DataSet[Page],然后按照pageId分组,对rank求和
// collect ranks and sum them up
.groupBy("pageId").aggregate(SUM, "rank")
// apply dampening factor
.map { p =>
//使用阻尼系数,进行修正
Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages))
}.withForwardedFields("pageId")
// terminate if no rank update was significant
//递归停止条件:检查当前的rank 与 下一轮迭代的rank 的差值 ,如果小于EPSILON那么停止
val termination = currentRanks.join(newRanks).where("pageId").equalTo("pageId") {
(current, next, out: Collector[Int]) =>
// check for significant update 只保留需要更新的数据
if (math.abs(current.rank - next.rank) > EPSILON) out.collect(1)
}
//newRanks是下一轮迭代的数据集
(newRanks, termination)
}
val result = finalRanks
// 输出数据
if (params.has("output")) {
result.writeAsCsv(params.get("output"), "\n", " ")
// 执行程序
env.execute("Basic PageRank Example")
} else {
println("Printing result to stdout. Use --output to specify output path.")
result.print()
}
}
// *************************************************************************
// 自定义的类型
// *************************************************************************
case class Link(sourceId: Long, targetId: Long)
case class Page(pageId: Long, rank: Double)
case class AdjacencyList(sourceId: Long, targetIds: Array[Long])
// *************************************************************************
// 辅助方法
// *************************************************************************
private def getPagesDataSet(env: ExecutionEnvironment, params: ParameterTool):
(DataSet[Long], Long) = {
if (params.has("pages") && params.has("numPages")) {
val pages = env
.readCsvFile[Tuple1[Long]](params.get("pages"), fieldDelimiter = " ", lineDelimiter = "\n")
.map(x => x._1)
(pages, params.getLong("numPages"))
} else {
println("Executing PageRank example with default pages data set.")
println("Use --pages and --numPages to specify file input.")
(env.generateSequence(1, 15), PageRankData.getNumberOfPages)
}
}
private def getLinksDataSet(env: ExecutionEnvironment, params: ParameterTool):
DataSet[Link] = {
if (params.has("links")) {
env.readCsvFile[Link](params.get("links"), fieldDelimiter = " ",
includedFields = Array(0, 1))
} else {
println("Executing PageRank example with default links data set.")
println("Use --links to specify file input.")
val edges = PageRankData.EDGES.map { case Array(v1, v2) => Link(v1.asInstanceOf[Long],
v2.asInstanceOf[Long])}
env.fromCollection(edges)
}
}
}