前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Spark UDF1 输入复杂结构

Spark UDF1 输入复杂结构

原创
作者头像
mikeLiu
修改2020-08-10 10:45:42
2.9K0
修改2020-08-10 10:45:42
举报
文章被收录于专栏:技术学习技术学习

Spark UDF1 输入复杂结构

前言

在使用Java Spark处理Parquet格式的数据时,难免会遇到struct及其嵌套的格式。而现有的spark UDF不能直接接收List、类(struct)作为输入参数。 本文提供一种Java Spark Udf1 输入复杂结构的解决方法。

代码语言:txt
复制
public class PersonEntity {
    private String name;
    private Integer age;
    private List<AddressEntity> address;
}
public class AddressEntity {
    private String street;
    private String city;
}

以下以PersonEntity类作为UDF1的输入参数,Boolean作为UDF1的输出参数,来认识Spark UDF1 输入复杂结构。然后结合文章1的Spark UDF1 输出复杂结构,返回修改后的PersonEntity对象,来说明Spark UDF1能够胜任逻辑处理的工作。

输入复杂结构,输出基础类型

直接将PersonEntity作为UDF1的输入类型,如UDF1<PersonEntity, Boolean>,会出现如下错误:

代码语言:txt
复制
// 输入Java Class时的报错信息
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to com.sogo.getimei.entity.PersonEntity
// 输入Java List类型时的报错信息
scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.List

文章2提到将Seq转换成List及使用Row的getAs()方法,文章3给出了Sacle Seq 转换成Java List的具体方法。在此基础上测试发现将List转换成Seq,将class(struct)转换成Row可以解决问题。

以下以实现过滤得到city>80的用户为例说明(虽然不使用UDF1也可以实现,哈哈)。

实现UDF1<Row, Boolean>

PersonEntity.java (仅需关注personFilterUdf成员变量)

代码语言:txt
复制
package com.sogo.getimei.entity;

import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import scala.collection.JavaConverters;
import scala.collection.mutable.Seq;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

/**
 * @Created by IntelliJ IDEA.
 * @author: liuzhixuan
 * @Date: 2020/8/6
 * @Time: 16:05
 * @des:
 */
@Setter
@Getter
public class PersonEntity implements Serializable {
    private String name;
    private Integer age;
    private List<AddressEntity> address;

    /**
     * DataType of PersonEntity
     * @return
     */
    public static DataType dataType() {
        List<StructField> structFieldList = new ArrayList<>(3);
        // name
        structFieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        // age
        structFieldList.add(DataTypes.createStructField("age", DataTypes.IntegerType, true));
        // address
        DataType arrayDataType = DataType.fromJson(DataTypes.createArrayType(AddressEntity.dataType()).json());
        structFieldList.add(DataTypes.createStructField("address", arrayDataType, true));
        // final struct
        String jsonStr = DataTypes.createStructType(structFieldList).json();
        return DataType.fromJson(jsonStr);
    }

    /**
     * DataType of struct<name string, address List<String>>
     * @return
     */
    public static DataType simplyDataType() {
        List<StructField> structFieldList = new ArrayList<>(2);
        // name
        structFieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        // address
        DataType arrayDataType = DataType.fromJson(DataTypes.createArrayType(DataTypes.StringType).json());
        structFieldList.add(DataTypes.createStructField("street", arrayDataType, true));
        // final struct
        String jsonStr = DataTypes.createStructType(structFieldList).json();
        return DataType.fromJson(jsonStr);
    }

    /***
     * parse: output struct
     */
    public static UDF1<String, Row> personParseUdf = new UDF1<String, Row>() {
        @Override
        public Row call(String s) throws Exception {
            PersonEntity personEntity = PersonEntity.parse(s);
            List<Row> rowList = new ArrayList<>();
            for (AddressEntity addressEntity : personEntity.getAddress()) {
                rowList.add(RowFactory.create(addressEntity.getStreet(), addressEntity.getCity()));
            }
            return RowFactory.create(personEntity.getName(), personEntity.getAge(), rowList);
        }
    };

    /**
     * filter input struct
     * @param str
     * @return
     */
    public static UDF1<Row, Boolean> personFilterUdf = new UDF1<Row, Boolean>() {
        @Override
        public Boolean call(Row row) throws Exception {
            // use Seq instead of List
            // use Row instead of java class
            Seq<Row> addressRowSeq = row.getAs("address");
            // transform Seq to List
            List<Row> addressRowList = JavaConverters.seqAsJavaList(addressRowSeq);
            for (Row addressRow : addressRowList) {
                String street = addressRow.getAs("street");
                String city = addressRow.getAs("city");
                if (Integer.valueOf(city) > 80) {
                    return true;
                }
            }
            return false;
        }
    };

    /**
     * filter input struct, output struct
     * @param str
     * @return
     */
    public static UDF1<Row, Row> personChangeUdf = new UDF1<Row, Row>() {
        @Override
        public Row call(Row row) throws Exception {
            String name = row.getAs("name");
            // use Seq instead of List
            // use Row instead of java class
            Seq<Row> addressRowSeq = row.getAs("address");
            // transform Seq to List
            List<Row> addressRowList = JavaConverters.seqAsJavaList(addressRowSeq);
            // store the street which can match the condition
            List<String> resStreetRowList = new ArrayList<>();
            for (Row addressRow : addressRowList) {
                String street = addressRow.getAs("street");
                String city = addressRow.getAs("city");
                if (!StringUtils.isEmpty(city) && city.contains("8")) {
                    resStreetRowList.add(street);
                }
            }
            if (resStreetRowList.size() <= 0) {
                return null;
            } else {
                return RowFactory.create(name, resStreetRowList);
            }
        }
    };


    public static PersonEntity parse(String str) {
        if (StringUtils.isEmpty(str)) {
            return null;
        }
        String[] fields = str.split("\t", -1);
        PersonEntity personEntity = new PersonEntity();
        personEntity.setName(fields[0]);
        personEntity.setAge(Integer.valueOf(fields[1]));
        List<AddressEntity> address = new ArrayList<>();
        String[] fieldsAddress = fields[2].split(",", -1);
        for (String s : fieldsAddress) {
            String[] add = s.split(":", -1);
            address.add(new AddressEntity(add[0], add[1]));
        }
        personEntity.setAddress(address);
        return personEntity;
    }
}

Address.java

代码语言:txt
复制
package com.sogo.getimei.entity;

import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

/**
 * @Created by IntelliJ IDEA.
 * @author: liuzhixuan
 * @Date: 2020/8/6
 * @Time: 16:08
 * @des:
 */
@Setter
@Getter
public class AddressEntity implements Serializable {
    private String street;
    private String city;

    public AddressEntity() {}

    public AddressEntity(String street, String city) {
        this.street = street;
        this.city = city;
    }

    public static DataType dataType() {
        List<StructField> structFieldList = new ArrayList<>(2);
        structFieldList.add(DataTypes.createStructField("street", DataTypes.StringType, true));
        structFieldList.add(DataTypes.createStructField("city", DataTypes.StringType, true));
        String jsonStr = DataTypes.createStructType(structFieldList).json();
        // 组装成struct DataType
        return DataType.fromJson(jsonStr);
    }
}

测试

测试数据

studyDs的数据与结构如下:

代码语言:txt
复制
+----+---+----------------------------------+
|name|age|address                           |
+----+---+----------------------------------+
|liu1|90 |[[Chn, 99], [Math, 98], [Eng, 97]]|
|liu2|80 |[[Chn, 89], [Math, 88], [Eng, 87]]|
|liu3|70 |[[Chn, 79], [Math, 78], [Eng, 77]]|
|liu4|60 |[[Chn, 69], [Math, 68], [Eng, 67]]|
+----+---+----------------------------------+

root
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- address: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- street: string (nullable = true)
 |    |    |-- city: string (nullable = true)
测试程序
代码语言:txt
复制
// UDF 注册
spark.udf().register("personFilterUdf", PersonEntity.personFilterUdf, DataTypes.BooleanType);
// struct方法对数据进行封装
Dataset<Row> studyFilterDs = studyDs.select(struct(col("name"), col("age"), col("address")).alias("person"))
        .filter("personFilterUdf(person)");
studyFilterDs.show(20, 0);
studyFilterDs.printSchema();
测试结果

正确得到city>80的用户

代码语言:txt
复制
+----------------------------------------------+
|person                                        |
+----------------------------------------------+
|[liu1, 90, [[Chn, 99], [Math, 98], [Eng, 97]]]|
|[liu2, 80, [[Chn, 89], [Math, 88], [Eng, 87]]]|
+----------------------------------------------+

root
 |-- person: struct (nullable = false)
 |    |-- name: string (nullable = true)
 |    |-- age: integer (nullable = true)
 |    |-- address: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- street: string (nullable = true)
 |    |    |    |-- city: string (nullable = true)

输入复杂类型,输出复杂类型

实现UDF1<Row, Row>

关注PersonEntity.java中的personChangeUdf成员变量和simplyDataType()方法。

测试

测试数据

studyDs同上

测试程序
代码语言:txt
复制
// UDF 注册
spark.udf().register("personChangeUdf", PersonEntity.personChangeUdf, PersonEntity.simplyDataType());
// 数据处理
Dataset<Row> studyFilterDs = studyDs.select(struct(col("name"), col("age"), col("address")).alias("person"))
        .selectExpr("personChangeUdf(person) as simplePerson")
        .filter(col("simplePerson").isNotNull())
        .selectExpr("simplePerson.name as name", "simplePerson.street as street");
studyFilterDs.show(20, 0);
studyFilterDs.printSchema();
测试结果

正确得到city中含有"8"的street

代码语言:txt
复制
+----+----------------+
|name|street          |
+----+----------------+
|liu1|[Math]          |
|liu2|[Chn, Math, Eng]|
|liu3|[Math]          |
|liu4|[Math]          |
+----+----------------+

root
 |-- name: string (nullable = true)
 |-- street: array (nullable = true)
 |    |-- element: string (containsNull = true)

敲黑板

Array to Seq报错

错误描述

代码语言:txt
复制
scala.collection.mutable.WrappedArray$ofRef cannot be cast to scala.collection.immutable.Seq

解决 </br>

可能是引包问题,将

代码语言:txt
复制
import scala.collection.immutable.Seq;

替换成

代码语言:txt
复制
import scala.collection.mutable.Seq;

将scale Seq 转换成 java List

代码语言:txt
复制
import scala.collection.JavaConverters;
// 转换成java list
Seq<String> seqString =  ...;
List<String> listString = JavaConverters.seqAsJavaList(seqString);
// 转换成java map
JavaConverters.mapAsJavaMap(map<...,...>);

小结

UDF1中输入复杂结构的关键点在于解决Scale和Java类型转换的问题。一般情况下,将List转换成Seq,将class(struct)转换成Row即可解决问题。

参考文献

1 Spark UDF1 返回复杂结构 https://cloud.tencent.com/developer/article/1674399

2 scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.ArrayList https://stackoverflow.com/questions/40764957/spark-java-lang-classcastexception-scala-collection-mutable-wrappedarrayofref

3 Convert from scala.collection.Seq<String> to java.util.List<String> in Java code https://stackoverflow.com/questions/17737631/convert-from-scala-collection-seqstring-to-java-util-liststring-in-java-code

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Spark UDF1 输入复杂结构
    • 前言
      • 输入复杂结构,输出基础类型
        • 实现UDF1<Row, Boolean>
        • 测试
      • 输入复杂类型,输出复杂类型
        • 实现UDF1<Row, Row>
        • 测试
      • 敲黑板
        • Array to Seq报错
        • 将scale Seq 转换成 java List
      • 小结
        • 参考文献
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档