首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >一个数组列与另一个(布尔)数组列的子集

一个数组列与另一个(布尔)数组列的子集
EN

Stack Overflow用户
提问于 2019-04-23 00:06:25
回答 2查看 684关注 0票数 3

我有一个这样的数据帧(在Pyspark 2.3.1中):

代码语言:javascript
复制
from pyspark.sql import Row

my_data = spark.createDataFrame([
  Row(a=[9, 3, 4], b=['a', 'b', 'c'], mask=[True, False, False]),
  Row(a=[7, 2, 6, 4], b=['w', 'x', 'y', 'z'], mask=[True, False, True, False])
])
my_data.show(truncate=False)
#+------------+------------+--------------------------+
#|a           |b           |mask                      |
#+------------+------------+--------------------------+
#|[9, 3, 4]   |[a, b, c]   |[true, false, false]      |
#|[7, 2, 6, 4]|[w, x, y, z]|[true, false, true, false]|
#+------------+------------+--------------------------+

现在我想使用mask列来设置ab列的子集:

代码语言:javascript
复制
my_desired_output = spark.createDataFrame([
  Row(a=[9], b=['a']),
  Row(a=[7, 6], b=['w', 'y'])
])
my_desired_output.show(truncate=False)
#+------+------+
#|a     |b     |
#+------+------+
#|[9]   |[a]   |
#|[7, 6]|[w, y]|
#+------+------+

实现这一点的“惯用”方法是什么?我目前的解决方案涉及到底层RDD上的map-ing和Numpy的子集,这似乎并不优雅:

代码语言:javascript
复制
import numpy as np

def subset_with_mask(row):
    mask = np.asarray(row.mask)
    a_masked = np.asarray(row.a)[mask].tolist()
    b_masked = np.asarray(row.b)[mask].tolist()
    return Row(a=a_masked, b=b_masked)

my_desired_output = spark.createDataFrame(my_data.rdd.map(subset_with_mask))

这是最好的方法吗,或者有没有更好的(不那么冗长和/或更有效)我可以使用Spark SQL工具做的事情?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-04-23 00:35:07

一种选择是使用UDF,您可以根据数组中的数据类型选择性地对其进行专门化:

代码语言:javascript
复制
import numpy as np
import pyspark.sql.functions as F
import pyspark.sql.types as T

def _mask_list(lst, mask):
    return np.asarray(lst)[mask].tolist()

mask_array_int = F.udf(_mask_list, T.ArrayType(T.IntegerType()))
mask_array_str = F.udf(_mask_list, T.ArrayType(T.StringType()))

my_desired_output = my_data
my_desired_output = my_desired_output.withColumn(
    'a', mask_array_int(F.col('a'), F.col('mask'))
)
my_desired_output = my_desired_output.withColumn(
    'b', mask_array_str(F.col('b'), F.col('mask'))
)
票数 2
EN

Stack Overflow用户

发布于 2019-04-23 03:19:37

在前面的答案中提到的UDF可能是在Spark 2.4中添加数组函数之前的方法。为了完整起见,这里有一个2.4版本之前的“纯SQL”实现。

代码语言:javascript
复制
from pyspark.sql.functions import *

df = my_data.withColumn("row", monotonically_increasing_id())

df1 = df.select("row", posexplode("a").alias("pos", "a"))
df2 = df.select("row", posexplode("b").alias("pos", "b"))
df3 = df.select("row", posexplode("mask").alias("pos", "mask"))

df1\
    .join(df2, ["row", "pos"])\
    .join(df3, ["row", "pos"])\
    .filter("mask")\
    .groupBy("row")\
    .agg(collect_list("a").alias("a"), collect_list("b").alias("b"))\
    .select("a", "b")\
    .show()

输出:

代码语言:javascript
复制
+------+------+
|     a|     b|
+------+------+
|[7, 6]|[w, y]|
|   [9]|   [a]|
+------+------+
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55797337

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档