首先明确一点:UDAF不仅仅用于agg()算子中
虽然Spark3.0.0的官方文档1已对Spark Java UDAF进行了说明,并且有example代码。因此本文主要解决在实际开发过程中,遇到的2种问题:
以下逐一进行描述说明。
先说明下Spark Java UDAF的2种实现形式2。第一种是继承UserDefinedAggregateFunction类,实现里面的8个方法,这种方式在Spark3.0.0中已标记为Depressed。第二种是继承Aggregator类,实现6个方法。
实现这样一个UDAF,统计AddressEntity中street出现的次数和对city的求和。
AddressEntity.java
public class AddressEntity implements Serializable {
private String city;
private String street;
}
PersonAnalizeEntity.java
(由于数据量不大,我们用2个map分别记录street的词频和city的累积和)
package com.sogo.getimei.entity;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* @Created by IntelliJ IDEA.
* @author: liuzhixuan
* @Date: 2020/8/8
* @Time: 21:55
* @des:
*/
@Setter
@Getter
public class PersonAnalizeEntity implements Serializable {
// record the number of street display
private Map<String, Integer> streetCountMap;
// record the sum of city
private Map<String, Integer> streetSumMap;
public PersonAnalizeEntity() {
this.streetCountMap = new HashMap<>();
this.streetSumMap = new HashMap<>();
}
}
AddressAnaliseUdaf.java
UDAF的代码实现
package com.sogo.getimei.udf;
import com.sogo.getimei.entity.AddressEntity;
import com.sogo.getimei.entity.PersonAnalizeEntity;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
import java.util.Map;
import java.util.Set;
/**
* @Created by IntelliJ IDEA.
* @author: liuzhixuan
* @Date: 2020/8/9
* @Time: 14:45
* @des:
*/
// 继承 Aggregator类
public class AddressAnaliseUdaf extends Aggregator<AddressEntity, PersonAnalizeEntity, PersonAnalizeEntity> {
// 初始化
@Override
public PersonAnalizeEntity zero() {
return new PersonAnalizeEntity();
}
// 分区内计算
@Override
public PersonAnalizeEntity reduce(PersonAnalizeEntity b, AddressEntity addressEntity) {
// 存在street,在StreetCountMap加1、StreetSumMap加和
if (b.getStreetCountMap().containsKey(addressEntity.getStreet())) {
b.getStreetCountMap().put(addressEntity.getStreet(),
b.getStreetCountMap().get(addressEntity.getStreet()) + 1);
b.getStreetSumMap().put(addressEntity.getStreet(),
b.getStreetSumMap().get(addressEntity.getStreet()) + Integer.valueOf(addressEntity.getCity()));
} else {
b.getStreetCountMap().put(addressEntity.getStreet(), 1);
b.getStreetSumMap().put(addressEntity.getStreet(), Integer.valueOf(addressEntity.getCity()));
}
return b;
}
// 分区间合并
@Override
public PersonAnalizeEntity merge(PersonAnalizeEntity b1, PersonAnalizeEntity b2) {
Set<Map.Entry<String, Integer>> entries = b2.getStreetCountMap().entrySet();
for (Map.Entry<String, Integer> entry : b2.getStreetCountMap().entrySet()) {
if (b1.getStreetCountMap().containsKey(entry.getKey())) {
b1.getStreetCountMap().put(entry.getKey(),
entry.getValue() + b1.getStreetCountMap().get(entry.getKey()));
b1.getStreetSumMap().put(entry.getKey(),
b2.getStreetSumMap().get(entry.getKey()) + b1.getStreetSumMap().get(entry.getKey()));
} else {
b1.getStreetCountMap().put(entry.getKey(), entry.getValue());
b1.getStreetSumMap().put(entry.getKey(), b2.getStreetSumMap().get(entry.getKey()));
}
}
return b1;
}
// 最终输出的结果
@Override
public PersonAnalizeEntity finish(PersonAnalizeEntity reduction) {
return reduction;
}
// 中间结果的schema
@Override
public Encoder<PersonAnalizeEntity> bufferEncoder() {
return Encoders.bean(PersonAnalizeEntity.class);
}
// 最终输出结果的schema
@Override
public Encoder<PersonAnalizeEntity> outputEncoder() {
return Encoders.bean(PersonAnalizeEntity.class);
}
}
Type-Safe和Untyped类型是针对Aggregator而言的。
简单而言,Type-Safe类型是针对Dataset<Entity>类型的,有类型检查。Untyped针对Dataset<Row>类型,或者用于SparkSQL中。理清楚了它们的使用场景,就可以避免混用导致的错误。
Dataset<Row> studyDs的数据样例与结构如下:
+----+---+----------------------------------+
|name|age|address |
+----+---+----------------------------------+
|liu1|90 |[[Chn, 99], [Math, 98], [Eng, 97]]|
|liu2|80 |[[Chn, 89], [Math, 88], [Eng, 87]]|
|liu3|70 |[[Chn, 79], [Math, 78], [Eng, 77]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
+----+---+----------------------------------+
root
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- address: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- street: string (nullable = true)
| | |-- city: string (nullable = true)
代码说明见注解部分
Dataset<PersonAnalizeEntity> aFinal = studyDs
// IUDAF输入类型AddressEnitty,因此需要将List<AddressEnitty>拍平
.selectExpr("explode(address) as address")
// 这里非常关键,需要解析出AddressEntity的各字段,才能被反序列化
.selectExpr("address.city as city", "address.street as street")
// Typed-safe
.as(Encoders.bean(AddressEntity.class))
// 通过调用UDAF的toColumn即可进行聚合计算
.select(new AddressAnaliseUdaf().toColumn());
aFinal.show(10,0);
aFinal.printSchema();
测试结果符合预期
+-------------------------------+-------------------------------------+
|streetCountMap |streetSumMap |
+-------------------------------+-------------------------------------+
|[Chn -> 5, Math -> 5, Eng -> 5]|[Chn -> 405, Math -> 400, Eng -> 395]|
+-------------------------------+-------------------------------------+
root
|-- streetCountMap: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
|-- streetSumMap: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
反序列化成bean对象时,如果不拆分出address struct的各子字段city、street,则会出现下面的错误:
org.apache.spark.sql.AnalysisException: cannot resolve '`city`' given input columns: [address]
注册UDAF
// udaf参数:1. UDAF对象, 2. 输入类型的Encoder
spark.udf().register("AddressAnaliseUdaf", udaf(new AddressAnaliseUdaf(), Encoders.bean(AddressEntity.class)));
第一种调用方式:callUDF方式调用 (成功)
Dataset<Row> agg = studyDs
.selectExpr("name", "explode(address) as address")
.selectExpr("name", "address.city as city", "address.street as street")
// 这里也很关键:输入的字段按字段名的字典序排序
.agg(callUDF("AddressAnaliseUdaf", col("city"), col("street")).alias("cal_result"))
.selectExpr("cal_result.streetCountMap as streetCountMap", "cal_result.streetSumMap as streetSumMap");
agg.show(10, 0);
agg.printSchema();
输出结果符合预期
+-------------------------------+-------------------------------------+
|streetCountMap |streetSumMap |
+-------------------------------+-------------------------------------+
|[Chn -> 5, Math -> 5, Eng -> 5]|[Chn -> 405, Math -> 400, Eng -> 395]|
+-------------------------------+-------------------------------------+
root
|-- streetCountMap: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
|-- streetSumMap: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
第二种调用方式:在SQL中调用
文章1中提供的demo是简单结构,这里想实现复杂嵌套的UDAF,终于解决了
尝试1(失败)
studyDs.selectExpr("explode(address) as address")
.registerTempTable("study");
Dataset<Row> sqlRow = spark.sql("select AddressAnaliseUdaf(address) from study");
报错信息如下:
Caused by: org.apache.spark.sql.AnalysisException: cannot resolve 'AddressAnaliseUdaf(address)' due to data type mismatch: argument 1 requires string type, however, 'study.`a
ddress`' is of struct<city:string,street:string> type.; line 1 pos 7
让问题变得迷茫的报错。
尝试2(成功)
studyDs.createOrReplaceTempView("study");
// 同样,UDAF中需要输入AddressEntity的各字段
// 需要按照AddressEntity中定义的顺序排序(可以随意修改字段名)
Dataset<Row> sqlRow = spark.sql("SELECT AddressAnaliseUdaf(address.city,address.street) FROM (SELECT explode(address) AS address FROM study)");
sqlRow.show(10, 0);
sqlRow.printSchema();
输出结果符合预期
+------------------------------------------------------------------------+
|addressanaliseudaf(address.city AS `city`, address.street AS `street`) |
+------------------------------------------------------------------------+
|[[Chn -> 5, Math -> 5, Eng -> 5], [Chn -> 405, Math -> 400, Eng -> 395]]|
+------------------------------------------------------------------------+
root
|-- addressanaliseudaf(address.city AS `city`, address.street AS `street`): struct (nullable = true)
| |-- streetCountMap: map (nullable = true)
| | |-- key: string
| | |-- value: integer (valueContainsNull = true)
| |-- streetSumMap: map (nullable = true)
| | |-- key: string
| | |-- value: integer (valueContainsNull = true)
测试修改字段名: 可以修改字段名
Dataset<Row> sqlRow = spark.sql("SELECT AddressAnaliseUdaf(city1,street1) FROM (SELECT address.city as city1, address.street as street1 FROM (SELECT explode(address) AS address FROM study))");
输出结果符合预期
+------------------------------------------------------------------------+
|addressanaliseudaf(city1, street1) |
+------------------------------------------------------------------------+
|[[Chn -> 5, Math -> 5, Eng -> 5], [Chn -> 405, Math -> 400, Eng -> 395]]|
+------------------------------------------------------------------------+
root
|-- addressanaliseudaf(city1, street1): struct (nullable = true)
| |-- streetCountMap: map (nullable = true)
| | |-- key: string
| | |-- value: integer (valueContainsNull = true)
| |-- streetSumMap: map (nullable = true)
| | |-- key: string
| | |-- value: integer (valueContainsNull = true)
测试结果符合预期
实现Spark Java UDAFs,只需要继承Aggregator类并实现其方法;在Typed-Safe下,只要保证反序列化成Dataset Entity对象后,即可通过UDAF对象的toColumn方法实现聚合计算。在Untyped下,使用callFunction和SQL调用要注意输入的字段顺序。这2种方式都需要注意将Entity展开成字段,进行传递。
2 spark中自定义UDAF函数实现的两种方式 https://blog.csdn.net/weixin_43861104/article/details/107358874
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。