import java.lang.Iterable

import org.apache.flink.api.common.functions.GroupReduceFunction
import org.apache.flink.api.java.aggregation.Aggregations.SUM
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.api.scala._
import org.apache.flink.examples.java.graph.util.PageRankData
import org.apache.flink.util.Collector

import scala.collection.JavaConverters._

/**
 *  使用方法:
 * {{{
 *   PageRankBasic --pages <path> --links <path> --output <path> --numPages <n> --iterations <n>
 * }}}

 * 输入数据格式如下:
 *        pages指定路径的数据格式:
 *                  通过换行分割的数据集,一行是一条记录,每一行有一个字段ID,我们暂时叫做pageId
 *                  例如:1\n2\n12\n42\n63代表了5个page, IDs是1, 2, 12, 42,  63
 * 
 *        links指定路径的数据格式:
 *                 由pageId组成的key value键值对,key value通过空格分开,一行是一条记录,一行记录包含两列
 *                 第一列:我们叫做soureId
 *                 第二列:我们叫做targetId
 *                 含义:sourceId -> targetId 第一列的page可以到达第二列的page
 *       
 * 其他参数     
 *  - numPages:  指定一共有多少个page ,即有多少个pageId
 *  -  Iterations:  迭代次数
 */
object PageRankBasic {

  //阻尼系数用于校正rank值
  private final val DAMPENING_FACTOR: Double = 0.85
  //用于递归停止条件,当某个page的上一轮与下一轮的rank值的差的绝对值小于EPSILON时候,停止更新
  private final val EPSILON: Double = 0.0001

  def main(args: Array[String]) {

    val params: ParameterTool = ParameterTool.fromArgs(args)

    // 设置flink的执行环境
    val env = ExecutionEnvironment.getExecutionEnvironment

    // 使的这些参数对于web接口可用
    env.getConfig.setGlobalJobParameters(params)

    // 读取输入数据集
    // pages:DataSet[Long],   numPages:Long,   links:DataSet[Link],   maxIterations:Int
    val (pages, numPages) = getPagesDataSet(env, params)
    val links = getLinksDataSet(env, params)
    val maxIterations = params.getInt("iterations", 10)

    // 初始化每个page的rank值,每个pageId对应的rank初始值是  1/numPages,封装到Page(pageId: Long, rank: Double)对象
    //pages:DataSet[Long] ,此处的withForwardedFields("*->pageId") 针对map操作而言,含义是将输入数据所有字段(这里只有一个字段)与 输出数据的类型Page类的pageId字段对应
    val pagesWithRanks = pages.map(p => Page(p, 1.0 / numPages)).withForwardedFields("*->pageId")

    // 构建邻接列表
    //groupBy根据输入数据类型Link的sourceId字段分组,输出数据类型(Long,Iterable[Link]),
    //然后通过使用reduceGroup算子,将Iterable[Link]变为AdjacencyList类型,
    //AdjacencyList类型第一个字段代表pageId ,第二个字段代表该pageId所有可到的pageId组成的集合,
    //AdjacencyList(sourceId: Long, targetIds: Array[Long])
    //最终结果集adjacencyLists的类型是DataSet[sourceid:Long,AdjacencyList]
    val adjacencyLists = links
      .groupBy("sourceId").reduceGroup( new GroupReduceFunction[Link, AdjacencyList] {
        override def reduce(values: Iterable[Link], out: Collector[AdjacencyList]): Unit = {
          var outputId = -1L
          //outputList存放的是某个outputId的所有可达pageIds,outputId是sourceId
          val outputList = values.asScala map { t => outputId = t.sourceId; t.targetId }
          //将sourceId targetIds封装到AdjacencyList对象
          out.collect(new AdjacencyList(outputId, outputList.toArray))
        }
      })

    // 开始迭代
    val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
      currentRanks =>
        val newRanks = currentRanks
          // distribute ranks to target pages
          .join(adjacencyLists).where("pageId").equalTo("sourceId") {
            //page是 pagesWithRanks:DataSet[Page]中数据, adjacent是adjacencyLists:DataSet[Long,AdjacencyList]中数据
            (page, adjacent, out: Collector[Page]) =>
              //获取该sourceId所有的可达pageIds
              val targets = adjacent.targetIds
              //获取sourceId可达pageId的个数
              val len = targets.length
              //分别求该sourceId下可达的每个pageId的rank,封装到Page
              adjacent.targetIds foreach { t => out.collect(Page(t, page.rank /len )) }
          }
          //上面代码得到数据DataSet[Page],然后按照pageId分组,对rank求和
          // collect ranks and sum them up
          .groupBy("pageId").aggregate(SUM, "rank")
          // apply dampening factor
          .map { p =>
            //使用阻尼系数,进行修正
            Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages))
          }.withForwardedFields("pageId")

        // terminate if no rank update was significant
        //递归停止条件:检查当前的rank 与 下一轮迭代的rank 的差值 ,如果小于EPSILON那么停止
        val termination = currentRanks.join(newRanks).where("pageId").equalTo("pageId") {
          (current, next, out: Collector[Int]) =>
            // check for significant update 只保留需要更新的数据
            if (math.abs(current.rank - next.rank) > EPSILON) out.collect(1)
        }
        //newRanks是下一轮迭代的数据集
        (newRanks, termination)
    }

    val result = finalRanks

    // 输出数据
    if (params.has("output")) {
      result.writeAsCsv(params.get("output"), "\n", " ")
      // 执行程序
      env.execute("Basic PageRank Example")
    } else {
      println("Printing result to stdout. Use --output to specify output path.")
      result.print()
    }
  }

  // *************************************************************************
  //     自定义的类型
  // *************************************************************************

  case class Link(sourceId: Long, targetId: Long)

  case class Page(pageId: Long, rank: Double)

  case class AdjacencyList(sourceId: Long, targetIds: Array[Long])

  // *************************************************************************
  //     辅助方法
  // *************************************************************************

  private def getPagesDataSet(env: ExecutionEnvironment, params: ParameterTool):
                     (DataSet[Long], Long) = {
    if (params.has("pages") && params.has("numPages")) {
      val pages = env
        .readCsvFile[Tuple1[Long]](params.get("pages"), fieldDelimiter = " ", lineDelimiter = "\n")
        .map(x => x._1)
      (pages, params.getLong("numPages"))
    } else {
      println("Executing PageRank example with default pages data set.")
      println("Use --pages and --numPages to specify file input.")
      (env.generateSequence(1, 15), PageRankData.getNumberOfPages)
    }
  }

  private def getLinksDataSet(env: ExecutionEnvironment, params: ParameterTool):
                      DataSet[Link] = {
    if (params.has("links")) {
      env.readCsvFile[Link](params.get("links"), fieldDelimiter = " ",
        includedFields = Array(0, 1))
    } else {
      println("Executing PageRank example with default links data set.")
      println("Use --links to specify file input.")
      val edges = PageRankData.EDGES.map { case Array(v1, v2) => Link(v1.asInstanceOf[Long],
        v2.asInstanceOf[Long])}
      env.fromCollection(edges)
    }
  }
}