在使用Java Spark处理Parquet格式的数据时,难免会遇到struct及其嵌套的格式。而现有的spark UDF不能直接接收List、类(struct)作为输入参数。 本文提供一种Java Spark Udf1 输入复杂结构的解决方法。
public class PersonEntity {
private String name;
private Integer age;
private List<AddressEntity> address;
}
public class AddressEntity {
private String street;
private String city;
}
以下以PersonEntity类作为UDF1的输入参数,Boolean作为UDF1的输出参数,来认识Spark UDF1 输入复杂结构。然后结合文章1的Spark UDF1 输出复杂结构,返回修改后的PersonEntity对象,来说明Spark UDF1能够胜任逻辑处理的工作。
直接将PersonEntity作为UDF1的输入类型,如UDF1<PersonEntity, Boolean>,会出现如下错误:
// 输入Java Class时的报错信息
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to com.sogo.getimei.entity.PersonEntity
// 输入Java List类型时的报错信息
scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.List
文章2提到将Seq转换成List及使用Row的getAs()方法,文章3给出了Sacle Seq 转换成Java List的具体方法。在此基础上测试发现将List转换成Seq,将class(struct)转换成Row可以解决问题。
以下以实现过滤得到city>80的用户为例说明(虽然不使用UDF1也可以实现,哈哈)。
PersonEntity.java (仅需关注personFilterUdf成员变量)
package com.sogo.getimei.entity;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import scala.collection.JavaConverters;
import scala.collection.mutable.Seq;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* @Created by IntelliJ IDEA.
* @author: liuzhixuan
* @Date: 2020/8/6
* @Time: 16:05
* @des:
*/
@Setter
@Getter
public class PersonEntity implements Serializable {
private String name;
private Integer age;
private List<AddressEntity> address;
/**
* DataType of PersonEntity
* @return
*/
public static DataType dataType() {
List<StructField> structFieldList = new ArrayList<>(3);
// name
structFieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
// age
structFieldList.add(DataTypes.createStructField("age", DataTypes.IntegerType, true));
// address
DataType arrayDataType = DataType.fromJson(DataTypes.createArrayType(AddressEntity.dataType()).json());
structFieldList.add(DataTypes.createStructField("address", arrayDataType, true));
// final struct
String jsonStr = DataTypes.createStructType(structFieldList).json();
return DataType.fromJson(jsonStr);
}
/**
* DataType of struct<name string, address List<String>>
* @return
*/
public static DataType simplyDataType() {
List<StructField> structFieldList = new ArrayList<>(2);
// name
structFieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
// address
DataType arrayDataType = DataType.fromJson(DataTypes.createArrayType(DataTypes.StringType).json());
structFieldList.add(DataTypes.createStructField("street", arrayDataType, true));
// final struct
String jsonStr = DataTypes.createStructType(structFieldList).json();
return DataType.fromJson(jsonStr);
}
/***
* parse: output struct
*/
public static UDF1<String, Row> personParseUdf = new UDF1<String, Row>() {
@Override
public Row call(String s) throws Exception {
PersonEntity personEntity = PersonEntity.parse(s);
List<Row> rowList = new ArrayList<>();
for (AddressEntity addressEntity : personEntity.getAddress()) {
rowList.add(RowFactory.create(addressEntity.getStreet(), addressEntity.getCity()));
}
return RowFactory.create(personEntity.getName(), personEntity.getAge(), rowList);
}
};
/**
* filter input struct
* @param str
* @return
*/
public static UDF1<Row, Boolean> personFilterUdf = new UDF1<Row, Boolean>() {
@Override
public Boolean call(Row row) throws Exception {
// use Seq instead of List
// use Row instead of java class
Seq<Row> addressRowSeq = row.getAs("address");
// transform Seq to List
List<Row> addressRowList = JavaConverters.seqAsJavaList(addressRowSeq);
for (Row addressRow : addressRowList) {
String street = addressRow.getAs("street");
String city = addressRow.getAs("city");
if (Integer.valueOf(city) > 80) {
return true;
}
}
return false;
}
};
/**
* filter input struct, output struct
* @param str
* @return
*/
public static UDF1<Row, Row> personChangeUdf = new UDF1<Row, Row>() {
@Override
public Row call(Row row) throws Exception {
String name = row.getAs("name");
// use Seq instead of List
// use Row instead of java class
Seq<Row> addressRowSeq = row.getAs("address");
// transform Seq to List
List<Row> addressRowList = JavaConverters.seqAsJavaList(addressRowSeq);
// store the street which can match the condition
List<String> resStreetRowList = new ArrayList<>();
for (Row addressRow : addressRowList) {
String street = addressRow.getAs("street");
String city = addressRow.getAs("city");
if (!StringUtils.isEmpty(city) && city.contains("8")) {
resStreetRowList.add(street);
}
}
if (resStreetRowList.size() <= 0) {
return null;
} else {
return RowFactory.create(name, resStreetRowList);
}
}
};
public static PersonEntity parse(String str) {
if (StringUtils.isEmpty(str)) {
return null;
}
String[] fields = str.split("\t", -1);
PersonEntity personEntity = new PersonEntity();
personEntity.setName(fields[0]);
personEntity.setAge(Integer.valueOf(fields[1]));
List<AddressEntity> address = new ArrayList<>();
String[] fieldsAddress = fields[2].split(",", -1);
for (String s : fieldsAddress) {
String[] add = s.split(":", -1);
address.add(new AddressEntity(add[0], add[1]));
}
personEntity.setAddress(address);
return personEntity;
}
}
Address.java
package com.sogo.getimei.entity;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* @Created by IntelliJ IDEA.
* @author: liuzhixuan
* @Date: 2020/8/6
* @Time: 16:08
* @des:
*/
@Setter
@Getter
public class AddressEntity implements Serializable {
private String street;
private String city;
public AddressEntity() {}
public AddressEntity(String street, String city) {
this.street = street;
this.city = city;
}
public static DataType dataType() {
List<StructField> structFieldList = new ArrayList<>(2);
structFieldList.add(DataTypes.createStructField("street", DataTypes.StringType, true));
structFieldList.add(DataTypes.createStructField("city", DataTypes.StringType, true));
String jsonStr = DataTypes.createStructType(structFieldList).json();
// 组装成struct DataType
return DataType.fromJson(jsonStr);
}
}
studyDs的数据与结构如下:
+----+---+----------------------------------+
|name|age|address |
+----+---+----------------------------------+
|liu1|90 |[[Chn, 99], [Math, 98], [Eng, 97]]|
|liu2|80 |[[Chn, 89], [Math, 88], [Eng, 87]]|
|liu3|70 |[[Chn, 79], [Math, 78], [Eng, 77]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
+----+---+----------------------------------+
root
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- address: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- street: string (nullable = true)
| | |-- city: string (nullable = true)
// UDF 注册
spark.udf().register("personFilterUdf", PersonEntity.personFilterUdf, DataTypes.BooleanType);
// struct方法对数据进行封装
Dataset<Row> studyFilterDs = studyDs.select(struct(col("name"), col("age"), col("address")).alias("person"))
.filter("personFilterUdf(person)");
studyFilterDs.show(20, 0);
studyFilterDs.printSchema();
正确得到city>80的用户
+----------------------------------------------+
|person |
+----------------------------------------------+
|[liu1, 90, [[Chn, 99], [Math, 98], [Eng, 97]]]|
|[liu2, 80, [[Chn, 89], [Math, 88], [Eng, 87]]]|
+----------------------------------------------+
root
|-- person: struct (nullable = false)
| |-- name: string (nullable = true)
| |-- age: integer (nullable = true)
| |-- address: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- street: string (nullable = true)
| | | |-- city: string (nullable = true)
关注PersonEntity.java中的personChangeUdf成员变量和simplyDataType()方法。
studyDs同上
// UDF 注册
spark.udf().register("personChangeUdf", PersonEntity.personChangeUdf, PersonEntity.simplyDataType());
// 数据处理
Dataset<Row> studyFilterDs = studyDs.select(struct(col("name"), col("age"), col("address")).alias("person"))
.selectExpr("personChangeUdf(person) as simplePerson")
.filter(col("simplePerson").isNotNull())
.selectExpr("simplePerson.name as name", "simplePerson.street as street");
studyFilterDs.show(20, 0);
studyFilterDs.printSchema();
正确得到city中含有"8"的street
+----+----------------+
|name|street |
+----+----------------+
|liu1|[Math] |
|liu2|[Chn, Math, Eng]|
|liu3|[Math] |
|liu4|[Math] |
+----+----------------+
root
|-- name: string (nullable = true)
|-- street: array (nullable = true)
| |-- element: string (containsNull = true)
错误描述
scala.collection.mutable.WrappedArray$ofRef cannot be cast to scala.collection.immutable.Seq
解决 </br>
可能是引包问题,将
import scala.collection.immutable.Seq;
替换成
import scala.collection.mutable.Seq;
import scala.collection.JavaConverters;
// 转换成java list
Seq<String> seqString = ...;
List<String> listString = JavaConverters.seqAsJavaList(seqString);
// 转换成java map
JavaConverters.mapAsJavaMap(map<...,...>);
UDF1中输入复杂结构的关键点在于解决Scale和Java类型转换的问题。一般情况下,将List转换成Seq,将class(struct)转换成Row即可解决问题。
1 Spark UDF1 返回复杂结构 https://cloud.tencent.com/developer/article/1674399
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。