Spark(18):SparkSQL之用户自定义函数

该文详细介绍了如何在IDEA中使用SparkSQL,包括添加依赖和创建DataFrame、Dataset。同时,讲解了用户自定义函数(UDF)和用户自定义聚合函数(UDAF)的实现,包括平均工资的计算方法,涵盖了RDD、DataFrame和Dataset之间的转换以及不同版本Spark中UDAF的使用方式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

0. 相关文章链接

1. IDEA 开发 SparkSQL

1.1. 添加依赖

1.2. 代码实现

2. 用户自定义UDF函数

3. 用户自定义UDAF函数


0. 相关文章链接

 Spark文章汇总 

1. IDEA 开发 SparkSQL

1.1. 添加依赖

<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-sql_2.12</artifactId>
    <version>3.0.0</version>
</dependency>

1.2. 代码实现

package com.yishou.bigdata.common.basic

import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/**
 * @ date: 2023/7/4
 * @ author: yangshibiao
 * @ desc: 项目描述
 */
object BasicModel {
    def main(args: Array[String]): Unit = {

        //创建上下文环境配置对象
        val conf: SparkConf = new SparkConf()
            .setMaster("local[*]")
            .setAppName("BasicModel")

        //创建 SparkSession 对象
        val spark: SparkSession = SparkSession
            .builder()
            .config(conf)
            .getOrCreate()

        //RDD=>DataFrame=>DataSet 转换需要引入隐式转换规则,否则无法转换
        //spark 不是包名,是上下文环境对象名
        import spark.implicits._

        //读取 json 文件 创建 DataFrame {"username": "lisi","age": 18}
        val df: DataFrame = spark.read.json("input/test.json")

        //SQL 风格语法
        df.createOrReplaceTempView("user")
        spark.sql("select avg(age) from user").show

        //DSL 风格语法
        df.select("username", "age").show()

        //*****RDD=>DataFrame=>DataSet*****
        //RDD
        val rdd1: RDD[(Int, String, Int)] = spark
            .sparkContext
            .makeRDD(List((1, "zhangsan", 30), (2, "lisi", 28), (3, "wangwu", 20)))
        //DataFrame
        val df1: DataFrame = rdd1.toDF("id", "name", "age")
        df1.show()
        //DateSet
        val ds1: Dataset[User] = df1.as[User]
        ds1.show()

        //*****DataSet=>DataFrame=>RDD*****
        //DataFrame
        val df2: DataFrame = ds1.toDF()
        //RDD 返回的 RDD 类型为 Row,里面提供的 getXXX 方法可以获取字段值,类似 jdbc 处理结果集, 但是索引从 0 开始
        val rdd2: RDD[Row] = df2.rdd
        rdd2.foreach((row: Row) => println(row.getString(1)))
        //*****RDD=>DataSet*****
        rdd1.map {
            case (id, name, age) => User(id, name, age)
        }.toDS()
        //*****DataSet=>=>RDD*****
        val ds2: RDD[User] = ds1.rdd
        
        //释放资源
        spark.stop()
    }
}

case class User(id: Int, name: String, age: Int)

2. 用户自定义UDF函数

  • 步骤一:创建 DataFrame
scala> val df = spark.read.json("data/user.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, username: string]
  • 步骤二:注册 UDF
scala> spark.udf.register("addName",(x:String)=> "Name:"+x)
res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
  • 步骤三:创建临时表
scala> df.createOrReplaceTempView("people")
  • 步骤四:应用 UDF
scala> spark.sql("Select addName(name),age from people").show()

3. 用户自定义UDAF函数

        强类型的 Dataset 和弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(), countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了,可以统一采用强类型聚合函数Aggregator 。

需求: 计算平均工资

  • 实现方式1:RDD
val conf: SparkConf = new SparkConf()
    .setAppName("app")
    .setMaster("local[*]")
val sc: SparkContext = new SparkContext(conf)

val res: (Int, Int) = sc
    .makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40)))
    .map {
        case (name, age) => {
            (age, 1)
        }
    }
    .reduce {
        (t1, t2) => {
            (t1._1 + t2._1, t1._2 + t2._2)
        }
    }

println(res._1 / res._2)

// 关闭连接
sc.stop()
  • 实现方式2:累加器
class MyAC extends AccumulatorV2[Int, Int] {
    
    var sum: Int = 0
    var count: Int = 0

    override def isZero: Boolean = {
        return sum == 0 && count == 0
    }

    override def copy(): AccumulatorV2[Int, Int] = {
        val newMyAc: MyAC = new MyAC
        newMyAc.sum = this.sum
        newMyAc.count = this.count
        newMyAc
    }

    override def reset(): Unit = {
        sum = 0
        count = 0
    }

    override def add(v: Int): Unit = {
        sum += v
        count += 1
    }

    override def merge(other: AccumulatorV2[Int, Int]): Unit = {
        other match {
            case o: MyAC => {
                sum += o.sum
                count += o.count
            }
            case _ =>
        }
    }

    override def value: Int = sum / count
}
  • 实现方式3:UDAF - 弱类型
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
 * @ date: 2023/7/4
 * @ author: yangshibiao
 * @ desc: 项目描述
 */
object BasicModel {
    def main(args: Array[String]): Unit = {

        // 运行环境
        val conf: SparkConf = new SparkConf()
            .setAppName("BasicModel")
            .setMaster("local[*]")
        val spark: SparkSession = SparkSession
            .builder()
            .config(conf)
            .getOrCreate()
        val sc: SparkContext = spark.sparkContext
        import spark.implicits._

        // 创建数据
        val userDF: DataFrame = sc
            .makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40)))
            .toDF("user_id", "age")
        userDF.createTempView("user")

        //创建聚合函数,并在 spark 中注册聚合函数
        val myAverage: MyAveragUDAF = new MyAveragUDAF
        spark.udf.register("avg_age", myAverage)

        // 使用并展示数据
        spark.sql("select avg_age(age) as avg_age from user").show()

        // 关闭资源
        spark.stop()

    }
}


/**
 * 定义类继承 UserDefinedAggregateFunction,并重写其中方法
 */
class MyAveragUDAF extends UserDefinedAggregateFunction {

    // 聚合函数输入参数的数据类型
    def inputSchema: StructType = StructType(Array(StructField("age", IntegerType)))

    // 聚合函数缓冲区中值的数据类型(age,count)
    def bufferSchema: StructType = StructType(Array(StructField("sum", LongType), StructField("count", LongType)))

    // 函数返回值的数据类型
    def dataType: DataType = DoubleType

    // 稳定性:对于相同的输入是否一直返回相同的输出。
    def deterministic: Boolean = true

    // 函数缓冲区初始化
    def initialize(buffer: MutableAggregationBuffer): Unit = {
        // 存年龄的总和
        buffer(0) = 0L
        // 存年龄的个数
        buffer(1) = 0L
    }

    // 更新缓冲区中的数据
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if (!input.isNullAt(0)) {
            buffer(0) = buffer.getLong(0) + input.getInt(0)
            buffer(1) = buffer.getLong(1) + 1
        }
    }

    // 合并缓冲区
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    // 计算最终结果
    def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
  • 实现方式4:强类型

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}

/**
 * @ date: 2023/7/4
 * @ author: yangshibiao
 * @ desc: 项目描述
 */
object BasicModel {
    def main(args: Array[String]): Unit = {

        // 运行环境
        val conf: SparkConf = new SparkConf()
            .setAppName("BasicModel")
            .setMaster("local[*]")
        val spark: SparkSession = SparkSession
            .builder()
            .config(conf)
            .getOrCreate()
        val sc: SparkContext = spark.sparkContext
        import spark.implicits._

        // 创建数据,并封装为 DataSet
        val userDF: DataFrame = sc
            .makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40)))
            .toDF("user_id", "age")
        val ds: Dataset[User] = userDF.as[User]
        
        //创建聚合函数
        val myAgeUdaf1: MyAveragUDAF1 = new MyAveragUDAF1
        
        //将聚合函数转换为查询的列
        val col: TypedColumn[User, Double] = myAgeUdaf1.toColumn
        
        //查询
        ds.select(col).show()

        // 关闭资源
        spark.stop()

    }
}


//输入数据类型
case class User(username: String, age: Long)

//缓存类型
case class AgeBuffer(var sum: Long, var count: Long)

/**
 * 定义类继承 org.apache.spark.sql.expressions.Aggregator
 * 重写类中的方法
 */
class MyAveragUDAF1 extends Aggregator[User, AgeBuffer, Double] {

    override def zero: AgeBuffer = {
        AgeBuffer(0L, 0L)
    }

    override def reduce(b: AgeBuffer, a: User): AgeBuffer = {
        b.sum = b.sum + a.age
        b.count = b.count + 1
        b
    }

    override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
        b1.sum = b1.sum + b2.sum
        b1.count = b1.count + b2.count
        b1
    }

    override def finish(buff: AgeBuffer): Double = {
        buff.sum.toDouble / buff.count
    }

    //DataSet 默认额编解码器,用于序列化,固定写法
    //自定义类型就是 product 自带类型根据类型选择
    override def bufferEncoder: Encoder[AgeBuffer] = {
        Encoders.product
    }

    override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
    }

}
  • Spark3.0 版本可以采用强类型的 Aggregator 方式代替 UserDefinedAggregateFunction
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}

/**
 * @ date: 2023/7/4
 * @ author: yangshibiao
 * @ desc: 项目描述
 */
object BasicModel {
    def main(args: Array[String]): Unit = {

        // 运行环境
        val conf: SparkConf = new SparkConf()
            .setAppName("BasicModel")
            .setMaster("local[*]")
        val spark: SparkSession = SparkSession
            .builder()
            .config(conf)
            .getOrCreate()
        val sc: SparkContext = spark.sparkContext
        import spark.implicits._

        // 创建数据
        val userDF: DataFrame = sc
            .makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangw", 40)))
            .toDF("user_id", "age")
        userDF.createTempView("user")

        // 创建 UDAF 函数,并注册到 SparkSQL 中
        val udaf: MyAvgAgeUDAF = new MyAvgAgeUDAF
        spark.udf.register("avgAge", functions.udaf(udaf))

        // 在 SQL 中使用聚合函数
        spark.sql("select avgAge(age) from user").show

        // 关闭资源
        spark.stop()

    }
}

case class Buff(var sum: Long, var cnt: Long)

// totalage, count
class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double] {

    override def zero: Buff = Buff(0, 0)

    override def reduce(b: Buff, a: Long): Buff = {
        b.sum += a
        b.cnt += 1
        b
    }

    override def merge(b1: Buff, b2: Buff): Buff = {
        b1.sum += b2.sum
        b1.cnt += b2.cnt
        b1
    }

    override def finish(reduction: Buff): Double = {
        reduction.sum.toDouble / reduction.cnt
    }

    override def bufferEncoder: Encoder[Buff] = Encoders.product

    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

}

注:其他Spark相关系列文章链接由此进 ->  Spark文章汇总 


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

电光闪烁

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值