前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Spark SQL源码研读系列01:ParseTree

Spark SQL源码研读系列01:ParseTree

原创
作者头像
百万大虾
修改2022-11-09 11:52:57
1.1K0
修改2022-11-09 11:52:57
举报
文章被收录于专栏:Spark SQL源码研读

Antlr

概念

ANTLR是Another Tool for Language Recognition的缩写。

它是一款强大的语法分析器生成工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件。

第一阶段:词法分析,把输入文本转换为词法符号(词法符号,token)。词法符号至少包含两部分信息:词法符号的类型和词法符号对应的文本。 第二阶段:语法分析,从输入的词法符号中识别语句结构,antlr生成的语法分析器会构建语法分析树(parse tree),它记录了语法分析器识别出输入语句结构的过程,以及该结构的各组成部分。

ANTLR可以自动生成词法分析器(Lexer)、语法分析器(Parser)和树分析器(Tree Parser)。

词法分析器(Lexer):是分析量化那些本来毫无意义的字符流,将他们翻译成离散的字符组(token),供语法分析器使用。

语法分析器(Parser):将收到的tokens组织起来,并转换成语法规则定义的所允许的结构。

树分析器(Tree Parser):用于对语法分析生成的抽象语法树进行遍历,并能执行一些相关的操作。

遍历模式

Antlr4有两种遍历模式:

  • Listener模式由Antlr提供的walker对象自动调用,而Visitor模式则必须通过显式的访问调用遍历其子级,如果忘记在节点的子节点上调用visit方法,意味着子树不会被访问;
  • Listener模式不能返回值,而Visitor模式可以返回任何自定义类型。因此,Listener模式就只能用一些变量来存储中间值,而Visitor可以直接返回计算值;
  • Listener模式触发某个事件,然后做某个操作。进入stat节点enterStat(),退出stat节点exitStat()。Antlr内建的树遍历器会去触发在Listener中像enterStat和exitStat的一串回调方法。

小结

  1. 通过parser返回一个context的树,ParserTree tree = parser.stat();
  2. visitor.visit(tree),在visit中调用context的accept方法,StatContext.accept;
  3. 在context调用visitor的具体实现方法,如visitAddSub;在实现visitor方法时候,注意如果还有childContent,继续往下。

备注:ANTLR语法的学习,可以参考书籍《ANTLR权威指南》

SQL解析

Spark SQL通过Antlr4定义SQL的语法规则,完成SQL词法,语法解析,最后将SQL转化为抽象语法树。

.g4文件

在如下路径:

src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4

src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

其中SqlBaseLexer.g4是词法文件,SqlBaseParser.g4是语法文件,Spark SQL就是通过这两个文件来解析SQL的。

方法调用

SparkSession.sql() --> AbstractSqlParser.parsePlan() --> AbstractSqlParser.parse()

代码语言:javascript
复制
/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
  val ctx = parser.singleStatement()
  withOrigin(ctx, Some(sqlText)) {
    astBuilder.visitSingleStatement(ctx) match {
      case plan: LogicalPlan => plan
      case _ =>
        val position = Origin(None, None)
        throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, position)
    }
  }
}

protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
  logDebug(s"Parsing command: $command")

  val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
  lexer.removeErrorListeners()
  lexer.addErrorListener(ParseErrorListener)

  val tokenStream = new CommonTokenStream(lexer)
  val parser = new SqlBaseParser(tokenStream)
  parser.addParseListener(PostProcessor)
  parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
  parser.removeErrorListeners()
  parser.addErrorListener(ParseErrorListener)
  parser.setErrorHandler(new SparkParserErrorStrategy())
  parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
  parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
  parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords

  try {
    try {
      // first, try parsing with potentially faster SLL mode
      parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
      // 返回paser
      // toResult方法内容就是parsePlan方法大括号的内容
      toResult(parser)
    }
    catch {
      case e: ParseCancellationException =>
        // if we fail, parse with LL mode
        tokenStream.seek(0) // rewind input stream
        parser.reset()

        // Try Again.
        parser.getInterpreter.setPredictionMode(PredictionMode.LL)
        toResult(parser)
    }
  }
  catch {
    case e: ParseException if e.command.isDefined =>
      throw e
    case e: ParseException =>
      throw e.withCommand(command)
    case e: AnalysisException =>
      val position = Origin(e.line, e.startPosition)
      throw new ParseException(Option(command), e.message, position, position,
        e.errorClass, e.messageParameters)
  }
}

ParseInterface

类关系

方法解析

代码语言:javascript
复制
/**
 * Interface for a parser.
 */
@DeveloperApi
trait ParserInterface {
  /**
   * Parse a string to a [[LogicalPlan]].
   */
  // 字符串解析为逻辑计划
  @throws[ParseException]("Text cannot be parsed to a LogicalPlan")
  def parsePlan(sqlText: String): LogicalPlan

  /**
   * Parse a string to an [[Expression]].
   */
  // 字符串解析为expression
  @throws[ParseException]("Text cannot be parsed to an Expression")
  def parseExpression(sqlText: String): Expression

  /**
   * Parse a string to a [[TableIdentifier]].
   */
  // 字符串解析为表语句
  @throws[ParseException]("Text cannot be parsed to a TableIdentifier")
  def parseTableIdentifier(sqlText: String): TableIdentifier

  /**
   * Parse a string to a [[FunctionIdentifier]].
   */
  // 字符串解析为函数
  @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
  def parseFunctionIdentifier(sqlText: String): FunctionIdentifier

  /**
   * Parse a string to a multi-part identifier.
   */
  // 字符串解析为多表达式
  @throws[ParseException]("Text cannot be parsed to a multi-part identifier")
  def parseMultipartIdentifier(sqlText: String): Seq[String]

  /**
   * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
   * of field definitions which will preserve the correct Hive metadata.
   */
  // 字符串解析为结构化类型
  @throws[ParseException]("Text cannot be parsed to a schema")
  def parseTableSchema(sqlText: String): StructType

  /**
   * Parse a string to a [[DataType]].
   */
  // 字符串解析为DataType
  @throws[ParseException]("Text cannot be parsed to a DataType")
  def parseDataType(sqlText: String): DataType

  /**
   * Parse a query string to a [[LogicalPlan]].
   */
  @throws[ParseException]("Text cannot be parsed to a LogicalPlan")
  def parseQuery(sqlText: String): LogicalPlan
}

AbstractSqlParser

SQL解析基础架构

方法解析

代码语言:javascript
复制
/**
 * Base SQL parsing infrastructure.
 */
abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with Logging {

  /** Creates/Resolves DataType for a given SQL string. */
  override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
    astBuilder.visitSingleDataType(parser.singleDataType())
  }

  /** Creates Expression for a given SQL string. */
  override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
    val ctx = parser.singleExpression()
    withOrigin(ctx, Some(sqlText)) {
      astBuilder.visitSingleExpression(ctx)
    }
  }

  /** Creates TableIdentifier for a given SQL string. */
  override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
    astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
  }

  /** Creates FunctionIdentifier for a given SQL string. */
  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
    parse(sqlText) { parser =>
      astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
    }
  }

  /** Creates a multi-part identifier for a given SQL string */
  override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
    parse(sqlText) { parser =>
      astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
    }
  }

  /**
   * Creates StructType for a given SQL string, which is a comma separated list of field
   * definitions which will preserve the correct Hive metadata.
   */
  override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser =>
    astBuilder.visitSingleTableSchema(parser.singleTableSchema())
  }

  /** Creates LogicalPlan for a given SQL string of query. */
  override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    val ctx = parser.query()
    withOrigin(ctx, Some(sqlText)) {
      astBuilder.visitQuery(ctx)
    }
  }

  /** Creates LogicalPlan for a given SQL string. */
  // unResolved_02
  override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    val ctx = parser.singleStatement()
    withOrigin(ctx, Some(sqlText)) {
      // unResolved_04
      astBuilder.visitSingleStatement(ctx) match {
        case plan: LogicalPlan => plan
        case _ =>
          val position = Origin(None, None)
          throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, position)
      }
    }
  }

  /** Get the builder (visitor) which converts a ParseTree into an AST. */
  protected def astBuilder: AstBuilder
  // unResolved_03
  protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
    logDebug(s"Parsing command: $command")
    // 词法分析器
    val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
    lexer.removeErrorListeners()
    lexer.addErrorListener(ParseErrorListener)
    // token流
    val tokenStream = new CommonTokenStream(lexer)
    // 语法分析器
    val parser = new SqlBaseParser(tokenStream)
    parser.addParseListener(PostProcessor)
    parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
    parser.removeErrorListeners()
    parser.addErrorListener(ParseErrorListener)
    parser.setErrorHandler(new SparkParserErrorStrategy())
    parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
    parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
    parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords

    try {
      try {
        // first, try parsing with potentially faster SLL mode
        parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
        // 返回parser
        toResult(parser)
      }
      catch {
        case e: ParseCancellationException =>
          // if we fail, parse with LL mode
          tokenStream.seek(0) // rewind input stream
          parser.reset()

          // Try Again.
          parser.getInterpreter.setPredictionMode(PredictionMode.LL)
          toResult(parser)
      }
    }
    catch {
      case e: ParseException if e.command.isDefined =>
        throw e
      case e: ParseException =>
        throw e.withCommand(command)
      case e: AnalysisException =>
        val position = Origin(e.line, e.startPosition)
        throw new ParseException(Option(command), e.message, position, position,
          e.errorClass, e.messageParameters)
    }
  }
}

ParseUtils

方法解析

代码语言:javascript
复制
/**
 * A collection of utility methods for use during the parsing process.
 */
object ParserUtils {

  val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
  val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
  val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
  val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r

  /** Get the command which created the token. */
    // 创建获取token的命令
  def command(ctx: ParserRuleContext): String = {
    val stream = ctx.getStart.getInputStream
    stream.getText(Interval.of(0, stream.size() - 1))
  }
  // 非法操作异常处理类
  def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = {
    throw QueryParsingErrors.operationNotAllowedError(message, ctx)
  }
  // 检验多个表示异常处理类
  def checkDuplicateClauses[T](
      nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = {
    if (nodes.size() > 1) {
      throw QueryParsingErrors.duplicateClausesError(clauseName, ctx)
    }
  }

  /** Check if duplicate keys exist in a set of key-value pairs. */
    // 检验是否包含多个重复的key
  def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = {
    keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) =>
      throw QueryParsingErrors.duplicateKeysError(key, ctx)
    }
  }

  /** Get the code that creates the given node. */
    // 根据已有节点,创建指定代码
  def source(ctx: ParserRuleContext): String = {
    val stream = ctx.getStart.getInputStream
    stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
  }

  /** Get all the text which comes after the given rule. */
    // 根据给定规则,获取所有文本
  def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)

  /** Get all the text which comes after the given token. */
    // 根据给定token,获取所有文本
  def remainder(token: Token): String = {
    val stream = token.getInputStream
    val interval = Interval.of(token.getStopIndex + 1, stream.size() - 1)
    stream.getText(interval)
  }

  /**
   * Get all the text which between the given start and end tokens.
   * When we need to extract everything between two tokens including all spaces we should use
   * this method instead of defined a named Antlr4 rule for .*?,
   * which somehow parse "a b" -> "ab" in some cases
   */
  def interval(start: Token, end: Token): String = {
    val interval = Interval.of(start.getStopIndex + 1, end.getStartIndex - 1)
    start.getInputStream.getText(interval)
  }

  /** Convert a string token into a string. */
    // 把token转换为字符串
  def string(token: Token): String = unescapeSQLString(token.getText)

  /** Convert a string node into a string. */
    // 把node转化为字符串
  def string(node: TerminalNode): String = unescapeSQLString(node.getText)

  /** Convert a string node into a string without unescaping. */
  def stringWithoutUnescape(node: TerminalNode): String = {
    // STRING parser rule forces that the input always has quotes at the starting and ending.
    node.getText.slice(1, node.getText.size - 1)
  }

  /** Collect the entries if any. */
  def entry(key: String, value: Token): Seq[(String, String)] = {
    Option(value).toSeq.map(x => key -> string(x))
  }

  /** Get the origin (line and position) of the token. */
  def position(token: Token): Origin = {
    val opt = Option(token)
    Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine))
  }

  def positionAndText(
      startToken: Token,
      stopToken: Token,
      sqlText: String,
      objectType: Option[String],
      objectName: Option[String]): Origin = {
    val startOpt = Option(startToken)
    val stopOpt = Option(stopToken)
    Origin(
      line = startOpt.map(_.getLine),
      startPosition = startOpt.map(_.getCharPositionInLine),
      startIndex = startOpt.map(_.getStartIndex),
      stopIndex = stopOpt.map(_.getStopIndex),
      sqlText = Some(sqlText),
      objectType = objectType,
      objectName = objectName)
  }

  /** Validate the condition. If it doesn't throw a parse exception. */
  def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
    if (!f) {
      throw new ParseException(message, ctx)
    }
  }

  /**
   * Register the origin of the context. Any TreeNode created in the closure will be assigned the
   * registered origin. This method restores the previously set origin after completion of the
   * closure.
   */
    // 注册解析规则上下文的来源
    // 在闭环中创建的任何数节点都将呗指定为已注册的源节点
    // 此方法在闭环完成后恢复先前设置的原点
  def withOrigin[T](ctx: ParserRuleContext, sqlText: Option[String] = None)(f: => T): T = {
    val current = CurrentOrigin.get
    val text = sqlText.orElse(current.sqlText)
    if (text.isEmpty) {
      CurrentOrigin.set(position(ctx.getStart))
    } else {
      CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get,
        current.objectType, current.objectName))
    }
    try {
      f
    } finally {
      CurrentOrigin.set(current)
    }
  }

  /** Unescape backslash-escaped string enclosed by quotes. */
  def unescapeSQLString(b: String): String = {
    val sb = new StringBuilder(b.length())
    // 字符串转换并赋值到sb数组中
    def appendEscapedChar(n: Char): Unit = {
      n match {
        case '0' => sb.append('\u0000')
        case '\'' => sb.append('\'')
        case '"' => sb.append('\"')
        case 'b' => sb.append('\b')
        case 'n' => sb.append('\n')
        case 'r' => sb.append('\r')
        case 't' => sb.append('\t')
        case 'Z' => sb.append('\u001A')
        case '\\' => sb.append('\\')
        // The following 2 lines are exactly what MySQL does TODO: why do we do this?
        case '%' => sb.append("\\%")
        case '_' => sb.append("\\_")
        case _ => sb.append(n)
      }
    }

    if (b.startsWith("r") || b.startsWith("R")) {
      b.substring(2, b.length - 1)
    } else {
      // Skip the first and last quotations enclosing the string literal.
      val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)

      while (charBuffer.remaining() > 0) {
        charBuffer match {
          case U16_CHAR_PATTERN(cp) =>
            // \u0000 style 16-bit unicode character literals.
            sb.append(Integer.parseInt(cp, 16).toChar)
            charBuffer.position(charBuffer.position() + 6)
          case U32_CHAR_PATTERN(cp) =>
            // \U00000000 style 32-bit unicode character literals.
            // Use Long to treat codePoint as unsigned in the range of 32-bit.
            val codePoint = JLong.parseLong(cp, 16)
            if (codePoint < 0x10000) {
              sb.append((codePoint & 0xFFFF).toChar)
            } else {
              val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800
              val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00
              sb.append(highSurrogate.toChar)
              sb.append(lowSurrogate.toChar)
            }
            charBuffer.position(charBuffer.position() + 10)
          case OCTAL_CHAR_PATTERN(cp) =>
            // \000 style character literals.
            sb.append(Integer.parseInt(cp, 8).toChar)
            charBuffer.position(charBuffer.position() + 4)
          case ESCAPED_CHAR_PATTERN(c) =>
            // escaped character literals.
            appendEscapedChar(c.charAt(0))
            charBuffer.position(charBuffer.position() + 2)
          case _ =>
            // non-escaped character literals.
            sb.append(charBuffer.get())
        }
      }
      sb.toString()
    }
  }

  /** the column name pattern in quoted regex without qualifier */
  val escapedIdentifier = "`((?s).+)`".r

  /** the column name pattern in quoted regex with qualifier */
  val qualifiedEscapedIdentifier = ("((?s).+)" + """.""" + "`((?s).+)`").r

  /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
  implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
    /**
     * Create a plan using the block of code when the given context exists. Otherwise return the
     * original plan.
     */
    def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
      if (ctx != null) {
        f
      } else {
        plan
      }
    }

    /**
     * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
     * passed function. The original plan is returned when the context does not exist.
     */
      // 逻辑计划转化,如果旧的上下文解析规则存在,就使用旧的function,否则就返回原始plan
    def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
      if (ctx != null) {
        f(ctx, plan)
      } else {
        plan
      }
    }
  }
}

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Antlr
    • 概念
      • 遍历模式
        • 小结
        • SQL解析
          • .g4文件
            • 方法调用
            • ParseInterface
              • 类关系
                • 方法解析
                • AbstractSqlParser
                  • 方法解析
                  • ParseUtils
                    • 方法解析
                    相关产品与服务
                    云数据库 MySQL
                    腾讯云数据库 MySQL(TencentDB for MySQL)为用户提供安全可靠,性能卓越、易于维护的企业级云数据库服务。其具备6大企业级特性,包括企业级定制内核、企业级高可用、企业级高可靠、企业级安全、企业级扩展以及企业级智能运维。通过使用腾讯云数据库 MySQL,可实现分钟级别的数据库部署、弹性扩展以及全自动化的运维管理,不仅经济实惠,而且稳定可靠,易于运维。
                    领券
                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档