spark自定义累加器

在spark2.0后,官方提供了一个新的抽象类AccumulaterV2来提供更加友好的自定义类型累加器的实现方式。

	abstract class AccumulatorV2[IN, OUT] extends Serializable 

实现自定义累加器需要继承AccumulatorV2并重写下面的方法。

class MyAccumulator extends AccumulatorV2[String, java.util.Set[String]] {

  private val _logArray: java.util.Set[String] = new java.util.HashSet[String]()

  /**
    * 判断是否为空
    *
    * @return
    */
  override def isZero: Boolean = {
    _logArray.isEmpty
  }

  override def copy(): AccumulatorV2[String, util.Set[String]] = {
    val newAcc = new MyAccumulator
    _logArray.synchronized {
      newAcc._logArray.addAll(_logArray)
    }
    newAcc
  }

  override def reset(): Unit = {
    _logArray.clear()
  }

  override def add(v: String): Unit = {
    _logArray.add(v)
  }

  override def merge(other: AccumulatorV2[String, util.Set[String]]): Unit = {
    other match {
      case o: MyAccumulator => _logArray.addAll(o.value)
    }

  }

  override def value: util.Set[String] = {
    java.util.Collections.unmodifiableSet(_logArray)
  }
}

测试:

object TestMyAccumulator {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("TestMyAccumulator").setMaster("local[4]")
    val sc = new SparkContext(conf)
    val myacc = new MyAccumulator
    sc.register(myacc, "myacc")
    val rdd = sc.parallelize(Array("1", "c", "1", "c", "b", "a", "3", "2"))

    rdd.map(x => {
      myacc.add(x)
    }).collect()

    println(rdd.foreach(println))
    import scala.collection.JavaConversions._
    for (i <- myacc.value) {
      print(i + "______")
    }
    sc.stop()
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值