ANTLR是Another Tool for Language Recognition的缩写。
它是一款强大的语法分析器生成工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件。
第一阶段:词法分析,把输入文本转换为词法符号(词法符号,token)。词法符号至少包含两部分信息:词法符号的类型和词法符号对应的文本。 第二阶段:语法分析,从输入的词法符号中识别语句结构,antlr生成的语法分析器会构建语法分析树(parse tree),它记录了语法分析器识别出输入语句结构的过程,以及该结构的各组成部分。
ANTLR可以自动生成词法分析器(Lexer)、语法分析器(Parser)和树分析器(Tree Parser)。
词法分析器(Lexer):是分析量化那些本来毫无意义的字符流,将他们翻译成离散的字符组(token),供语法分析器使用。
语法分析器(Parser):将收到的tokens组织起来,并转换成语法规则定义的所允许的结构。
树分析器(Tree Parser):用于对语法分析生成的抽象语法树进行遍历,并能执行一些相关的操作。
Antlr4有两种遍历模式:
备注:ANTLR语法的学习,可以参考书籍《ANTLR权威指南》
Spark SQL通过Antlr4定义SQL的语法规则,完成SQL词法,语法解析,最后将SQL转化为抽象语法树。
在如下路径:
src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
其中SqlBaseLexer.g4是词法文件,SqlBaseParser.g4是语法文件,Spark SQL就是通过这两个文件来解析SQL的。
SparkSession.sql() --> AbstractSqlParser.parsePlan() --> AbstractSqlParser.parse()
/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.singleStatement()
withOrigin(ctx, Some(sqlText)) {
astBuilder.visitSingleStatement(ctx) match {
case plan: LogicalPlan => plan
case _ =>
val position = Origin(None, None)
throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, position)
}
}
}
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.setErrorHandler(new SparkParserErrorStrategy())
parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
// 返回paser
// toResult方法内容就是parsePlan方法大括号的内容
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position,
e.errorClass, e.messageParameters)
}
}
/**
* Interface for a parser.
*/
@DeveloperApi
trait ParserInterface {
/**
* Parse a string to a [[LogicalPlan]].
*/
// 字符串解析为逻辑计划
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parsePlan(sqlText: String): LogicalPlan
/**
* Parse a string to an [[Expression]].
*/
// 字符串解析为expression
@throws[ParseException]("Text cannot be parsed to an Expression")
def parseExpression(sqlText: String): Expression
/**
* Parse a string to a [[TableIdentifier]].
*/
// 字符串解析为表语句
@throws[ParseException]("Text cannot be parsed to a TableIdentifier")
def parseTableIdentifier(sqlText: String): TableIdentifier
/**
* Parse a string to a [[FunctionIdentifier]].
*/
// 字符串解析为函数
@throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
/**
* Parse a string to a multi-part identifier.
*/
// 字符串解析为多表达式
@throws[ParseException]("Text cannot be parsed to a multi-part identifier")
def parseMultipartIdentifier(sqlText: String): Seq[String]
/**
* Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
* of field definitions which will preserve the correct Hive metadata.
*/
// 字符串解析为结构化类型
@throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType
/**
* Parse a string to a [[DataType]].
*/
// 字符串解析为DataType
@throws[ParseException]("Text cannot be parsed to a DataType")
def parseDataType(sqlText: String): DataType
/**
* Parse a query string to a [[LogicalPlan]].
*/
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parseQuery(sqlText: String): LogicalPlan
}
SQL解析基础架构
/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with Logging {
/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}
/** Creates Expression for a given SQL string. */
override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
val ctx = parser.singleExpression()
withOrigin(ctx, Some(sqlText)) {
astBuilder.visitSingleExpression(ctx)
}
}
/** Creates TableIdentifier for a given SQL string. */
override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
}
/** Creates FunctionIdentifier for a given SQL string. */
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
}
}
/** Creates a multi-part identifier for a given SQL string */
override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
parse(sqlText) { parser =>
astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
}
}
/**
* Creates StructType for a given SQL string, which is a comma separated list of field
* definitions which will preserve the correct Hive metadata.
*/
override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser =>
astBuilder.visitSingleTableSchema(parser.singleTableSchema())
}
/** Creates LogicalPlan for a given SQL string of query. */
override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.query()
withOrigin(ctx, Some(sqlText)) {
astBuilder.visitQuery(ctx)
}
}
/** Creates LogicalPlan for a given SQL string. */
// unResolved_02
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.singleStatement()
withOrigin(ctx, Some(sqlText)) {
// unResolved_04
astBuilder.visitSingleStatement(ctx) match {
case plan: LogicalPlan => plan
case _ =>
val position = Origin(None, None)
throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText, position)
}
}
}
/** Get the builder (visitor) which converts a ParseTree into an AST. */
protected def astBuilder: AstBuilder
// unResolved_03
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
// 词法分析器
val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
// token流
val tokenStream = new CommonTokenStream(lexer)
// 语法分析器
val parser = new SqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.setErrorHandler(new SparkParserErrorStrategy())
parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
// 返回parser
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.seek(0) // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position,
e.errorClass, e.messageParameters)
}
}
}
/**
* A collection of utility methods for use during the parsing process.
*/
object ParserUtils {
val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r
/** Get the command which created the token. */
// 创建获取token的命令
def command(ctx: ParserRuleContext): String = {
val stream = ctx.getStart.getInputStream
stream.getText(Interval.of(0, stream.size() - 1))
}
// 非法操作异常处理类
def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = {
throw QueryParsingErrors.operationNotAllowedError(message, ctx)
}
// 检验多个表示异常处理类
def checkDuplicateClauses[T](
nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = {
if (nodes.size() > 1) {
throw QueryParsingErrors.duplicateClausesError(clauseName, ctx)
}
}
/** Check if duplicate keys exist in a set of key-value pairs. */
// 检验是否包含多个重复的key
def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = {
keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) =>
throw QueryParsingErrors.duplicateKeysError(key, ctx)
}
}
/** Get the code that creates the given node. */
// 根据已有节点,创建指定代码
def source(ctx: ParserRuleContext): String = {
val stream = ctx.getStart.getInputStream
stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
}
/** Get all the text which comes after the given rule. */
// 根据给定规则,获取所有文本
def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
/** Get all the text which comes after the given token. */
// 根据给定token,获取所有文本
def remainder(token: Token): String = {
val stream = token.getInputStream
val interval = Interval.of(token.getStopIndex + 1, stream.size() - 1)
stream.getText(interval)
}
/**
* Get all the text which between the given start and end tokens.
* When we need to extract everything between two tokens including all spaces we should use
* this method instead of defined a named Antlr4 rule for .*?,
* which somehow parse "a b" -> "ab" in some cases
*/
def interval(start: Token, end: Token): String = {
val interval = Interval.of(start.getStopIndex + 1, end.getStartIndex - 1)
start.getInputStream.getText(interval)
}
/** Convert a string token into a string. */
// 把token转换为字符串
def string(token: Token): String = unescapeSQLString(token.getText)
/** Convert a string node into a string. */
// 把node转化为字符串
def string(node: TerminalNode): String = unescapeSQLString(node.getText)
/** Convert a string node into a string without unescaping. */
def stringWithoutUnescape(node: TerminalNode): String = {
// STRING parser rule forces that the input always has quotes at the starting and ending.
node.getText.slice(1, node.getText.size - 1)
}
/** Collect the entries if any. */
def entry(key: String, value: Token): Seq[(String, String)] = {
Option(value).toSeq.map(x => key -> string(x))
}
/** Get the origin (line and position) of the token. */
def position(token: Token): Origin = {
val opt = Option(token)
Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine))
}
def positionAndText(
startToken: Token,
stopToken: Token,
sqlText: String,
objectType: Option[String],
objectName: Option[String]): Origin = {
val startOpt = Option(startToken)
val stopOpt = Option(stopToken)
Origin(
line = startOpt.map(_.getLine),
startPosition = startOpt.map(_.getCharPositionInLine),
startIndex = startOpt.map(_.getStartIndex),
stopIndex = stopOpt.map(_.getStopIndex),
sqlText = Some(sqlText),
objectType = objectType,
objectName = objectName)
}
/** Validate the condition. If it doesn't throw a parse exception. */
def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
if (!f) {
throw new ParseException(message, ctx)
}
}
/**
* Register the origin of the context. Any TreeNode created in the closure will be assigned the
* registered origin. This method restores the previously set origin after completion of the
* closure.
*/
// 注册解析规则上下文的来源
// 在闭环中创建的任何数节点都将呗指定为已注册的源节点
// 此方法在闭环完成后恢复先前设置的原点
def withOrigin[T](ctx: ParserRuleContext, sqlText: Option[String] = None)(f: => T): T = {
val current = CurrentOrigin.get
val text = sqlText.orElse(current.sqlText)
if (text.isEmpty) {
CurrentOrigin.set(position(ctx.getStart))
} else {
CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get,
current.objectType, current.objectName))
}
try {
f
} finally {
CurrentOrigin.set(current)
}
}
/** Unescape backslash-escaped string enclosed by quotes. */
def unescapeSQLString(b: String): String = {
val sb = new StringBuilder(b.length())
// 字符串转换并赋值到sb数组中
def appendEscapedChar(n: Char): Unit = {
n match {
case '0' => sb.append('\u0000')
case '\'' => sb.append('\'')
case '"' => sb.append('\"')
case 'b' => sb.append('\b')
case 'n' => sb.append('\n')
case 'r' => sb.append('\r')
case 't' => sb.append('\t')
case 'Z' => sb.append('\u001A')
case '\\' => sb.append('\\')
// The following 2 lines are exactly what MySQL does TODO: why do we do this?
case '%' => sb.append("\\%")
case '_' => sb.append("\\_")
case _ => sb.append(n)
}
}
if (b.startsWith("r") || b.startsWith("R")) {
b.substring(2, b.length - 1)
} else {
// Skip the first and last quotations enclosing the string literal.
val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)
while (charBuffer.remaining() > 0) {
charBuffer match {
case U16_CHAR_PATTERN(cp) =>
// \u0000 style 16-bit unicode character literals.
sb.append(Integer.parseInt(cp, 16).toChar)
charBuffer.position(charBuffer.position() + 6)
case U32_CHAR_PATTERN(cp) =>
// \U00000000 style 32-bit unicode character literals.
// Use Long to treat codePoint as unsigned in the range of 32-bit.
val codePoint = JLong.parseLong(cp, 16)
if (codePoint < 0x10000) {
sb.append((codePoint & 0xFFFF).toChar)
} else {
val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800
val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00
sb.append(highSurrogate.toChar)
sb.append(lowSurrogate.toChar)
}
charBuffer.position(charBuffer.position() + 10)
case OCTAL_CHAR_PATTERN(cp) =>
// \000 style character literals.
sb.append(Integer.parseInt(cp, 8).toChar)
charBuffer.position(charBuffer.position() + 4)
case ESCAPED_CHAR_PATTERN(c) =>
// escaped character literals.
appendEscapedChar(c.charAt(0))
charBuffer.position(charBuffer.position() + 2)
case _ =>
// non-escaped character literals.
sb.append(charBuffer.get())
}
}
sb.toString()
}
}
/** the column name pattern in quoted regex without qualifier */
val escapedIdentifier = "`((?s).+)`".r
/** the column name pattern in quoted regex with qualifier */
val qualifiedEscapedIdentifier = ("((?s).+)" + """.""" + "`((?s).+)`").r
/** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
/**
* Create a plan using the block of code when the given context exists. Otherwise return the
* original plan.
*/
def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f
} else {
plan
}
}
/**
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
* passed function. The original plan is returned when the context does not exist.
*/
// 逻辑计划转化,如果旧的上下文解析规则存在,就使用旧的function,否则就返回原始plan
def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f(ctx, plan)
} else {
plan
}
}
}
}
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。