我想出了如何获得一个Spark来提前返回已知的固定数量的列。但是,Spark如何返回每个行可能不同的任意数量的列呢?
我在Java17中使用Spark3.3.0,假设我有一个包含100万人的DataFrame。对于每个人,我想查找每年的工资(例如从一个数据库),但每个人可能有不同年份的薪水。如果我知道有2020年和2021年的时间,我会这样做:
StructType salarySchema = createStructType(List.of(createStructField("salary2020",
createDecimalType(12, 2), true), createStructField("salary2021",
createDecimalType(12, 2), true)));
UserDefinedFunction lookupSalariesForId = udf((String id) -> {
// TODO look up salaries
return RowFactory.create(salary2020, salary2021);
}, salarySchema).asNondeterministic();
df = df.withColumn("salaries", lookupSalariesForId.apply(col("id")))
.select("*", "salaries.*");
这是Spark的迂回方式,可以将多个值从UDF加载到单个列中,然后将它们拆分为单独的列。
那么,如果一个人只有2003年和2004年的工资,而另一个人的薪水是2007年、2008年和2009年的,那该怎么办?我想为第一个人创建salary2003
和salary2004
列,然后为第二个人创建salary2007
、salary2008
、salary2009
列。我要怎么用UDF来实现呢?(我知道如何动态创建一个数组,以便通过RowFactory.create()
传回。问题是与UDF模式相关的模式是在UDF逻辑之外定义的。)
还是有什么更好的方法来解决火花呢?我是否应该为每个可能的薪资年创建一个单独的查找DataFrame,只包含个人in和一列,然后以某种方式加入它们,就像我们在关系数据库世界中所做的那样?但是,一个单独的DataFrame会给我什么好处,难道我不会回到原点去构建它吗?当然,我可以在Java中手动构建它,但是我无法从Spark引擎、并行执行器等获得好处。
简而言之,根据每个行的标识符(),为现有的DataFrame中的每一行动态添加任意数量的列的最佳方式是什么?
发布于 2022-09-10 12:29:23
您可以从UDF返回一个映射:
MapType salarySchema2 = createMapType(StringType, createDecimalType(12, 2));
UserDefinedFunction lookupSalariesForId2 = udf((String id) -> {
//the Java map containg the result of the UDF
Map<String, BigDecimal> result = new HashMap<String, BigDecimal>();
//Generate some random test data
Random r = new Random();
int years = r.nextInt(10) + 1; // max 10 years
int startYear = r.nextInt(5) + 2010;
for (int i = 0; i < years; i++) {
result.put("salary" + (startYear + i), new BigDecimal(r.nextDouble() * 1000));
}
return result;
}, salarySchema2).asNondeterministic();
df = df.withColumn("salaries2", lookupSalariesForId2.apply(col("id"))).cache();
df.show(false);
输出:
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|id |salaries2 |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|1 |{salary2014 -> 333.74} |
|2 |{salary2010 -> 841.83, salary2011 -> 764.24, salary2012 -> 703.35, salary2013 -> 727.06, salary2014 -> 314.52} |
|3 |{salary2012 -> 770.90, salary2013 -> 790.92} |
|4 |{salary2011 -> 696.24, salary2012 -> 420.56, salary2013 -> 566.10, salary2014 -> 160.99} |
|5 |{salary2011 -> 60.59, salary2012 -> 313.57, salary2013 -> 770.82, salary2014 -> 641.90, salary2015 -> 776.13, salary2016 -> 145.28, salary2017 -> 216.02}|
|6 |{salary2011 -> 842.02, salary2012 -> 565.32} |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
第二行中出现cache
的原因是第二部分:使用一些sql函数可以获得映射列中所有键的(sql函数)集合。然后,可以使用该集合为每年创建一个列:
Collection<String> years = JavaConverters.asJavaCollection((WrappedArray<String>)
df.withColumn("years", functions.map_keys(col("salaries2")))
.agg(functions.array_sort(
functions.array_distinct(
functions.flatten(
functions.collect_set(col("years"))))))
.first().get(0));
List<Column> salaries2 = years.stream().map((year) ->
col("salaries2").getItem(year).alias(year)).collect(Collectors.toList());
salaries2.add(0, col("id"));
df.select(salaries2.toArray(new Column[0])).show();
输出:
+---+----------+----------+----------+----------+----------+----------+----------+----------+
| id|salary2010|salary2011|salary2012|salary2013|salary2014|salary2015|salary2016|salary2017|
+---+----------+----------+----------+----------+----------+----------+----------+----------+
| 1| null| null| null| null| 333.74| null| null| null|
| 2| 841.83| 764.24| 703.35| 727.06| 314.52| null| null| null|
| 3| null| null| 770.90| 790.92| null| null| null| null|
| 4| null| 696.24| 420.56| 566.10| 160.99| null| null| null|
| 5| null| 60.59| 313.57| 770.82| 641.90| 776.13| 145.28| 216.02|
| 6| null| 842.02| 565.32| null| null| null| null| null|
+---+----------+----------+----------+----------+----------+----------+----------+----------+
收集所有地图中的所有年份可能需要在大型数据集上花费一些时间,因为Spark必须先处理所有UDF调用,然后收集映射键。
https://stackoverflow.com/questions/73672625
复制