首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在GraphFrame上聚合AggregateMessages时,如何保留所有元素?

在GraphFrame上聚合AggregateMessages时,如何保留所有元素?
EN

Stack Overflow用户
提问于 2018-04-07 00:40:10
回答 2查看 202关注 0票数 1

假设我有以下图表:

代码语言:javascript
运行
复制
scala> v.show()
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|           null|
|BBB|           null|
|QQQ|           null|
|DDD|           null|
|FFF|           null|
|EEE|           null|
|AAA|           null|
|GGG|           null|
+---+---------------+


scala> e.show()
+---+---+---+
| iD|src|dst|
+---+---+---+
|  1|CCC|AAA| 
|  2|CCC|BBB| 
...
+---+---+---+

我想运行一个聚合,以获取从目标顶点发送到源顶点的所有消息(而不仅仅是总和、第一个、最后一个等)。因此,我想要运行的命令类似于:

代码语言:javascript
运行
复制
g.aggregateMessages.sendToSrc(AM.edge("id")).agg(all(AM.msg).as("downstreamEdges")).show()

除了函数all不存在(据我所知不存在)。输出将类似于:

代码语言:javascript
运行
复制
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|         [1, 2]|
... 
+---+---------------+

我可以将上面的函数与firstlast一起使用,而不是(不存在的) all,但他们只会给我

代码语言:javascript
运行
复制
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|              1|
... 
+---+---------------+

代码语言:javascript
运行
复制
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|              2|
... 
+---+---------------+

分别使用。我怎么能保存所有的条目呢?(可能有很多,不只是1和2,还有1,2,23,45,等等)。谢谢。

EN

回答 2

Stack Overflow用户

发布于 2018-04-07 10:23:31

我对this answer进行了修改,使其具有以下特性:

代码语言:javascript
运行
复制
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.graphframes.lib.AggregateMessages

class KeepAllString extends UserDefinedAggregateFunction {
  private val AM = AggregateMessages

  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", StringType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("ids", ArrayType(StringType, containsNull = true), nullable = true) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = ArrayType(StringType,true)

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = Seq[String]()


  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
    buffer(0) = buffer.getAs[Seq[String]](0) ++ Seq(input.getAs[String](0))

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
    buffer1(0) = buffer1.getAs[Seq[String]](0) ++ buffer2.getAs[Seq[String]](0)

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = buffer.getAs[Seq[String]](0)
}

我上面的all方法就是:val all = new KeepAllString()

但是如何让它成为通用的,这样对于BigDecimal,Timestamp,等等,我可以做一些类似的事情:

val allTimestamp = new KeepAll[Timestamp]()

票数 0
EN

Stack Overflow用户

发布于 2019-11-26 00:56:51

我使用聚合函数collect_set()解决了类似的问题

代码语言:javascript
运行
复制
 agg = gx.aggregateMessages(
            f.collect_set(AM.msg).alias("aggMess"),
            sendToSrc=AM.edge("id")
            sendToDst=None)

另一个(具有重复项)将是collect_list()

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49697450

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档