1. 该函数需要一个combine函数以及一个初始值 2. 函数可以返回一个与RDD中值类型不同的结果类型U 3. 我们需要一个针对每个分区合并操作,将单个分区中值(V)合并到该分区的聚合结果(U)中 4. 我们需要一个针对各个分区合并操作,将各个分区所聚合的结果(U)合并为一个结果(U) 5. 以上两步这两个参数返回的是聚合后的结果U,而并非是新创建了一个U
/** * Aggregate the values of each key, using given combine functions and a neutral "zero value". * This function can return a different result type, U, than the type of the values in this RDD, * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, * as in scala.TraversableOnce. The former operation is used for merging values within a * partition, and the latter is used for merging values between partitions. To avoid memory * allocation, both of these functions are allowed to modify and return their first argument * instead of creating a new U. */ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] = self.withScope { // Serialize the zero value to a byte array so that we can get a new clone of it on each key val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) val zeroArray = new Array[Byte](zeroBuffer.limit) zeroBuffer.get(zeroArray) lazy val cachedSerializer = SparkEnv.get.serializer.newInstance() val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) } /** * 注释省略 */ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) } /** * 注释省略 */ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, combOp: (U, U) => U): RDD[(K, U)] = self.withScope { aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) }
package com.yd.spark.job.batch.analysis import com.yd.spark.common.config.SparkEnvInit /** * @Author Guozy * @Description * @Date 2021/12/16 22:14 **/ object testAggregateBykey extends App { //初始换环境,这里是一个初始换spark的一个工具类 SparkEnvInit.init() // 获取spark上下文 val sc = SparkEnvInit.getSparkContext val testData = Array( ("a", 1), ("a", 3), ("b", 4), ("c", 4), ("b", 5), ("d", 3), ("a", 1), ("e", 3), ("a", 4), ("f", 4), ("c", 5), ("c", 3), ("c", 1), ("c", 3), ("b", 4), ("a", 4), ("e", 5), ("e", 3), ("e", 1), ("f", 3), ("c", 4), ("c", 4), ("c", 5), ("c", 3) ) val testRDD = sc.parallelize(testData, 4) val resultRDD = testRDD.aggregateByKey(10)( (u: Int, v: Int) => u + v, (u1: Int, u2: Int) => u1 + u2 ) resultRDD.foreach(record => { val partIndex = TaskContext.getPartitionId() println("分区:" + partIndex + "," + record._1 + "=" + record._2) }) }