首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Context Engine:Chunk、Symbol、Graph、RAG 四大策略

Context Engine:Chunk、Symbol、Graph、RAG 四大策略

作者头像
安全风信子
发布2026-05-27 09:22:02
发布2026-05-27 09:22:02
1270
举报
文章被收录于专栏:AI SPPECHAI SPPECH

作者: HOS(安全风信子) 日期: 2026-05-24 主要来源平台: GitHub 摘要: Context(上下文)是 AI IDE 能力的上限,这一论断在人工智能辅助编程领域已获得广泛认同。在实际软件开发场景中,代码库规模动辄数十万乃至上百万行,而 AI 模型的有效上下文窗口通常仅有数万 token。如何在海量代码中精准提取最相关的上下文片段,是 AI IDE 面临的核心挑战。本文深度解析四大上下文策略:Chunk(文本分块)通过多种分块算法将代码拆分为可管理的单元;Symbol(符号索引)建立函数、类、变量的索引结构实现精确导航;Graph(图结构)构建代码结构图、调用图、依赖图捕获语义关联;RAG(检索增强)融合向量检索与传统 BM25 算法实现语义匹配。通过对 Cursor、Claude Code、Cline 等主流产品的深度分析,我们揭示上下文工程背后的原理与实现机制,探讨如何在效果与成本之间取得平衡,并给出混合检索 Context Engine 的完整实现方案。

目录
  • 1. 引言:为什么 Context 是 AI IDE 的核心瓶颈
    • 1.1 上下文危机的本质
    • 1.2 AI IDE 的上下文处理流程
    • 1.3 四大策略的协同关系
  • 2. Chunk 策略:代码分块的艺术与科学
    • 2.1 分块:上下文工程的基石
    • 2.2 固定分块策略
      • 2.2.1 实现原理
      • 2.2.2 固定分块的优劣分析
    • 2.3 语义分块策略
      • 2.3.1 基于 AST 的语义分块
      • 2.3.2 语义分块的边界检测
    • 2.4 结构感知分块策略
      • 2.4.1 分块重要性评分
    • 2.5 分块策略对比分析
    • 2.6 分块策略的工程实践
  • 3. Symbol 策略:代码结构的精确索引
    • 3.1 为什么需要 Symbol 索引
    • 3.2 Language Server Protocol (LSP) 与符号提取
      • 3.2.1 LSP 符号能力概览
      • 3.2.2 基于 LSP 的符号提取实现
    • 3.3 符号索引的构建与存储
      • 3.3.1 符号索引的数据结构
    • 3.4 类型推断与符号增强
      • 3.4.1 类型推断引擎
    • 3.5 Symbol 策略在 AI IDE 中的应用
      • 3.5.1 智能补全
      • 3.5.2 代码理解增强
  • 4. Graph 策略:代码关系的语义建模
    • 4.1 代码图结构的本质
    • 4.2 调用图构建
      • 4.2.1 静态调用图构建
    • 4.3 依赖图构建
    • 4.4 Graph 策略在 AI IDE 中的应用
      • 4.4.1 影响分析
      • 4.4.2 上下文优先级排序
  • 5. RAG 策略:检索增强的语义匹配
    • 5.1 RAG 在 AI IDE 中的定位
    • 5.2 向量检索基础
      • 5.2.1 文本向量化
      • 5.2.2 向量索引与检索
    • 5.3 BM25 检索算法
    • 5.4 混合检索架构
    • 5.5 RAG 上下文组装策略
  • 6. 上下文选择算法:从海量代码中精准选材
    • 6.1 上下文选择的挑战
      • 6.1.1 挑战的本质
    • 6.2 多阶段筛选算法
    • 6.3 上下文预算分配
  • 7. 成本控制:Token 预算与上下文窗口的博弈
    • 7.1 Token 经济的本质
    • 7.2 Token 预算管理框架
    • 7.3 自适应上下文压缩
  • 8. 实践:实现一个混合检索的 Context Engine
    • 8.1 整体架构设计
    • 8.2 核心实现代码
    • 8.3 使用示例
  • 9. 总结与展望
    • 9.1 四大策略对比
    • 9.2 工程实践建议
    • 9.3 未来方向
  • 参考链接
  • A. RAG 实现细节
    • A.1 向量数据库选型对比
    • A.2 BM25 参数调优
  • B. 性能基准测试结果

“The quality of an AI coding assistant is ultimately limited by the quality of its context.” AI 编程助手的质量,最终受限于其上下文的质量。

1. 引言:为什么 Context 是 AI IDE 的核心瓶颈

1.1 上下文危机的本质

当代软件开发的核心矛盾之一是代码规模膨胀与模型上下文有限之间的冲突。根据 GitHub 的统计数据[^1],一个中等规模的企业级代码库通常包含:

代码库规模

代码行数(LOC)

Token 估算(压缩率 1.5)

小型项目

10,000 - 50,000

15K - 75K

中型项目

50,000 - 200,000

75K - 300K

大型项目

200,000 - 1,000,000

300K - 1.5M

巨型项目

> 1,000,000

> 1.5M

而截至 2026 年,主流大模型的上下文窗口虽然已有显著提升,但真正有效使用的上下文通常被限制在 32K - 128K token 范围内,超过此范围后模型对远处信息的注意力会显著衰减,这被研究者称为 “lost in the middle” 问题[^2]。

1.2 AI IDE 的上下文处理流程

一个完整的 AI IDE 上下文处理流程通常包含以下阶段:

1.3 四大策略的协同关系

Chunk、Symbol、Graph、RAG 四大策略并非相互替代,而是互补协同的关系:

  • Chunk 是所有策略的基础单元,负责将代码库切分为可处理的块
  • Symbol 提供精确的代码结构信息,用于类型推断和定义定位
  • Graph 捕获代码间的语义关系,用于调用链分析和依赖解析
  • RAG 提供语义层面的检索能力,用于自然语言查询的理解和匹配

这四者的关系可以用下表概括:

策略

核心问题

输入

输出

典型算法

Chunk

如何切分代码?

原始代码文件

结构化块列表

固定窗口、语义分块

Symbol

符号在哪?

代码 AST

符号索引表

LSP 协议、类型推断

Graph

代码如何关联?

符号索引

图结构

静态分析、指针解析

RAG

最相关的代码?

查询 + 块

相关块排序

向量检索、BM25


2. Chunk 策略:代码分块的艺术与科学

本节为你提供的核心技术价值:理解代码分块的多种策略及其适用场景,掌握固定分块、语义分块、结构感知分块的原理与实现。

2.1 分块:上下文工程的基石

Chunk(文本分块)是 Context Engine 中最基础的策略,其核心思想是将大规模的代码库拆分为多个独立的、可检索的片段。分块的质量直接影响后续检索的效果:分块过大,相关性噪声增加;分块过小,上下文连续性丢失

2.2 固定分块策略

固定分块是最简单也是最常用的分块策略。其核心思想是按照预设的固定长度(通常以 token 数或行数计)对代码进行均匀切分。

2.2.1 实现原理
代码语言:javascript
复制
class FixedChunker:
    """
    固定大小分块器
    将代码按固定 token 数或行数进行切分
    """
    
    def __init__(self, chunk_size: int = 512, overlap: int = 50):
        """
        Args:
            chunk_size: 每块的 token 数(包含重叠)
            overlap: 相邻块之间的重叠 token 数
        """
        self.chunk_size = chunk_size
        self.overlap = overlap
    
    def chunk_code(self, code: str, language: str = None) -> list[Chunk]:
        """
        对代码进行固定大小分块
        
        Args:
            code: 源代码文本
            language: 编程语言(用于 tokenizer 选择)
            
        Returns:
            分块列表,每个块包含文本、起始位置、结束位置
        """
        # 选择对应语言的 tokenizer
        tokenizer = self._get_tokenizer(language)
        
        # 将代码转换为 token 序列
        tokens = tokenizer.encode(code)
        
        chunks = []
        step = self.chunk_size - self.overlap
        
        for start in range(0, len(tokens), step):
            end = min(start + self.chunk_size, len(tokens))
            chunk_tokens = tokens[start:end]
            
            # 解码回文本(可能存在边界偏差,但可接受)
            chunk_text = tokenizer.decode(chunk_tokens)
            
            chunks.append(Chunk(
                content=chunk_text,
                start_token=start,
                end_token=end,
                start_line=self._get_line_from_offset(code, start),
                end_line=self._get_line_from_offset(code, end)
            ))
            
            if end >= len(tokens):
                break
        
        return chunks
    
    def _get_tokenizer(self, language: str):
        """获取对应语言的 tokenizer"""
        # 实际实现中可根据 language 返回不同的 tokenizer
        # 这里简化处理,使用 tiktoken 作为通用 tokenizer
        import tiktoken
        return tiktoken.get_encoding("cl100k_base")
    
    def _get_line_from_offset(self, code: str, token_offset: int) -> int:
        """根据 token 偏移量估算行号"""
        # 简化实现:假设平均每 token 约 4 字符,每行约 80 字符
        char_offset = token_offset * 4
        return code[:char_offset].count('\n') + 1


@dataclass
class Chunk:
    """代码分块"""
    content: str
    start_token: int
    end_token: int
    start_line: int
    end_line: int
    file_path: str = None
    language: str = None
    metadata: dict = field(default_factory=dict)
2.2.2 固定分块的优劣分析

优点:

  • 实现简单,易于理解和维护
  • 分块均匀,检索时可预测性高
  • 计算开销低,适合大规模代码库快速处理

缺点:

  • 语义割裂问题:可能在函数中途或语句中间切分,导致语义不完整
  • 上下文丢失:无法保留被切分代码块的原始上下文关系
  • 边界效应:固定边界无法适应代码的结构特点
2.3 语义分块策略

语义分块旨在解决固定分块的语义割裂问题。其核心思想是根据代码的语义边界(如函数、类、模块)进行分块,确保每个块尽可能保持语义的完整性。

2.3.1 基于 AST 的语义分块

现代编程语言的代码具有明确的语法结构——函数、类、方法、模块等。这些结构天然构成了语义边界,是进行分块的理想切分点。

代码语言:javascript
复制
import ast
from typing import Iterator
from dataclasses import dataclass, field


class SemanticChunker:
    """
    基于 AST 的语义分块器
    通过解析代码的抽象语法树,按语义单元进行分块
    """
    
    def __init__(self, max_chunk_size: int = 512, min_chunk_size: int = 50):
        self.max_chunk_size = max_chunk_size
        self.min_chunk_size = min_chunk_size
    
    def chunk_code(self, code: str, file_path: str = None) -> list[Chunk]:
        """
        对代码进行语义分块
        
        策略:
        1. 解析 AST
        2. 遍历 AST 节点,按语义单元(函数、类、方法)分块
        3. 过大块进一步拆分,过小块合并
        """
        try:
            tree = ast.parse(code, filename=file_path)
        except SyntaxError:
            # 解析失败时回退到固定分块
            return self._fallback_chunk(code)
        
        chunks = []
        
        # 遍历 AST 节点,按语义单元分块
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                chunk = self._function_to_chunk(node, code, file_path)
                if self._is_valid_chunk(chunk):
                    chunks.append(chunk)
                elif self._is_large_chunk(chunk):
                    # 过大的块需要进一步拆分
                    chunks.extend(self._split_large_chunk(chunk))
                    
            elif isinstance(node, ast.ClassDef):
                chunk = self._class_to_chunk(node, code, file_path)
                if self._is_valid_chunk(chunk):
                    chunks.append(chunk)
                elif self._is_large_chunk(chunk):
                    chunks.extend(self._split_large_chunk(chunk))
        
        # 按文件中的位置排序
        chunks.sort(key=lambda c: c.start_line)
        
        # 合并过小的块
        chunks = self._merge_small_chunks(chunks)
        
        return chunks
    
    def _function_to_chunk(self, node: ast.FunctionDef, code: str, file_path: str) -> Chunk:
        """将函数节点转换为 Chunk"""
        # 获取函数定义的行范围
        start_line = node.lineno
        end_line = node.end_lineno or start_line + 20
        
        # 提取函数代码(包含完整定义)
        lines = code.split('\n')
        function_lines = lines[start_line - 1:end_line]
        function_code = '\n'.join(function_lines)
        
        return Chunk(
            content=function_code,
            start_token=0,  # token 统计在外部进行
            end_token=0,
            start_line=start_line,
            end_line=end_line,
            file_path=file_path,
            metadata={
                'type': 'function',
                'name': node.name,
                'args': [arg.arg for arg in node.args.args],
                'decorators': [ast.unparse(d) for d in node.decorator_list]
            }
        )
    
    def _class_to_chunk(self, node: ast.ClassDef, code: str, file_path: str) -> Chunk:
        """将类节点转换为 Chunk"""
        start_line = node.lineno
        end_line = node.end_lineno or start_line + 50
        
        lines = code.split('\n')
        class_lines = lines[start_line - 1:end_line]
        class_code = '\n'.join(class_lines)
        
        # 提取类的方法列表
        methods = [n.name for n in node.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
        
        return Chunk(
            content=class_code,
            start_token=0,
            end_token=0,
            start_line=start_line,
            end_line=end_line,
            file_path=file_path,
            metadata={
                'type': 'class',
                'name': node.name,
                'methods': methods,
                'bases': [ast.unparse(base) for base in node.bases]
            }
        )
    
    def _is_valid_chunk(self, chunk: Chunk) -> bool:
        """检查块是否有效(不过小)"""
        return len(chunk.content.split('\n')) >= self.min_chunk_size
    
    def _is_large_chunk(self, chunk: Chunk) -> bool:
        """检查块是否过大"""
        return len(chunk.content.split('\n')) > self.max_chunk_size
    
    def _split_large_chunk(self, chunk: Chunk) -> list[Chunk]:
        """拆分过大的块(简化实现)"""
        # 实际实现中可按行数进一步拆分
        # 这里简化处理,返回原块
        return [chunk]
    
    def _merge_small_chunks(self, chunks: list[Chunk]) -> list[Chunk]:
        """合并过小的块"""
        if not chunks:
            return []
        
        merged = [chunks[0]]
        
        for chunk in chunks[1:]:
            last = merged[-1]
            combined_lines = last.end_line - last.start_line + chunk.end_line - chunk.start_line
            
            if combined_lines < self.min_chunk_size:
                # 合并到前一个块
                merged[-1] = Chunk(
                    content=last.content + '\n' + chunk.content,
                    start_token=last.start_token,
                    end_token=chunk.end_token,
                    start_line=last.start_line,
                    end_line=chunk.end_line,
                    file_path=chunk.file_path,
                    metadata={**last.metadata, **chunk.metadata}
                )
            else:
                merged.append(chunk)
        
        return merged
    
    def _fallback_chunk(self, code: str) -> list[Chunk]:
        """解析失败时的回退策略"""
        fixed_chunker = FixedChunker()
        return fixed_chunker.chunk_code(code)
2.3.2 语义分块的边界检测

在实际实现中,语义分块需要精确检测各种代码结构的边界:

代码语言:javascript
复制
class BoundaryDetector:
    """
    代码结构边界检测器
    精确检测函数、类、模块的边界位置
    """
    
    # 语言的缩进模式
    INDENT_PATTERNS = {
        'python': {'indent': '    ', 'dedent': None},
        'javascript': {'indent': '  ', 'dedent': None},
        'java': {'indent': '    ', 'dedent': None},
    }
    
    def detect_function_boundaries(self, code: str, language: str = 'python') -> list[tuple[int, int, str]]:
        """
        检测所有函数的边界
        
        Returns:
            [(start_line, end_line, function_name), ...]
        """
        if language == 'python':
            return self._detect_python_functions(code)
        elif language in ('javascript', 'typescript'):
            return self._detect_js_functions(code)
        else:
            return self._detect_generic_functions(code)
    
    def _detect_python_functions(self, code: str) -> list[tuple[int, int, str]]:
        """检测 Python 函数的边界"""
        tree = ast.parse(code)
        boundaries = []
        
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                boundaries.append((
                    node.lineno,
                    node.end_lineno or node.lineno + 20,
                    node.name
                ))
        
        return boundaries
    
    def _detect_js_functions(self, code: str) -> list[tuple[int, int, str]]:
        """检测 JavaScript 函数的边界(简化实现)"""
        import re
        
        # 匹配函数声明
        func_pattern = r'(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>|(\w+)\s*:\s*(?:async\s*)?\([^)]*\)\s*=>)'
        
        boundaries = []
        lines = code.split('\n')
        
        for i, line in enumerate(lines, 1):
            match = re.search(func_pattern, line)
            if match:
                func_name = match.group(1) or match.group(2) or match.group(3)
                # 简化:假设函数体在后续 5-50 行
                boundaries.append((i, min(i + 20, len(lines)), func_name))
        
        return boundaries
    
    def _detect_generic_functions(self, code: str) -> list[tuple[int, int, str]]:
        """通用函数边界检测(基于括号匹配)"""
        boundaries = []
        lines = code.split('\n')
        
        in_function = False
        function_start = 0
        function_name = ''
        brace_count = 0
        
        for i, line in enumerate(lines, 1):
            if 'function' in line or 'def ' in line:
                in_function = True
                function_start = i
                # 提取函数名
                match = re.search(r'(?:function|def)\s+(\w+)', line)
                function_name = match.group(1) if match else f'anonymous_{i}'
            
            if in_function:
                brace_count += line.count('{') - line.count('}')
                if brace_count == 0 and '{' in ''.join(lines[function_start-1:i]):
                    boundaries.append((function_start, i, function_name))
                    in_function = False
        
        return boundaries
2.4 结构感知分块策略

结构感知分块在语义分块的基础上更进一步,不仅考虑代码的语法结构,还考虑代码的语义关系和重要性权重。这种策略的核心思想是:并非所有代码都具有同等重要性,核心业务逻辑、入口函数、公共 API 应该获得更多的上下文份额。

2.4.1 分块重要性评分
代码语言:javascript
复制
from enum import Enum
from typing import Callable


class CodeImportance(Enum):
    """代码重要性等级"""
    CRITICAL = 1  # 关键:入口点、公共 API
    HIGH = 2      # 高:业务核心逻辑、数据模型
    MEDIUM = 3    # 中:辅助功能、内部实现
    LOW = 4       # 低:工具函数、配置


class StructuralChunker:
    """
    结构感知分块器
    根据代码在项目中的结构性位置和重要性进行分块
    """
    
    def __init__(self):
        self.importance_weights = {
            'main': 4.0,
            'index': 4.0,
            'app': 3.5,
            'api': 3.5,
            'router': 3.0,
            'service': 3.0,
            'model': 2.5,
            'controller': 2.5,
            'util': 1.5,
            'helper': 1.5,
            'config': 1.0,
            'test': 0.8,
        }
    
    def chunk_with_importance(
        self, 
        code: str, 
        file_path: str,
        structure_hints: dict = None
    ) -> list[Chunk]:
        """
        带重要性评分的结构感知分块
        
        Args:
            code: 源代码
            file_path: 文件路径(用于判断文件类型和位置)
            structure_hints: 额外的结构提示(如入口点列表)
        """
        # 1. 基础分块
        base_chunks = self._base_chunking(code, file_path)
        
        # 2. 计算每个块的重要性分数
        for chunk in base_chunks:
            chunk.importance = self._calculate_importance(chunk, file_path, structure_hints)
        
        # 3. 根据重要性调整块大小
        adjusted_chunks = self._adjust_chunks_by_importance(base_chunks)
        
        return adjusted_chunks
    
    def _calculate_importance(
        self, 
        chunk: Chunk, 
        file_path: str,
        hints: dict = None
    ) -> float:
        """
        计算分块的重要性分数
        
        考量因素:
        1. 文件路径(目录和文件名)
        2. 代码类型(函数声明、类定义等)
        3. 导出/公开程度
        4. 调用频率
        """
        score = 1.0
        
        # 文件路径因素
        file_name = chunk.file_path.split('/')[-1].lower()
        directory = '/'.join(chunk.file_path.split('/')[:-1]).lower()
        
        for key, weight in self.importance_weights.items():
            if key in file_name or key in directory:
                score *= weight
                break
        
        # 代码结构因素
        metadata = chunk.metadata
        if metadata.get('type') == 'function':
            func_name = metadata.get('name', '').lower()
            
            # 构造函数、入口函数高权重
            if func_name in ('__init__', 'main', 'run', 'start', 'init'):
                score *= 2.0
            
            # 下划线开头的私有方法低权重
            if func_name.startswith('_') and not func_name.startswith('__'):
                score *= 0.7
            
            # 公共方法(无下划线)适度权重
            if not func_name.startswith('_'):
                score *= 1.3
        
        elif metadata.get('type') == 'class':
            class_name = metadata.get('name', '')
            
            # 继承自特定基类高权重
            bases = metadata.get('bases', [])
            if any('Exception' in b or 'Error' in b for b in bases):
                score *= 1.5
            
            # 公共类适度权重
            if not class_name.startswith('_'):
                score *= 1.2
        
        # 外部提示因素
        if hints:
            # 如果这是已知的入口函数
            entry_functions = hints.get('entry_functions', [])
            if metadata.get('name') in entry_functions:
                score *= 3.0
        
        return score
    
    def _adjust_chunks_by_importance(self, chunks: list[Chunk]) -> list[Chunk]:
        """
        根据重要性调整块大小
        
        策略:
        - 高重要性块:扩大上下文,减少拆分
        - 低重要性块:缩小块,减少上下文
        """
        # 计算平均重要性
        avg_importance = sum(c.importance for c in chunks) / len(chunks) if chunks else 1.0
        
        adjusted = []
        for chunk in chunks:
            if chunk.importance > avg_importance * 1.5:
                # 高重要性:扩展块(包含更多上下文)
                expanded = self._expand_chunk(chunk)
                adjusted.append(expanded)
            elif chunk.importance < avg_importance * 0.5:
                # 低重要性:压缩块(减少到核心部分)
                compressed = self._compress_chunk(chunk)
                adjusted.append(compressed)
            else:
                adjusted.append(chunk)
        
        return adjusted
    
    def _expand_chunk(self, chunk: Chunk) -> Chunk:
        """扩展高重要性块"""
        # 简化实现:扩展 20% 的上下文
        lines = chunk.content.split('\n')
        expand_lines = max(1, len(lines) // 5)
        
        return Chunk(
            content=chunk.content,
            start_token=chunk.start_token,
            end_token=chunk.end_token,
            start_line=max(1, chunk.start_line - expand_lines),
            end_line=chunk.end_line + expand_lines,
            file_path=chunk.file_path,
            metadata={**chunk.metadata, 'expanded': True}
        )
    
    def _compress_chunk(self, chunk: Chunk) -> Chunk:
        """压缩低重要性块"""
        # 简化实现:标记为压缩
        return Chunk(
            content=chunk.content,
            start_token=chunk.start_token,
            end_token=chunk.end_token,
            start_line=chunk.start_line,
            end_line=chunk.end_line,
            file_path=chunk.file_path,
            metadata={**chunk.metadata, 'compressed': True}
        )
    
    def _base_chunking(self, code: str, file_path: str) -> list[Chunk]:
        """基础分块(使用语义分块器)"""
        semantic_chunker = SemanticChunker()
        return semantic_chunker.chunk_code(code, file_path)
2.5 分块策略对比分析

策略

优点

缺点

适用场景

固定分块

实现简单、计算高效

语义割裂、边界不智能

大规模代码的快速预览、精确检索不可行时

语义分块

保持语义完整、AST 精确

实现复杂、解析开销大

需要精确代码理解的场景

结构感知

重要性区分、上下文优化

需要项目结构知识、权重调优复杂

大型代码库、差异化检索需求

2.6 分块策略的工程实践

在实际 AI IDE 产品中,分块策略的选择通常遵循以下原则:

  1. 混合策略:单一策略难以应对复杂场景,通常组合使用
  2. 自适应调整:根据代码特征自动选择分块策略
  3. 层级索引:建立多级索引(文件级、类级、函数级)支持不同粒度的检索

Cursor 的分块策略分析表明[^3],其采用了一种层级感知的混合分块方法:

  • 顶层分块:按文件或模块划分,保持文件边界
  • 中层分块:按函数/类划分,保持语义完整
  • 底层分块:对过大或过小的块进行动态调整

3. Symbol 策略:代码结构的精确索引

本节为你提供的核心技术价值:理解符号索引的构建原理,掌握基于 LSP 的符号提取与类型推理技术,实现精确的代码定位和导航。

3.1 为什么需要 Symbol 索引

Chunk 策略关注的是代码的文本切分,而 Symbol 策略关注的是代码的结构语义。在软件开发中,代码不仅仅是文本,更是具有丰富语义结构的符号系统:

  • 函数定义 → 可调用、可引用
  • 变量声明 → 有类型、有作用域
  • 类定义 → 有属性、有方法、有继承关系
  • 模块导入 → 有依赖关系

这些符号信息对于 AI 理解代码至关重要。当用户询问"这个函数在哪里定义"或"这个变量的类型是什么"时,Symbol 索引能够提供精确的答案。

3.2 Language Server Protocol (LSP) 与符号提取

LSP(Language Server Protocol)是 Microsoft 发起的标准化协议,用于在编辑器/IDE 和语言服务器之间进行通信[^4]。LSP 提供了丰富的代码符号操作能力,是构建 Symbol 索引的基础。

3.2.1 LSP 符号能力概览

3.2.2 基于 LSP 的符号提取实现
代码语言:javascript
复制
import json
import subprocess
from dataclasses import dataclass, field
from typing import Optional
from enum import IntEnum


class SymbolKind(IntEnum):
    """LSP 符号类型"""
    FILE = 1
    MODULE = 2
    NAMESPACE = 3
    PACKAGE = 4
    CLASS = 5
    METHOD = 6
    PROPERTY = 7
    FIELD = 8
    CONSTRUCTOR = 9
    ENUM = 10
    INTERFACE = 11
    FUNCTION = 12
    VARIABLE = 13
    CONSTANT = 14
    STRING = 15
    NUMBER = 16
    BOOLEAN = 17
    ARRAY = 18
    OBJECT = 19
    KEY = 20
    NULL = 21
    ENUMMEMBER = 22
    STRUCT = 23
    EVENT = 24
    OPERATOR = 25
    TYPEPARAMETER = 26


@dataclass
class SymbolInfo:
    """符号信息"""
    name: str
    kind: SymbolKind
    file_path: str
    start_line: int
    start_column: int
    end_line: int
    end_column: int
    container_name: str = None
    signature: str = None
    docstring: str = None
    dependencies: list[str] = field(default_factory=list)
    type_info: str = None
    deprecated: bool = False


class LSSPSymbolExtractor:
    """
    基于 LSP 的符号提取器
    使用 Language Server Protocol 提取代码符号信息
    """
    
    def __init__(self, ls_path: str = None):
        """
        Args:
            ls_path: 语言服务器可执行文件路径
        """
        self.ls_path = ls_path
        self.process = None
        self.request_id = 0
    
    def extract_symbols(self, file_path: str, language: str) -> list[SymbolInfo]:
        """
        从源文件中提取所有符号
        
        Args:
            file_path: 源文件路径
            language: 编程语言
            
        Returns:
            符号信息列表
        """
        # 启动语言服务器
        self._start_server(language)
        
        try:
            # 发送 textDocument/documentSymbol 请求
            response = self._send_request('textDocument/documentSymbol', {
                'textDocument': {'uri': self._path_to_uri(file_path)}
            })
            
            symbols = self._parse_symbol_response(response, file_path)
            return symbols
            
        finally:
            self._stop_server()
    
    def find_definition(self, file_path: str, line: int, character: int) -> SymbolInfo:
        """
        查找符号定义位置
        
        Args:
            file_path: 源文件路径
            line: 光标行号(0索引)
            character: 光标列号(0索引)
            
        Returns:
            定义位置的符号信息
        """
        self._start_server(self._detect_language(file_path))
        
        try:
            response = self._send_request('textDocument/definition', {
                'textDocument': {'uri': self._path_to_uri(file_path)},
                'position': {'line': line, 'character': character}
            })
            
            return self._parse_location_response(response)
            
        finally:
            self._stop_server()
    
    def find_references(self, file_path: str, line: int, character: int) -> list[SymbolInfo]:
        """
        查找符号的所有引用
        
        Returns:
            引用位置列表
        """
        self._start_server(self._detect_language(file_path))
        
        try:
            response = self._send_request('textDocument/references', {
                'textDocument': {'uri': self._path_to_uri(file_path)},
                'position': {'line': line, 'character': character},
                'context': {'includeDeclaration': True}
            })
            
            return self._parse_locations_response(response)
            
        finally:
            self._stop_server()
    
    def _start_server(self, language: str):
        """启动语言服务器"""
        if self.process is not None:
            return
        
        # 根据语言选择对应的语言服务器
        server_cmd = self._get_server_command(language)
        
        self.process = subprocess.Popen(
            server_cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        
        # 发送初始化请求
        self._send_request('initialize', {
            'processId': self.process.pid,
            'rootUri': self._path_to_uri('.'),
            'capabilities': {
                'textDocument': {
                    'documentSymbol': {
                        'hierarchicalDocumentSymbolSupport': True
                    }
                }
            }
        })
        
        # 发送 initialized 通知
        self._send_notification('initialized', {})
    
    def _stop_server(self):
        """停止语言服务器"""
        if self.process:
            self._send_notification('shutdown', {})
            self.process.terminate()
            self.process = None
    
    def _send_request(self, method: str, params: dict) -> dict:
        """发送 LSP 请求"""
        self.request_id += 1
        
        request = {
            'jsonrpc': '2.0',
            'id': self.request_id,
            'method': method,
            'params': params
        }
        
        request_json = json.dumps(request) + '\n'
        self.process.stdin.write(request_json.encode())
        self.process.stdin.flush()
        
        # 读取响应
        response_line = self.process.stdout.readline()
        response = json.loads(response_line)
        
        return response.get('result', {})
    
    def _send_notification(self, method: str, params: dict):
        """发送 LSP 通知"""
        notification = {
            'jsonrpc': '2.0',
            'method': method,
            'params': params
        }
        
        notification_json = json.dumps(notification) + '\n'
        self.process.stdin.write(notification_json.encode())
        self.process.stdin.flush()
    
    def _parse_symbol_response(self, response: list, file_path: str) -> list[SymbolInfo]:
        """解析符号响应"""
        symbols = []
        
        def parse_recursive(items, container=''):
            for item in items:
                if 'children' in item:
                    # DocumentSymbol format
                    symbols.append(SymbolInfo(
                        name=item['name'],
                        kind=SymbolKind(item['kind']),
                        file_path=file_path,
                        start_line=item['range']['start']['line'],
                        start_column=item['range']['start']['character'],
                        end_line=item['range']['end']['line'],
                        end_column=item['range']['end']['character'],
                        container_name=container,
                        detail=item.get('detail')
                    ))
                    parse_recursive(item['children'], item['name'])
                else:
                    # SymbolInformation format
                    location = item['location']
                    symbols.append(SymbolInfo(
                        name=item['name'],
                        kind=SymbolKind(item['kind']),
                        file_path=file_path,
                        start_line=location['range']['start']['line'],
                        start_column=location['range']['start']['character'],
                        end_line=location['range']['end']['line'],
                        end_column=location['range']['end']['character'],
                        container_name=item.get('containerName')
                    ))
        
        parse_recursive(response)
        return symbols
    
    def _get_server_command(self, language: str) -> list[str]:
        """获取语言服务器启动命令"""
        servers = {
            'python': ['python', '-m', 'pylsp'],
            'javascript': ['node', '/path/to/javascript-language-server.js'],
            'typescript': ['node', '/path/to/typescript-language-server.js'],
            'rust': ['rust-analyzer'],
            'go': ['gopls'],
            'java': ['jdtls'],
        }
        return servers.get(language, ['pylsp'])
    
    def _path_to_uri(self, path: str) -> str:
        """将文件路径转换为 file URI"""
        import urllib.parse
        return 'file://' + urllib.parse.quote(path)
    
    def _detect_language(self, file_path: str) -> str:
        """根据文件扩展名检测语言"""
        ext_map = {
            '.py': 'python',
            '.js': 'javascript',
            '.ts': 'typescript',
            '.rs': 'rust',
            '.go': 'go',
            '.java': 'java',
        }
        import os
        _, ext = os.path.splitext(file_path)
        return ext_map.get(ext.lower(), 'python')
3.3 符号索引的构建与存储

提取的符号信息需要高效存储,以便快速检索。对于大规模代码库,符号数量可能达到数十万,因此需要设计高效的索引结构。

3.3.1 符号索引的数据结构
代码语言:javascript
复制
from typing import Dict, List, Set, Optional
from dataclasses import dataclass, field
import json


@dataclass
class SymbolIndex:
    """
    符号索引
    
    支持多种查询模式:
    1. 按名称查询(精确/模糊)
    2. 按类型查询
    3. 按文件/包查询
    4. 按依赖关系查询
    """
    # 名称 -> 符号列表(支持同名符号重载)
    name_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
    
    # 类型 -> 符号列表
    kind_index: Dict[SymbolKind, List[SymbolInfo]] = field(default_factory=dict)
    
    # 文件路径 -> 符号列表
    file_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
    
    # 容器(类/模块) -> 子符号列表
    container_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
    
    # 全符号列表(按文件路径和行号排序)
    all_symbols: List[SymbolInfo] = field(default_factory=list)
    
    def add_symbol(self, symbol: SymbolInfo):
        """添加符号到索引"""
        # 名称索引
        if symbol.name not in self.name_index:
            self.name_index[symbol.name] = []
        self.name_index[symbol.name].append(symbol)
        
        # 类型索引
        if symbol.kind not in self.kind_index:
            self.kind_index[symbol.kind] = []
        self.kind_index[symbol.kind].append(symbol)
        
        # 文件索引
        if symbol.file_path not in self.file_index:
            self.file_index[symbol.file_path] = []
        self.file_index[symbol.file_path].append(symbol)
        
        # 容器索引
        if symbol.container_name:
            if symbol.container_name not in self.container_index:
                self.container_index[symbol.container_name] = []
            self.container_index[symbol.container_name].append(symbol)
        
        # 全量列表
        self.all_symbols.append(symbol)
    
    def find_by_name(self, name: str, exact: bool = True) -> List[SymbolInfo]:
        """按名称查找符号"""
        if exact:
            return self.name_index.get(name, [])
        else:
            # 模糊匹配
            results = []
            for sym_name, symbols in self.name_index.items():
                if name.lower() in sym_name.lower():
                    results.extend(symbols)
            return results
    
    def find_by_kind(self, kind: SymbolKind) -> List[SymbolInfo]:
        """按类型查找符号"""
        return self.kind_index.get(kind, [])
    
    def find_by_file(self, file_path: str) -> List[SymbolInfo]:
        """查找文件中的所有符号"""
        return self.file_index.get(file_path, [])
    
    def find_in_container(self, container: str) -> List[SymbolInfo]:
        """查找容器内的符号(类的所有方法等)"""
        return self.container_index.get(container, [])
    
    def find_by_prefix(self, prefix: str) -> List[SymbolInfo]:
        """查找以指定前缀开头的符号"""
        results = []
        for sym_name, symbols in self.name_index.items():
            if sym_name.startswith(prefix):
                results.extend(symbols)
        return sorted(results, key=lambda s: (s.file_path, s.start_line))


class SymbolIndexBuilder:
    """
    符号索引构建器
    支持增量构建和批量构建
    """
    
    def __init__(self, extractor: LSSPSymbolExtractor):
        self.extractor = extractor
        self.index = SymbolIndex()
    
    def build_from_files(self, file_paths: list[str]) -> SymbolIndex:
        """
        从多个文件构建索引
        
        Args:
            file_paths: 源文件路径列表
            
        Returns:
            构建完成的符号索引
        """
        for file_path in file_paths:
            self.add_file(file_path)
        
        return self.index
    
    def add_file(self, file_path: str):
        """增量添加文件到索引"""
        language = self._detect_language(file_path)
        symbols = self.extractor.extract_symbols(file_path, language)
        
        for symbol in symbols:
            self.index.add_symbol(symbol)
    
    def remove_file(self, file_path: str):
        """从索引中移除文件"""
        # 找出需要移除的符号
        to_remove = self.index.find_by_file(file_path)
        
        for symbol in to_remove:
            # 从各索引中移除
            self.index.name_index[symbol.name].remove(symbol)
            self.index.kind_index[symbol.kind].remove(symbol)
            self.index.file_index[symbol.file_path].remove(symbol)
            
            if symbol.container_name and symbol.container_name in self.index.container_index:
                self.index.container_index[symbol.container_name].remove(symbol)
            
            self.index.all_symbols.remove(symbol)
    
    def _detect_language(self, file_path: str) -> str:
        """检测语言"""
        import os
        _, ext = os.path.splitext(file_path)
        ext_map = {
            '.py': 'python',
            '.js': 'javascript',
            '.ts': 'typescript',
            '.rs': 'rust',
            '.go': 'go',
            '.java': 'java',
        }
        return ext_map.get(ext.lower(), 'python')
3.4 类型推断与符号增强

符号索引不仅包含符号的位置信息,还应包含类型信息、依赖关系等语义信息。这些信息对于 AI 理解代码至关重要。

3.4.1 类型推断引擎
代码语言:javascript
复制
from typing import Dict, Set, Optional, List
from dataclasses import dataclass


@dataclass
class TypeInfo:
    """类型信息"""
    name: str
    module: str = None
    generic_params: List['TypeInfo'] = None
    base_types: List[str] = None
    methods: Dict[str, 'TypeInfo'] = None
    properties: Dict[str, 'TypeInfo'] = None


class TypeInferenceEngine:
    """
    类型推断引擎
    
    基于静态分析推断变量的类型信息
    支持:
    1. 变量声明类型推断
    2. 函数返回类型推断
    3. 表达式类型推断
    4. 泛型参数推断
    """
    
    def __init__(self, symbol_index: SymbolIndex):
        self.symbol_index = symbol_index
        self.type_cache: Dict[str, TypeInfo] = {}
    
    def infer_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
        """
        推断符号的类型
        
        Args:
            symbol: 符号信息
            
        Returns:
            推断得到的类型信息
        """
        cache_key = f"{symbol.file_path}:{symbol.start_line}"
        if cache_key in self.type_cache:
            return self.type_cache[cache_key]
        
        if symbol.type_info:
            # 已有类型信息
            return self._parse_type_string(symbol.type_info)
        
        # 基于上下文推断
        inferred = self._infer_from_context(symbol)
        
        self.type_cache[cache_key] = inferred
        return inferred
    
    def _infer_from_context(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
        """从代码上下文推断类型"""
        if symbol.kind == SymbolKind.VARIABLE:
            return self._infer_variable_type(symbol)
        elif symbol.kind == SymbolKind.FUNCTION:
            return self._infer_function_return_type(symbol)
        elif symbol.kind == SymbolKind.PROPERTY:
            return self._infer_property_type(symbol)
        
        return None
    
    def _infer_variable_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
        """推断变量类型"""
        # 读取源代码
        with open(symbol.file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        if symbol.start_line > len(lines):
            return None
        
        line = lines[symbol.start_line - 1]
        
        # Python 类型注解
        if ': ' in line and '=' in line:
            type_part = line.split(':')[1].split('=')[0].strip()
            return self._parse_type_string(type_part)
        
        # JavaScript/TypeScript 类型注解
        if ': ' in line and ('const' in line or 'let' in line or 'var' in line):
            type_part = line.split(':')[1].split(';')[0].strip()
            return self._parse_type_string(type_part)
        
        return None
    
    def _infer_function_return_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
        """推断函数返回类型"""
        # 查找函数签名中的返回类型注解
        if symbol.signature:
            # 解析函数签名
            import re
            match = re.search(r'->\s*(\w+)', symbol.signature)
            if match:
                return TypeInfo(name=match.group(1))
        
        return None
    
    def _infer_property_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
        """推断属性类型"""
        # 属性通常是类成员
        if symbol.container_name:
            container_symbols = self.symbol_index.find_by_name(symbol.container_name, exact=True)
            if container_symbols:
                # 查找属性定义
                container = container_symbols[0]
                # 尝试从源代码中解析属性类型
                pass
        
        return None
    
    def _parse_type_string(self, type_str: str) -> TypeInfo:
        """解析类型字符串"""
        # 处理泛型
        import re
        
        generic_match = re.match(r'(\w+)\[(\w+(?:,\s*\w+)*)\]', type_str)
        if generic_match:
            base_type = generic_match.group(1)
            params = generic_match.group(2).split(',')
            return TypeInfo(
                name=base_type,
                generic_params=[TypeInfo(name=p.strip()) for p in params]
            )
        
        # 简单类型
        return TypeInfo(name=type_str)
    
    def build_type_hierarchy(self) -> Dict[str, Set[str]]:
        """
        构建类型继承层次结构
        
        Returns:
            {子类型: {父类型集合}}
        """
        hierarchy: Dict[str, Set[str]] = {}
        
        for symbol in self.symbol_index.all_symbols:
            if symbol.kind in (SymbolKind.CLASS, SymbolKind.INTERFACE):
                type_info = self.infer_type(symbol)
                if type_info and type_info.base_types:
                    for base in type_info.base_types:
                        if base not in hierarchy:
                            hierarchy[base] = set()
                        hierarchy[base].add(type_info.name)
        
        return hierarchy
3.5 Symbol 策略在 AI IDE 中的应用
3.5.1 智能补全

当用户输入代码时,Symbol 索引用于:

  1. 类型感知的成员补全:根据变量类型过滤可用的成员
  2. 作用域感知的符号补全:只显示当前作用域可见的符号
  3. 依赖感知的导入补全:根据项目依赖推荐可用模块
3.5.2 代码理解增强

Symbol 索引为 AI 提供代码的结构化视图

代码语言:javascript
复制
class SymbolEnhancedContextBuilder:
    """
    基于符号索引的上下文构建器
    将符号信息融入检索结果,增强 AI 对代码的理解
    """
    
    def __init__(self, symbol_index: SymbolIndex, type_inference: TypeInferenceEngine):
        self.symbol_index = symbol_index
        self.type_inference = type_inference
    
    def build_enhanced_context(
        self, 
        chunks: list[Chunk], 
        query: str
    ) -> list[EnhancedChunk]:
        """
        构建增强的上下文
        
        为每个 chunk 添加符号信息,帮助 AI 理解代码结构
        """
        enhanced_chunks = []
        
        for chunk in chunks:
            # 查找 chunk 范围内的所有符号
            symbols_in_range = self._find_symbols_in_range(chunk)
            
            # 构建符号摘要
            symbol_summary = self._build_symbol_summary(symbols_in_range)
            
            # 添加到 chunk
            enhanced = EnhancedChunk(
                content=chunk.content,
                start_line=chunk.start_line,
                end_line=chunk.end_line,
                file_path=chunk.file_path,
                symbols=symbols_in_range,
                symbol_summary=symbol_summary,
                type_info=self._build_type_context(symbols_in_range)
            )
            
            enhanced_chunks.append(enhanced)
        
        return enhanced_chunks
    
    def _find_symbols_in_range(self, chunk: Chunk) -> list[SymbolInfo]:
        """查找 chunk 范围内的符号"""
        symbols = []
        
        for symbol in self.symbol_index.all_symbols:
            if (symbol.file_path == chunk.file_path and
                chunk.start_line <= symbol.start_line <= chunk.end_line):
                symbols.append(symbol)
        
        return symbols
    
    def _build_symbol_summary(self, symbols: list[SymbolInfo]) -> str:
        """构建符号摘要"""
        if not symbols:
            return "无符号定义"
        
        # 按类型分组
        by_kind = {}
        for sym in symbols:
            if sym.kind not in by_kind:
                by_kind[sym.kind] = []
            by_kind[sym.kind].append(sym.name)
        
        parts = []
        for kind, names in sorted(by_kind.items(), key=lambda x: x[0].value):
            parts.append(f"{kind.name}: {', '.join(names)}")
        
        return "; ".join(parts)
    
    def _build_type_context(self, symbols: list[SymbolInfo]) -> str:
        """构建类型上下文"""
        type_contexts = []
        
        for symbol in symbols:
            type_info = self.type_inference.infer_type(symbol)
            if type_info:
                if type_info.generic_params:
                    type_str = f"{type_info.name}[{', '.join(t.name for t in type_info.generic_params)}]"
                else:
                    type_str = type_info.name
                
                type_contexts.append(f"{symbol.name}: {type_str}")
        
        return "\n".join(type_contexts)


@dataclass
class EnhancedChunk:
    """增强的代码块"""
    content: str
    start_line: int
    end_line: int
    file_path: str
    symbols: list[SymbolInfo]
    symbol_summary: str
    type_info: str

4. Graph 策略:代码关系的语义建模

本节为你提供的核心技术价值:理解代码图的构建原理,掌握调用图、依赖图、数据流图的实现技术,实现深层次代码关系推理。

4.1 代码图结构的本质

代码不仅是静态的文本,更是一个由多种关系交织而成的语义网络。Graph 策略的核心思想是将代码中的各种关系建模为图结构,从而支持:

  1. 调用关系:函数 A 调用了函数 B
  2. 依赖关系:模块 A 依赖模块 B
  3. 继承关系:类 A 继承自类 B
  4. 引用关系:变量 A 引用了变量 B
  5. 数据流关系:数据从 A 流向 B

这种结构化的关系表示使得 AI 能够进行深层次推理,如"修改这个函数会影响哪些地方"。

4.2 调用图构建

调用图(Call Graph)是图结构中最重要的一种,表示函数之间的调用关系。

4.2.1 静态调用图构建
代码语言:javascript
复制
import ast
from typing import Dict, Set, List, Tuple, Optional
from dataclasses import dataclass, field


@dataclass
class CallGraphNode:
    """调用图节点"""
    name: str
    file_path: str
    start_line: int
    end_line: int
    function_type: str  # 'function', 'method', 'lambda', 'class'
    is_external: bool = False


@dataclass
class CallGraphEdge:
    """调用图边"""
    caller: str  # caller node id
    callee: str  # callee node id
    call_site_line: int
    is_virtual: bool = False  # 虚函数调用(运行时确定)


class CallGraph:
    """
    调用图
    
    支持:
    1. 静态调用图构建(基于 AST 分析)
    2. 动态调用图构建(基于运行时追踪)
    3. 调用链查询
    """
    
    def __init__(self):
        self.nodes: Dict[str, CallGraphNode] = {}
        self.edges: List[CallGraphEdge] = []
        # 优化:构建反向索引
        self.callee_to_callers: Dict[str, Set[str]] = {}
        self.caller_to_callees: Dict[str, Set[str]] = {}
    
    def add_node(self, node: CallGraphNode) -> str:
        """添加节点,返回节点 ID"""
        node_id = self._make_node_id(node.file_path, node.name)
        self.nodes[node_id] = node
        return node_id
    
    def add_edge(self, caller_id: str, callee_id: str, call_site_line: int, is_virtual: bool = False):
        """添加调用边"""
        edge = CallGraphEdge(
            caller=caller_id,
            callee=callee_id,
            call_site_line=call_site_line,
            is_virtual=is_virtual
        )
        self.edges.append(edge)
        
        # 更新反向索引
        if callee_id not in self.callee_to_callers:
            self.callee_to_callers[callee_id] = set()
        self.callee_to_callers[callee_id].add(caller_id)
        
        if caller_id not in self.caller_to_callees:
            self.caller_to_callees[caller_id] = set()
        self.caller_to_callees[caller_id].add(callee_id)
    
    def get_callers(self, node_id: str) -> Set[str]:
        """获取调用指定函数的所有函数(反向查询)"""
        return self.callee_to_callers.get(node_id, set())
    
    def get_callees(self, node_id: str) -> Set[str]:
        """获取指定函数调用的所有函数(正向查询)"""
        return self.caller_to_callees.get(node_id, set())
    
    def find_call_chain(self, start: str, end: str) -> Optional[List[str]]:
        """
        查找从 start 到 end 的调用链
        
        使用 BFS 查找最短路径
        """
        from collections import deque
        
        if start == end:
            return [start]
        
        queue = deque([(start, [start])])
        visited = {start}
        
        while queue:
            current, path = queue.popleft()
            
            for callee in self.get_callees(current):
                if callee == end:
                    return path + [callee]
                
                if callee not in visited:
                    visited.add(callee)
                    queue.append((callee, path + [callee]))
        
        return None
    
    def find_all_callers(self, node_id: str, max_depth: int = 10) -> Set[str]:
        """
        递归查找所有调用者(影响分析)
        
        Args:
            node_id: 目标函数 ID
            max_depth: 最大递归深度
        """
        all_callers: Set[str] = set()
        
        def dfs(current: str, depth: int):
            if depth >= max_depth:
                return
            
            for caller in self.get_callers(current):
                if caller not in all_callers:
                    all_callers.add(caller)
                    dfs(caller, depth + 1)
        
        dfs(node_id, 0)
        return all_callers
    
    def find_all_callees(self, node_id: str, max_depth: int = 10) -> Set[str]:
        """
        递归查找所有被调用者(依赖分析)
        """
        all_callees: Set[str] = set()
        
        def dfs(current: str, depth: int):
            if depth >= max_depth:
                return
            
            for callee in self.get_callees(current):
                if callee not in all_callees:
                    all_callees.add(callee)
                    dfs(callee, depth + 1)
        
        dfs(node_id, 0)
        return all_callees
    
    def _make_node_id(self, file_path: str, name: str) -> str:
        """生成唯一的节点 ID"""
        return f"{file_path}:{name}"
4.3 依赖图构建

依赖图表示模块/包之间的导入关系,是代码组织和架构分析的基础。

代码语言:javascript
复制
import os
from typing import Dict, Set, List, Optional
from dataclasses import dataclass, field


@dataclass
class DependencyNode:
    """依赖图节点"""
    name: str
    path: str
    node_type: str  # 'module', 'package', 'file'
    is_external: bool = False


@dataclass 
class DependencyEdge:
    """依赖图边"""
    source: str
    target: str
    import_line: int
    import_type: str  # 'import', 'from_import', 'dynamic'


class DependencyGraph:
    """
    依赖图
    
    支持:
    1. 模块依赖分析
    2. 循环依赖检测
    3. 依赖路径查找
    4. 影响范围分析
    """
    
    def __init__(self):
        self.nodes: Dict[str, DependencyNode] = {}
        self.edges: List[DependencyEdge] = []
        self.source_to_targets: Dict[str, Set[str]] = {}
        self.target_to_sources: Dict[str, Set[str]] = {}
    
    def add_node(self, node: DependencyNode) -> str:
        """添加节点"""
        self.nodes[node.path] = node
        return node.path
    
    def add_edge(self, edge: DependencyEdge):
        """添加依赖边"""
        self.edges.append(edge)
        
        if edge.source not in self.source_to_targets:
            self.source_to_targets[edge.source] = set()
        self.source_to_targets[edge.source].add(edge.target)
        
        if edge.target not in self.target_to_sources:
            self.target_to_sources[edge.target] = set()
        self.target_to_sources[edge.target].add(edge.source)
    
    def find_cycle(self) -> Optional[List[str]]:
        """
        检测循环依赖
        
        Returns:
            如果存在循环依赖,返回循环路径
        """
        visited = set()
        rec_stack = set()
        path = []
        
        def dfs(node: str) -> Optional[List[str]]:
            visited.add(node)
            rec_stack.add(node)
            path.append(node)
            
            for neighbor in self.source_to_targets.get(node, set()):
                if neighbor not in visited:
                    result = dfs(neighbor)
                    if result:
                        return result
                elif neighbor in rec_stack:
                    # 找到循环
                    cycle_start = path.index(neighbor)
                    return path[cycle_start:] + [neighbor]
            
            path.pop()
            rec_stack.remove(node)
            return None
        
        for node in self.nodes:
            if node not in visited:
                result = dfs(node)
                if result:
                    return result
        
        return None
    
    def get_transitive_closure(self, node_id: str) -> Set[str]:
        """
        获取节点的传递闭包(所有直接或间接依赖)
        """
        closure = set()
        to_visit = {node_id}
        
        while to_visit:
            current = to_visit.pop()
            for neighbor in self.source_to_targets.get(current, set()):
                if neighbor not in closure:
                    closure.add(neighbor)
                    to_visit.add(neighbor)
        
        return closure
    
    def get_reverse_transitive_closure(self, node_id: str) -> Set[str]:
        """
        获取反向传递闭包(所有直接或间接影响该节点的上游)
        """
        closure = set()
        to_visit = {node_id}
        
        while to_visit:
            current = to_visit.pop()
            for neighbor in self.target_to_sources.get(current, set()):
                if neighbor not in closure:
                    closure.add(neighbor)
                    to_visit.add(neighbor)
        
        return closure
4.4 Graph 策略在 AI IDE 中的应用
4.4.1 影响分析

当修改某个函数时,通过调用图可以快速确定所有受影响的下游函数:

代码语言:javascript
复制
class ImpactAnalyzer:
    """
    代码影响分析器
    基于图结构分析代码修改的影响范围
    """
    
    def __init__(self, call_graph: CallGraph, dependency_graph: DependencyGraph):
        self.call_graph = call_graph
        self.dependency_graph = dependency_graph
    
    def analyze_impact(self, file_path: str, function_name: str) -> dict:
        """
        分析修改指定函数的影响
        
        Returns:
            {
                'direct_callers': [...],  # 直接调用者
                'all_callers': [...],     # 所有递归调用者
                'affected_files': [...], # 受影响的文件
                'test_files': [...],     # 可能受影响的测试文件
            }
        """
        func_id = f"{file_path}:{function_name}"
        
        # 获取所有调用者
        all_callers = self.call_graph.find_all_callers(func_id)
        
        # 转换为文件列表
        affected_files = set()
        for caller_id in all_callers:
            if ':' in caller_id:
                caller_file = caller_id.rsplit(':', 1)[0]
                affected_files.add(caller_file)
        
        # 获取依赖该文件的模块
        reverse_deps = self.dependency_graph.get_reverse_transitive_closure(file_path)
        
        return {
            'direct_callers': list(self.call_graph.get_callers(func_id)),
            'all_callers': list(all_callers),
            'affected_files': list(affected_files),
            'affected_dependencies': list(reverse_deps)
        }
4.4.2 上下文优先级排序

基于图的拓扑排序,AI IDE 可以为检索结果赋予优先级:

代码语言:javascript
复制
class GraphContextScorer:
    """
    基于图的上下文评分器
    
    根据代码在图中的位置计算重要性分数
    """
    
    def __init__(self, call_graph: CallGraph, dependency_graph: DependencyGraph):
        self.call_graph = call_graph
        self.dependency_graph = dependency_graph
    
    def score(self, file_path: str, function_name: str = None) -> float:
        """
        计算代码的重要性分数
        
        分数考量因素:
        1. 被调用/引用次数(越多越重要)
        2. 在依赖图中的层级(越底层越核心)
        3. 是否为公共 API
        """
        score = 1.0
        
        if function_name:
            node_id = f"{file_path}:{function_name}"
            
            # 调用次数因子
            caller_count = len(self.call_graph.get_callers(node_id))
            score *= (1 + caller_count * 0.1)
            
            # 导出函数因子
            if not function_name.startswith('_'):
                score *= 1.5
        
        # 文件级别的评分
        direct_deps = self.dependency_graph.get_transitive_closure(file_path)
        dependents = self.dependency_graph.get_reverse_transitive_closure(file_path)
        
        # 被越多模块依赖越重要
        score *= (1 + len(dependents) * 0.05)
        
        return min(score, 10.0)  # 上限

5. RAG 策略:检索增强的语义匹配

本节为你提供的核心技术价值:掌握向量检索与 BM25 的原理与实现,理解混合检索的架构设计,实现高质量的语义匹配。

5.1 RAG 在 AI IDE 中的定位

RAG(Retrieval-Augmented Generation,检索增强生成)策略将信息检索语言模型生成相结合,是 Context Engine 中实现语义匹配的核心组件。

与传统的关键词检索相比,RAG 能够:

  1. 理解语义:即使查询语句与代码表面词汇不同,也能找到语义相关的代码
  2. 处理歧义:理解"登录功能"可能对应 loginauthenticatesignIn 等多种实现
  3. 跨语言匹配:自然语言查询与代码之间的语义桥接
5.2 向量检索基础

向量检索(Vector Search)将文本映射到高维向量空间,通过向量相似度计算实现语义匹配。

5.2.1 文本向量化
代码语言:javascript
复制
import numpy as np
from typing import List, Optional
from dataclasses import dataclass, field
import hashlib


@dataclass
class TextVector:
    """文本向量"""
    text: str
    vector: np.ndarray
    metadata: dict = field(default_factory=dict)
    
    @property
    def norm(self) -> float:
        """向量的 L2 范数"""
        return np.linalg.norm(self.vector)
    
    def cosine_similarity(self, other: 'TextVector') -> float:
        """计算余弦相似度"""
        dot_product = np.dot(self.vector, other.vector)
        return dot_product / (self.norm * other.norm + 1e-8)


class CodeVectorizer:
    """
    代码向量化器
    
    支持多种嵌入模型:
    1. 基于 TF-IDF 的传统方法
    2. 基于 Sentence Transformers 的现代方法
    3. 专门针对代码的 CodeBERT、GraphCodeBERT
    """
    
    def __init__(self, model_name: str = 'codebert'):
        self.model_name = model_name
        self.model = None
        self._load_model()
    
    def _load_model(self):
        """加载向量化模型"""
        if self.model_name == 'tfidf':
            self.model = TFIDFVectorizer()
        elif self.model_name == 'codebert':
            self.model = CodeBERTModel()
        elif self.model_name == 'sentence-transformers':
            self.model = SentenceTransformerModel()
        else:
            raise ValueError(f"Unknown model: {self.model_name}")
    
    def vectorize(self, code: str, language: str = None) -> TextVector:
        """将代码向量化"""
        vector = self.model.encode(code, language)
        return TextVector(text=code, vector=vector)
    
    def batch_vectorize(self, codes: List[str], language: str = None) -> List[TextVector]:
        """批量向量化"""
        return [self.vectorize(code, language) for code in codes]


class TFIDFVectorizer:
    """
    基于 TF-IDF 的向量化器
    
    传统方法,适用于快速原型和简单场景
    """
    
    def __init__(self, max_features: int = 4096):
        self.max_features = max_features
        self.vocabulary = {}
        self.idf = {}
        self.doc_count = 0
    
    def fit(self, documents: List[str]):
        """构建词汇表和 IDF"""
        from collections import Counter
        
        self.doc_count = len(documents)
        
        # 词频统计
        df = Counter()
        for doc in documents:
            tokens = self._tokenize(doc)
            df.update(set(tokens))
        
        # 构建词汇表(取 top max_features)
        for word, freq in df.most_common(self.max_features):
            self.vocabulary[word] = len(self.vocabulary)
        
        # 计算 IDF
        for word, doc_freq in df.items():
            if word in self.vocabulary:
                self.idf[word] = np.log(self.doc_count / (doc_freq + 1)) + 1
    
    def encode(self, text: str, language: str = None) -> np.ndarray:
        """将文本编码为向量"""
        tokens = self._tokenize(text)
        
        # TF
        tf = Counter(tokens)
        
        # 构建向量
        vector = np.zeros(len(self.vocabulary))
        for word, freq in tf.items():
            if word in self.vocabulary:
                idx = self.vocabulary[word]
                vector[idx] = freq * self.idf.get(word, 1.0)
        
        # 归一化
        norm = np.linalg.norm(vector)
        if norm > 0:
            vector = vector / norm
        
        return vector
    
    def _tokenize(self, text: str) -> List[str]:
        """分词"""
        import re
        # 简单分词:保留单词和下划线,移除数字和特殊字符
        tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*', text.lower())
        return tokens
5.2.2 向量索引与检索
代码语言:javascript
复制
from typing import List, Tuple, Optional
import heapq


class VectorIndex:
    """
    向量索引
    
    支持:
    1. 暴力搜索(精确)
    2. IVF-PQ 索引(近似)
    3. HNSW 索引(近似)
    """
    
    def __init__(self, dimension: int, index_type: str = 'flat'):
        self.dimension = dimension
        self.index_type = index_type
        self.vectors: List[TextVector] = []
        
        if index_type == 'flat':
            self._search = self._flat_search
        elif index_type == 'ivf_pq':
            self._index = None
            self._search = self._ivf_pq_search
        elif index_type == 'hnsw':
            self._hnsw = HNSWIndex(dimension)
            self._search = self._hnsw_search
    
    def add(self, vector: TextVector):
        """添加向量到索引"""
        self.vectors.append(vector)
        
        if self.index_type == 'hnsw':
            self._hnsw.add(vector)
    
    def search(self, query: TextVector, k: int = 10) -> List[Tuple[TextVector, float]]:
        """
        搜索最相似的 k 个向量
        
        Returns:
            [(vector, similarity), ...]
        """
        return self._search(query, k)
    
    def _flat_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
        """暴力搜索"""
        similarities = []
        
        for vec in self.vectors:
            sim = query.cosine_similarity(vec)
            similarities.append((vec, sim))
        
        # 返回 top-k
        return heapq.nlargest(k, similarities, key=lambda x: x[1])
    
    def _ivf_pq_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
        """IVF-PQ 近似搜索(简化实现)"""
        # 实际实现应使用 FAISS 库
        # 这里返回暴力搜索作为简化
        return self._flat_search(query, k)
    
    def _hnsw_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
        """HNSW 近似搜索"""
        return self._hnsw.search(query, k)


class HNSWIndex:
    """
    Hierarchical Navigable Small World (HNSW) 索引
    
    一种高效的近似最近邻搜索算法
    """
    
    def __init__(self, dimension: int, m: int = 16, ef: int = 200):
        self.dimension = dimension
        self.m = m  # 每个节点的最大连接数
        self.ef = ef  # 搜索时的动态列表大小
        self.layers: List[List[TextVector]] = []
        self.graph: List[dict[int, list[int]]] = []  # 每层的邻接表
    
    def add(self, vector: TextVector):
        """添加向量到索引"""
        # 简化实现
        if not self.layers:
            self.layers.append([])
            self.graph.append({})
        
        self.layers[0].append(vector)
    
    def search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
        """搜索"""
        # 简化实现:使用第一层的暴力搜索
        if not self.layers:
            return []
        
        layer_vectors = self.layers[0]
        similarities = [(v, query.cosine_similarity(v)) for v in layer_vectors]
        return heapq.nlargest(k, similarities, key=lambda x: x[1])
5.3 BM25 检索算法

BM25(Best Matching 25)是一种基于词频的概率检索模型,在信息检索领域广泛应用。其核心思想是:一个词在文档中出现的次数越多,文档越相关;但词频达到一定饱和后,边际效用递减

代码语言:javascript
复制
import math
from typing import List, Tuple, Set
from collections import Counter


class BM25:
    """
    BM25 检索算法
    
    公式:
    Score(D, Q) = Σ IDF(qi) * (tf(qi, D) * (k1 + 1)) / (tf(qi, D) + k1 * (1 - b + b * |D|/avgdl))
    
    其中:
    - tf(qi, D): 词 qi 在文档 D 中的词频
    - |D|: 文档长度
    - avgdl: 平均文档长度
    - k1: 词频饱和参数(通常 1.2-2.0)
    - b: 文档长度归一化参数(通常 0.75)
    - IDF(qi): 逆文档频率
    """
    
    def __init__(self, k1: float = 1.5, b: float = 0.75):
        self.k1 = k1
        self.b = b
        self.documents: List[str] = []
        self.tokenized_docs: List[List[str]] = []
        self.avgdl = 0.0
        self.doc_freq: dict = {}  # 词 -> 包含该词的文档数
        self.N = 0  # 总文档数
        self.idf: dict = {}
    
    def fit(self, documents: List[str]):
        """
        构建 BM25 索引
        
        Args:
            documents: 文档列表
        """
        self.documents = documents
        self.N = len(documents)
        
        # 分词
        self.tokenized_docs = [self._tokenize(doc) for doc in documents]
        
        # 计算平均文档长度
        total_len = sum(len(doc) for doc in self.tokenized_docs)
        self.avgdl = total_len / self.N if self.N > 0 else 0
        
        # 统计文档频率
        self.doc_freq = Counter()
        for doc_tokens in self.tokenized_docs:
            for token in set(doc_tokens):
                self.doc_freq[token] += 1
        
        # 计算 IDF
        self._compute_idf()
    
    def _tokenize(self, text: str) -> List[str]:
        """分词"""
        import re
        # 保留编程语言标识符
        tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*', text.lower())
        return tokens
    
    def _compute_idf(self):
        """计算 IDF 值"""
        for term, df in self.doc_freq.items():
            # IDF 公式:log((N - df + 0.5) / (df + 0.5))
            self.idf[term] = math.log((self.N - df + 0.5) / (df + 0.5) + 1)
    
    def score(self, query: str, doc_idx: int) -> float:
        """
        计算查询对单个文档的 BM25 分数
        
        Args:
            query: 查询字符串
            doc_idx: 文档索引
            
        Returns:
            BM25 分数
        """
        query_tokens = self._tokenize(query)
        doc_tokens = self.tokenized_docs[doc_idx]
        doc_len = len(doc_tokens)
        
        score = 0.0
        tf = Counter(doc_tokens)
        
        for token in query_tokens:
            if token not in self.idf:
                continue
            
            term_freq = tf[token]
            idf = self.idf[token]
            
            # BM25 公式
            numerator = term_freq * (self.k1 + 1)
            denominator = term_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
            
            score += idf * (numerator / denominator)
        
        return score
    
    def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
        """
        搜索最相关的文档
        
        Args:
            query: 查询字符串
            top_k: 返回的最多结果数
            
        Returns:
            [(doc_idx, score), ...]
        """
        scores = [(i, self.score(query, i)) for i in range(self.N)]
        
        # 过滤零分结果并排序
        scores = [(idx, score) for idx, score in scores if score > 0]
        scores.sort(key=lambda x: x[1], reverse=True)
        
        return scores[:top_k]
5.4 混合检索架构

向量检索擅长语义匹配但对精确关键词不敏感,BM25 擅长精确匹配但无法理解语义。混合检索通过融合两种方法的优势,实现1+1 > 2的效果。

代码语言:javascript
复制
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass, field
import numpy as np


@dataclass
class SearchResult:
    """搜索结果"""
    chunk_id: str
    content: str
    file_path: str
    start_line: int
    end_line: int
    vector_score: float = 0.0
    bm25_score: float = 0.0
    combined_score: float = 0.0
    rank: int = 0
    metadata: dict = field(default_factory=dict)


class HybridRetriever:
    """
    混合检索器
    
    融合向量检索和 BM25 检索的结果
    """
    
    def __init__(
        self,
        vector_weight: float = 0.6,
        bm25_weight: float = 0.4,
        rerank: bool = True
    ):
        """
        Args:
            vector_weight: 向量检索权重
            bm25_weight: BM25 权重
            rerank: 是否进行重排序
        """
        self.vector_weight = vector_weight
        self.bm25_weight = bm25_weight
        self.rerank = rerank
        
        # 各子检索器
        self.vector_index: VectorIndex = None
        self.bm25: BM25 = None
        
        # 文档存储
        self.documents: Dict[str, dict] = {}
    
    def index(self, chunks: List[dict]):
        """
        构建索引
        
        Args:
            chunks: [{'id': ..., 'content': ..., 'file_path': ..., 'metadata': ...}, ...]
        """
        # 存储文档
        for chunk in chunks:
            self.documents[chunk['id']] = chunk
        
        # 构建向量索引
        vectorizer = CodeVectorizer('sentence-transformers')
        vectors = [vectorizer.vectorize(chunk['content']) for chunk in chunks]
        
        self.vector_index = VectorIndex(dimension=len(vectors[0].vector) if vectors else 384)
        for i, vec in enumerate(vectors):
            vec.metadata = {'chunk_id': chunks[i]['id']}
            self.vector_index.add(vec)
        
        # 构建 BM25 索引
        self.bm25 = BM25()
        self.bm25.fit([chunk['content'] for chunk in chunks])
    
    def search(
        self, 
        query: str, 
        top_k: int = 20,
        max_results: int = 10
    ) -> List[SearchResult]:
        """
        混合搜索
        
        Args:
            query: 查询字符串
            top_k: 每种检索方式返回的结果数
            max_results: 最终返回的结果数
        """
        # 并行执行两种检索
        vector_results = self._vector_search(query, top_k)
        bm25_results = self._bm25_search(query, top_k)
        
        # 分数归一化
        vector_results = self._normalize_scores(vector_results)
        bm25_results = self._normalize_scores(bm25_results)
        
        # 合并结果
        combined = self._combine_results(vector_results, bm25_results)
        
        # 重排序
        if self.rerank:
            combined = self._rerank(combined, query)
        
        # 截取 top_k
        combined = combined[:max_results]
        
        # 设置最终排名
        for i, result in enumerate(combined):
            result.rank = i + 1
        
        return combined
    
    def _vector_search(self, query: str, top_k: int) -> Dict[str, float]:
        """向量检索"""
        if not self.vector_index:
            return {}
        
        vectorizer = CodeVectorizer('sentence-transformers')
        query_vector = vectorizer.vectorize(query)
        
        results = self.vector_index.search(query_vector, top_k)
        
        return {r[0].metadata.get('chunk_id', ''): r[1] for r in results}
    
    def _bm25_search(self, query: str, top_k: int) -> Dict[str, float]:
        """BM25 检索"""
        if not self.bm25:
            return {}
        
        scores = self.bm25.search(query, top_k)
        
        return {str(idx): score for idx, score in scores}
    
    def _normalize_scores(self, scores: Dict[str, float]) -> Dict[str, float]:
        """Min-Max 归一化分数"""
        if not scores:
            return {}
        
        min_s = min(scores.values())
        max_s = max(scores.values())
        
        if max_s == min_s:
            return {k: 1.0 for k in scores}
        
        return {k: (v - min_s) / (max_s - min_s) for k, v in scores.items()}
    
    def _combine_results(
        self, 
        vector_scores: Dict[str, float], 
        bm25_scores: Dict[str, float]
    ) -> List[SearchResult]:
        """合并检索结果"""
        all_ids = set(vector_scores.keys()) | set(bm25_scores.keys())
        
        results = []
        for chunk_id in all_ids:
            if chunk_id not in self.documents:
                continue
            
            doc = self.documents[chunk_id]
            
            v_score = vector_scores.get(chunk_id, 0.0)
            b_score = bm25_scores.get(chunk_id, 0.0)
            
            # 加权求和
            combined = self.vector_weight * v_score + self.bm25_weight * b_score
            
            result = SearchResult(
                chunk_id=chunk_id,
                content=doc['content'],
                file_path=doc.get('file_path', ''),
                start_line=doc.get('start_line', 0),
                end_line=doc.get('end_line', 0),
                vector_score=v_score,
                bm25_score=b_score,
                combined_score=combined,
                metadata=doc.get('metadata', {})
            )
            
            results.append(result)
        
        # 按综合分数排序
        results.sort(key=lambda x: x.combined_score, reverse=True)
        
        return results
    
    def _rerank(self, results: List[SearchResult], query: str) -> List[SearchResult]:
        """
        使用交叉编码器重排序
        
        简化实现:基于关键词匹配的重排序
        实际应使用 Cross-Encoder 模型
        """
        query_terms = set(query.lower().split())
        
        for result in results:
            content_terms = set(result.content.lower().split())
            
            # 计算查询词在结果中的覆盖率
            coverage = len(query_terms & content_terms) / len(query_terms) if query_terms else 0
            
            # 轻微调整分数:覆盖率高的结果提升
            result.combined_score = result.combined_score * (1 + 0.1 * coverage)
        
        results.sort(key=lambda x: x.combined_score, reverse=True)
        
        return results
5.5 RAG 上下文组装策略

检索到的结果需要组装成最终的上下文,供给 AI 模型使用。这个过程需要考虑:

  1. 上下文长度控制:确保总 token 数不超过模型限制
  2. 位置优化:将最重要的信息放在上下文的开头和结尾(lost in the middle 效应)
  3. 结构保留:保持代码的完整性,避免在语句中间截断
代码语言:javascript
复制
class RAGContextAssembler:
    """
    RAG 上下文组装器
    
    将检索结果组装成适合 AI 模型的上下文
    """
    
    def __init__(
        self,
        max_tokens: int = 8000,
        preserve_edges: bool = True,
        position_strategy: str = 'balanced'
    ):
        """
        Args:
            max_tokens: 最大 token 数
            preserve_edges: 是否保留代码块的完整边界
            position_strategy: 位置策略
                - 'balanced': 首尾各 30%,中间 40%
                - 'front': 前 70%
                - 'back': 后 70%
                - 'uniform': 均匀分布
        """
        self.max_tokens = max_tokens
        self.preserve_edges = preserve_edges
        self.position_strategy = position_strategy
    
    def assemble(
        self, 
        results: List[SearchResult], 
        query: str,
        system_prompt: str = None
    ) -> str:
        """
        组装上下文
        
        Args:
            results: 排序后的检索结果
            query: 用户查询
            system_prompt: 系统提示
            
        Returns:
            组装后的上下文字符串
        """
        # 估算 token
        def estimate_tokens(text: str) -> int:
            # 简化估算:约 4 字符/token
            return len(text) // 4
        
        # 计算系统提示的 token
        system_tokens = estimate_tokens(system_prompt) if system_prompt else 0
        query_tokens = estimate_tokens(f"Query: {query}\n")
        
        # 可用于上下文的 token
        available_tokens = self.max_tokens - system_tokens - query_tokens - 200  # 保留一些余量
        
        # 按位置策略分配 token
        token_allocation = self._allocate_tokens(len(results), available_tokens)
        
        # 选择和裁剪内容
        context_parts = []
        current_tokens = 0
        
        for i, (result, allocated) in enumerate(zip(results, token_allocation)):
            content_tokens = estimate_tokens(result.content)
            
            if content_tokens <= allocated:
                # 内容足够,直接添加
                context_parts.append(self._format_result(result, i + 1))
                current_tokens += content_tokens
            else:
                # 需要裁剪
                truncated = self._truncate_content(result, allocated)
                context_parts.append(self._format_result_with_truncation(result, truncated, i + 1))
                current_tokens += estimate_tokens(truncated)
        
        # 组装最终上下文
        if system_prompt:
            final_context = f"{system_prompt}\n\n"
        else:
            final_context = ""
        
        final_context += f"Query: {query}\n\n"
        final_context += "Relevant Code:\n"
        final_context += "\n\n".join(context_parts)
        
        return final_context
    
    def _allocate_tokens(self, num_results: int, total_tokens: int) -> List[int]:
        """根据位置策略分配 token"""
        if num_results == 0:
            return []
        
        # 计算每部分的比例
        if self.position_strategy == 'balanced':
            # 首尾各 35%,中间 30%
            if num_results == 1:
                ratios = [1.0]
            elif num_results == 2:
                ratios = [0.5, 0.5]
            elif num_results >= 3:
                mid_count = num_results - 2
                front = 0.35
                back = 0.35
                middle = 0.30
                
                ratios = []
                ratios.append(front)
                for _ in range(mid_count):
                    ratios.append(middle / mid_count)
                ratios.append(back)
        elif self.position_strategy == 'front':
            ratios = [1.0 / num_results] * num_results
            # 前面的权重更高
            for i in range(num_results):
                ratios[i] *= (num_results - i) / (num_results - i + 1)
        elif self.position_strategy == 'back':
            ratios = [1.0 / num_results] * num_results
            for i in range(num_results):
                ratios[i] *= (i + 1) / (i + 2)
        else:  # uniform
            ratios = [1.0 / num_results] * num_results
        
        # 转换为 token 数量
        tokens = [int(total_tokens * r) for r in ratios]
        
        # 调整以确保总和正确
        diff = total_tokens - sum(tokens)
        if diff != 0 and tokens:
            tokens[-1] += diff
        
        return tokens
    
    def _format_result(self, result: SearchResult, rank: int) -> str:
        """格式化单个结果"""
        return f"""--- Result #{rank} (Score: {result.combined_score:.3f}) ---
File: {result.file_path}:{result.start_line}-{result.end_line}
{result.content}
---"""
    
    def _format_result_with_truncation(
        self, 
        result: SearchResult, 
        truncated_content: str,
        rank: int
    ) -> str:
        """格式化截断的结果"""
        return f"""--- Result #{rank} (Score: {result.combined_score:.3f}, TRUNCATED) ---
File: {result.file_path}:{result.start_line}-{result.end_line}
{truncated_content}
[... {len(result.content) - len(truncated_content)} characters truncated ...]
---"""
    
    def _truncate_content(self, result: SearchResult, max_tokens: int) -> str:
        """截断内容以适应 token 限制"""
        max_chars = max_tokens * 4  # 约 4 字符/token
        
        if len(result.content) <= max_chars:
            return result.content
        
        # 优先保留重要部分:开头和结尾
        keep_from_start = max_chars // 2
        keep_from_end = max_chars // 2
        
        start_part = result.content[:keep_from_start]
        end_part = result.content[-keep_from_end:]
        
        return start_part + "\n\n[... content truncated ...]\n\n" + end_part

6. 上下文选择算法:从海量代码中精准选材

本节为你提供的核心技术价值:理解上下文选择的挑战与算法设计,掌握多阶段筛选与贪心选择的核心原理,实现高效的大规模代码库上下文管理。

6.1 上下文选择的挑战

在百万行级别的代码库中,即使进行了分块和检索,候选上下文仍然可能远超模型的容量限制。上下文选择算法需要在召回率(relevant content 覆盖率)和精确率(irrelevant noise 排除率)之间取得平衡。

6.1.1 挑战的本质
代码语言:javascript
复制
代码库规模:1,000,000 行
分块大小:100 行/块
总块数:10,000 块
模型上下文:100,000 token ≈ 25,000 行 ≈ 250 块

选择比例:250 / 10,000 = 2.5%

问题:从 10,000 个块中选出最重要的 250 个

这是一个典型的top-k 选择问题,但具有以下特殊性:

  1. 相关性非单调:一个块与查询的相关性不等于其各部分相关性之和
  2. 位置偏差:用户当前编辑位置附近的代码通常更重要
  3. 时效性:最近修改的代码可能更相关
  4. 多样性需求:避免选择太多相似的块
6.2 多阶段筛选算法
代码语言:javascript
复制
from typing import List, Set, Dict, Tuple, Optional
from dataclasses import dataclass, field
import heapq


@dataclass
class ScoredChunk:
    """带分数的代码块"""
    chunk_id: str
    content: str
    file_path: str
    start_line: int
    end_line: int
    score: float = 0.0
    features: dict = field(default_factory=dict)


class MultiStageSelector:
    """
    多阶段上下文选择器
    
    阶段:
    1. 粗筛:快速过滤明显不相关的块
    2. 精排:计算每个块的详细分数
    3. 贪心选择:在多样性约束下选择最优组合
    """
    
    def __init__(
        self,
        max_chunks: int = 100,
        diversity_weight: float = 0.2,
        recency_weight: float = 0.1,
        proximity_weight: float = 0.3,
        relevance_weight: float = 0.4
    ):
        self.max_chunks = max_chunks
        self.diversity_weight = diversity_weight
        self.recency_weight = recency_weight
        self.proximity_weight = proximity_weight
        self.relevance_weight = relevance_weight
        
        # 状态
        self.current_file = None
        self.current_line = 0
        self.modification_times: Dict[str, float] = {}  # chunk_id -> timestamp
    
    def select(
        self,
        candidates: List[ScoredChunk],
        query: str,
        current_file: str = None,
        current_line: int = 0
    ) -> List[ScoredChunk]:
        """
        多阶段选择
        
        Args:
            candidates: 候选块列表
            query: 查询
            current_file: 当前编辑文件
            current_line: 当前编辑行
            
        Returns:
            选中的块列表
        """
        self.current_file = current_file
        self.current_line = current_line
        
        # 阶段 1:粗筛 - 基于简单规则快速过滤
        filtered = self._coarse_filter(candidates, query)
        
        # 阶段 2:精排 - 计算综合分数
        scored = self._fine_rank(filtered, query)
        
        # 阶段 3:贪心选择 - 考虑多样性
        selected = self._greedy_select(scored)
        
        return selected
    
    def _coarse_filter(self, candidates: List[ScoredChunk], query: str) -> List[ScoredChunk]:
        """
        粗筛阶段
        
        快速过滤策略:
        1. 文件类型过滤
        2. 关键词黑名单过滤
        3. 最小相关性阈值过滤
        """
        # 黑名单关键词
        blacklist = {'test', 'mock', 'fixture', '__pycache__', '.min.js'}
        
        # 扩展查询关键词
        query_terms = set(query.lower().split())
        
        filtered = []
        for chunk in candidates:
            # 文件类型检查
            if any(term in chunk.file_path.lower() for term in blacklist):
                continue
            
            # 基础相关性检查:至少有一个查询词在内容或路径中
            content_lower = chunk.content.lower()
            path_lower = chunk.file_path.lower()
            
            has_match = any(
                term in content_lower or term in path_lower
                for term in query_terms
            )
            
            if has_match or not query_terms:
                filtered.append(chunk)
        
        # 如果过滤后仍然太多,随机采样
        if len(filtered) > self.max_chunks * 3:
            import random
            random.shuffle(filtered)
            filtered = filtered[:self.max_chunks * 3]
        
        return filtered
    
    def _fine_rank(self, candidates: List[ScoredChunk], query: str) -> List[ScoredChunk]:
        """
        精排阶段
        
        计算每个块的多维度分数:
        1. 相关性分数(来自 RAG)
        2. 邻近性分数(与当前编辑位置的距离)
        3. 时效性分数(最近修改)
        4. 多样性分数(与已选块的区别)
        """
        for chunk in candidates:
            features = self._compute_features(chunk, query)
            chunk.features = features
            
            # 综合分数
            chunk.score = (
                self.relevance_weight * features['relevance'] +
                self.proximity_weight * features['proximity'] +
                self.recency_weight * features['recency']
            )
        
        # 按分数排序
        candidates.sort(key=lambda c: c.score, reverse=True)
        
        return candidates
    
    def _compute_features(self, chunk: ScoredChunk, query: str) -> dict:
        """计算多维特征"""
        features = {}
        
        # 1. 相关性分数(这里简化,实际来自 RAG 检索)
        query_terms = set(query.lower().split())
        content_terms = set(chunk.content.lower().split())
        
        # Jaccard 相似度
        intersection = len(query_terms & content_terms)
        union = len(query_terms | content_terms)
        features['relevance'] = intersection / union if union > 0 else 0
        
        # 2. 邻近性分数
        if self.current_file and chunk.file_path == self.current_file:
            # 在同一文件中,距离越近分数越高
            distance = abs(chunk.start_line - self.current_line)
            features['proximity'] = 1.0 / (1.0 + distance * 0.1)
        else:
            features['proximity'] = 0.0
        
        # 3. 时效性分数
        mod_time = self.modification_times.get(chunk.chunk_id, 0)
        if mod_time > 0:
            import time
            age_days = (time.time() - mod_time) / 86400
            features['recency'] = 1.0 / (1.0 + age_days)
        else:
            features['recency'] = 0.5  # 默认中间值
        
        return features
    
    def _greedy_select(self, candidates: List[ScoredChunk]) -> List[ScoredChunk]:
        """
        贪心选择
        
        在考虑多样性的情况下选择最优组合
        """
        if len(candidates) <= self.max_chunks:
            return candidates
        
        selected = []
        remaining = candidates.copy()
        
        while len(selected) < self.max_chunks and remaining:
            # 选择当前最优
            best = remaining.pop(0)
            selected.append(best)
            
            # 惩罚与已选块相似的内容(多样性维护)
            remaining = self._penalize_similar(remaining, selected)
            
            # 重新排序
            remaining.sort(key=lambda c: c.score, reverse=True)
        
        return selected
    
    def _penalize_similar(
        self, 
        candidates: List[ScoredChunk], 
        selected: List[ScoredChunk]
    ) -> List[ScoredChunk]:
        """惩罚与已选块相似的内容"""
        selected_signatures = [self._get_signature(c) for c in selected]
        
        penalized = []
        for candidate in candidates:
            sig = self._get_signature(candidate)
            
            # 计算与最近选择块的相似度
            max_similarity = 0.0
            for selected_sig in selected_signatures[-5:]:  # 只看最近 5 个
                sim = self._signature_similarity(sig, selected_sig)
                max_similarity = max(max_similarity, sim)
            
            # 应用惩罚
            candidate.score *= (1 - self.diversity_weight * max_similarity)
            penalized.append(candidate)
        
        return penalized
    
    def _get_signature(self, chunk: ScoredChunk) -> Set[str]:
        """获取块的签名(顶层词汇集合)"""
        # 取内容的前 20 个实词作为签名
        import re
        tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]{2,}', chunk.content.lower())
        
        # 去除停用词
        stopwords = {'function', 'class', 'def', 'return', 'if', 'else', 'for', 'while', 'import', 'from', 'const', 'let', 'var', 'public', 'private', 'static'}
        tokens = [t for t in tokens if t not in stopwords]
        
        return set(tokens[:20])
    
    def _signature_similarity(self, sig1: Set[str], sig2: Set[str]) -> float:
        """计算两个签名的相似度"""
        if not sig1 or not sig2:
            return 0.0
        
        intersection = len(sig1 & sig2)
        union = len(sig1 | sig2)
        
        return intersection / union if union > 0 else 0.0
6.3 上下文预算分配

在实际场景中,AI IDE 需要处理多种类型的上下文请求:代码补全、代码生成、代码解释、bug 修复等。每种请求对上下文的需求不同,需要动态分配上下文预算

代码语言:javascript
复制
from enum import Enum
from typing import Dict, Callable


class QueryType(Enum):
    """查询类型"""
    COMPLETION = "completion"       # 代码补全
    GENERATION = "generation"       # 代码生成
    EXPLANATION = "explanation"     # 代码解释
    REFACTORING = "refactoring"    # 重构
    BUG_FIX = "bug_fix"            # Bug 修复
    DEBUGGING = "debugging"         # 调试
    GENERAL = "general"             # 通用


class ContextBudgetAllocator:
    """
    上下文预算分配器
    
    根据查询类型动态分配上下文预算
    """
    
    # 不同查询类型的预算分配策略
    BUDGET_STRATEGIES = {
        QueryType.COMPLETION: {
            'max_context': 2000,
            'focus': 'current_file',
            'include_imports': True,
            'include_related': True,
            'include_tests': False,
        },
        QueryType.GENERATION: {
            'max_context': 8000,
            'focus': 'related_files',
            'include_imports': True,
            'include_related': True,
            'include_tests': False,
        },
        QueryType.EXPLANATION: {
            'max_context': 5000,
            'focus': 'target_code',
            'include_imports': True,
            'include_related': False,
            'include_tests': False,
        },
        QueryType.REFACTORING: {
            'max_context': 10000,
            'focus': 'target_and_callers',
            'include_imports': True,
            'include_related': True,
            'include_tests': True,
        },
        QueryType.BUG_FIX: {
            'max_context': 8000,
            'focus': 'error_and_related',
            'include_imports': True,
            'include_related': True,
            'include_tests': True,
        },
        QueryType.DEBUGGING: {
            'max_context': 6000,
            'focus': 'error_context',
            'include_imports': False,
            'include_related': True,
            'include_tests': False,
        },
        QueryType.GENERAL: {
            'max_context': 5000,
            'focus': 'balanced',
            'include_imports': True,
            'include_related': True,
            'include_tests': False,
        },
    }
    
    def allocate(
        self,
        query_type: QueryType,
        query: str,
        available_budget: int
    ) -> Dict:
        """
        分配上下文预算
        
        Returns:
            {
                'max_context': 最大上下文量,
                'chunk_budgets': {'type': budget},
                'focus_areas': [...]
            }
        """
        strategy = self.BUDGET_STRATEGIES.get(query_type, self.BUDGET_STRATEGIES[QueryType.GENERAL])
        
        # 根据可用预算调整
        max_context = min(strategy['max_context'], available_budget)
        
        # 计算各部分预算
        chunk_budgets = self._calculate_chunk_budgets(strategy, max_context)
        
        return {
            'query_type': query_type,
            'max_context': max_context,
            'chunk_budgets': chunk_budgets,
            'focus_areas': strategy['focus'].split('_'),
            'include_imports': strategy['include_imports'],
            'include_related': strategy['include_related'],
            'include_tests': strategy['include_tests'],
        }
    
    def _calculate_chunk_budgets(
        self, 
        strategy: Dict, 
        total_budget: int
    ) -> Dict[str, int]:
        """计算各部分的预算分配"""
        focus = strategy['focus']
        
        if focus == 'current_file':
            return {
                'current_function': int(total_budget * 0.4),
                'current_file': int(total_budget * 0.3),
                'imports': int(total_budget * 0.2),
                'related': int(total_budget * 0.1),
            }
        elif focus == 'target_and_callers':
            return {
                'target_code': int(total_budget * 0.35),
                'callers': int(total_budget * 0.25),
                'callees': int(total_budget * 0.15),
                'dependencies': int(total_budget * 0.15),
                'tests': int(total_budget * 0.1),
            }
        elif focus == 'error_context':
            return {
                'error_location': int(total_budget * 0.3),
                'error_trace': int(total_budget * 0.25),
                'related_functions': int(total_budget * 0.25),
                'data_flow': int(total_budget * 0.2),
            }
        else:  # balanced
            return {
                'target_code': int(total_budget * 0.3),
                'related_code': int(total_budget * 0.3),
                'imports': int(total_budget * 0.2),
                'tests': int(total_budget * 0.2),
            }

7. 成本控制:Token 预算与上下文窗口的博弈

本节为你提供的核心技术价值:理解 AI IDE 中的成本结构,掌握 Token 预算管理策略,实现效果与成本的最优平衡。

7.1 Token 经济的本质

在 AI IDE 中,每一次 AI 模型的调用都涉及 Token 消耗。Token 消耗主要来自两部分:

  1. 输入 Token:用户查询 + 上下文(代码块、对话历史)
  2. 输出 Token:AI 模型的响应

根据 2026 年的主流模型定价[^5]:

模型

输入 ($/1M tokens)

输出 ($/1M tokens)

GPT-4o

$2.50

$10.00

Claude 3.5 Sonnet

$3.00

$15.00

Gemini 1.5 Pro

$1.25

$5.00

对于一个日活跃开发者(假设每天 200 次 AI 调用,平均每次 50K 输入 tokens),月度成本可达:

代码语言:javascript
复制
200 calls/day × 30 days × 50K tokens × $3/1M = $90/月/人

对于一个 100 人的开发团队,月度成本约 $9,000。因此,Token 优化不仅是技术问题,更是经济问题。

7.2 Token 预算管理框架
代码语言:javascript
复制
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Tuple
from enum import Enum
import time


class CostLevel(Enum):
    """成本等级"""
    LOW = "low"       # 简单查询,本地模型
    MEDIUM = "medium" # 标准查询,中等模型
    HIGH = "high"     # 复杂查询,顶级模型


@dataclass
class TokenBudget:
    """Token 预算"""
    total: int
    used: int = 0
    reserved: int = 0  # 预留给关键上下文的 tokens
    
    @property
    def available(self) -> int:
        return self.total - self.used - self.reserved
    
    @property
    def usage_ratio(self) -> float:
        return (self.used + self.reserved) / self.total if self.total > 0 else 0


class TokenBudgetManager:
    """
    Token 预算管理器
    
    职责:
    1. 跟踪 Token 消耗
    2. 动态调整预算分配
    3. 实施成本控制策略
    """
    
    def __init__(self, daily_budget_tokens: int = 10_000_000):
        """
        Args:
            daily_budget_tokens: 每日 Token 预算上限
        """
        self.daily_budget = daily_budget_tokens
        self.daily_used = 0
        self.last_reset = time.time()
        
        # 各功能类型的预算
        self.feature_budgets: Dict[str, TokenBudget] = {
            'completion': TokenBudget(total=int(daily_budget_tokens * 0.3)),
            'generation': TokenBudget(total=int(daily_budget_tokens * 0.25)),
            'explanation': TokenBudget(total=int(daily_budget_tokens * 0.15)),
            'refactoring': TokenBudget(total=int(daily_budget_tokens * 0.15)),
            'other': TokenBudget(total=int(daily_budget_tokens * 0.15)),
        }
        
        # 成本等级配置
        self.cost_thresholds = {
            CostLevel.LOW: 1000,      # < 1K tokens
            CostLevel.MEDIUM: 5000,  # 1K - 5K tokens
            CostLevel.HIGH: 5000,    # > 5K tokens
        }
    
    def allocate(
        self, 
        feature: str, 
        requested: int,
        priority: int = 1
    ) -> Tuple[int, CostLevel]:
        """
        分配 Token 预算
        
        Args:
            feature: 功能类型
            requested: 请求的 Token 数
            priority: 优先级(1-5,越高越优先)
            
        Returns:
            (实际分配数, 成本等级)
        """
        # 重置每日计数
        self._check_daily_reset()
        
        # 确定成本等级
        cost_level = self._determine_cost_level(requested)
        
        # 获取功能预算
        budget = self.feature_budgets.get(feature, self.feature_budgets['other'])
        
        # 检查预算可用性
        if budget.available >= requested:
            budget.used += requested
            self.daily_used += requested
            return requested, cost_level
        
        # 预算不足,尝试回收低优先级预算
        reclaimed = self._reclaim_budget(priority)
        
        if budget.available + reclaimed >= requested:
            # 优先使用功能自身预算
            available = budget.available
            budget.used += available
            
            remaining = requested - available
            budget.used += remaining
            
            self.daily_used += requested
            return requested, cost_level
        
        # 最终方案:按比例分配
        allocated = min(budget.available, requested)
        budget.used += allocated
        self.daily_used += allocated
        
        return allocated, cost_level
    
    def _check_daily_reset(self):
        """检查是否需要重置每日计数"""
        current_time = time.time()
        if current_time - self.last_reset > 86400:  # 24 hours
            self.daily_used = 0
            self.last_reset = current_time
            
            # 重置各功能预算
            for budget in self.feature_budgets.values():
                budget.used = 0
    
    def _determine_cost_level(self, tokens: int) -> CostLevel:
        """确定成本等级"""
        if tokens < self.cost_thresholds[CostLevel.LOW]:
            return CostLevel.LOW
        elif tokens < self.cost_thresholds[CostLevel.MEDIUM]:
            return CostLevel.MEDIUM
        else:
            return CostLevel.HIGH
    
    def _reclaim_budget(self, priority: int) -> int:
        """回收低优先级预算"""
        reclaimed = 0
        
        # 按优先级排序(低优先级的先回收)
        sorted_features = sorted(
            self.feature_budgets.items(),
            key=lambda x: self._get_feature_priority(x[0])
        )
        
        for feature, budget in sorted_features:
            if self._get_feature_priority(feature) < priority and budget.reserved > 0:
                # 回收预留但未使用的预算
                reclaimed += budget.reserved
                budget.reserved = 0
        
        return reclaimed
    
    def _get_feature_priority(self, feature: str) -> int:
        """获取功能优先级"""
        priorities = {
            'completion': 5,   # 代码补全高优先级
            'generation': 4,
            'refactoring': 3,
            'explanation': 2,
            'other': 1,
        }
        return priorities.get(feature, 1)
    
    def get_cost_report(self) -> Dict:
        """获取成本报告"""
        self._check_daily_reset()
        
        return {
            'daily_budget': self.daily_budget,
            'daily_used': self.daily_used,
            'daily_usage_ratio': self.daily_used / self.daily_budget if self.daily_budget > 0 else 0,
            'feature_costs': {
                feature: {
                    'budget': budget.total,
                    'used': budget.used,
                    'usage_ratio': budget.usage_ratio
                }
                for feature, budget in self.feature_budgets.items()
            }
        }
7.3 自适应上下文压缩

当 Token 预算紧张时,需要对上下文进行压缩。自适应压缩根据内容的重要性动态调整压缩比例。

代码语言:javascript
复制
class AdaptiveContextCompressor:
    """
    自适应上下文压缩器
    
    策略:
    1. 保留高相关性内容,低压缩
    2. 保留首尾内容,中间可压缩
    3. 保留结构标记(如函数签名),可压缩实现细节
    4. 保留近期内容,压缩旧内容
    """
    
    def __init__(
        self,
        min_compression_ratio: float = 0.5,  # 最小压缩率(保留 50%)
        max_compression_ratio: float = 0.9   # 最大压缩率(保留 10%)
    ):
        self.min_compression_ratio = min_compression_ratio
        self.max_compression_ratio = max_compression_ratio
    
    def compress(
        self,
        content: str,
        target_tokens: int,
        relevance_scores: Dict[int, float] = None
    ) -> str:
        """
        自适应压缩内容
        
        Args:
            content: 原始内容
            target_tokens: 目标 token 数
            relevance_scores: 每行/每个段落的 relevance 分数
            
        Returns:
            压缩后的内容
        """
        # 估算当前 token 数
        current_tokens = len(content) // 4
        
        if current_tokens <= target_tokens:
            return content
        
        # 计算压缩比
        compression_ratio = target_tokens / current_tokens
        compression_ratio = max(
            self.min_compression_ratio,
            min(self.max_compression_ratio, compression_ratio)
        )
        
        # 分析内容结构
        segments = self._split_into_segments(content)
        
        if not segments:
            return self._aggressive_compress(content, target_tokens)
        
        # 计算每段的压缩权重
        segment_weights = self._calculate_weights(
            segments, 
            relevance_scores
        )
        
        # 按权重分配目标 token
        total_weight = sum(segment_weights)
        allocated_tokens = {
            i: int(target_tokens * (w / total_weight))
            for i, w in enumerate(segment_weights)
        }
        
        # 压缩各段
        compressed_segments = []
        for i, segment in enumerate(segments):
            seg_tokens = allocated_tokens.get(i, 0)
            if seg_tokens < 20:  # 太小则跳过
                continue
            
            if len(segment) // 4 > seg_tokens:
                # 需要压缩
                compressed = self._compress_segment(segment, seg_tokens)
                compressed_segments.append(compressed)
            else:
                compressed_segments.append(segment)
        
        return '\n'.join(compressed_segments)
    
    def _split_into_segments(self, content: str) -> List[str]:
        """将内容分割成段落"""
        # 按空行分割
        segments = content.split('\n\n')
        return [s.strip() for s in segments if s.strip()]
    
    def _calculate_weights(
        self,
        segments: List[str],
        relevance_scores: Dict[int, float]
    ) -> List[float]:
        """计算每段的权重"""
        weights = []
        
        for i, segment in enumerate(segments):
            base_weight = 1.0
            
            # 首段和末段权重更高(位置效应)
            if i == 0:
                base_weight *= 1.5
            elif i == len(segments) - 1:
                base_weight *= 1.3
            
            # 相关性分数
            relevance = relevance_scores.get(i, 0.5) if relevance_scores else 0.5
            base_weight *= (0.5 + relevance)
            
            # 长度调整:太长或太短的段权重降低
            line_count = len(segment.split('\n'))
            if line_count < 3:
                base_weight *= 0.7
            elif line_count > 50:
                base_weight *= 0.8
            
            weights.append(base_weight)
        
        return weights
    
    def _compress_segment(self, segment: str, target_tokens: int) -> str:
        """压缩单个段落"""
        lines = segment.split('\n')
        
        if len(lines) <= 3:
            return segment
        
        target_lines = max(3, int(len(lines) * (target_tokens * 4 / len(segment))))
        
        if target_lines >= len(lines):
            return segment
        
        # 保留开头和结尾
        keep_from_start = target_lines // 2
        keep_from_end = target_lines - keep_from_start
        
        compressed = lines[:keep_from_start]
        compressed.append(f"... [{len(lines) - target_lines} lines truncated] ...")
        compressed.extend(lines[-keep_from_end:])
        
        return '\n'.join(compressed)
    
    def _aggressive_compress(self, content: str, target_tokens: int) -> str:
        """激进压缩(当无法按段落分割时)"""
        target_chars = target_tokens * 4
        
        if len(content) <= target_chars:
            return content
        
        # 保留开头和结尾
        keep = target_chars // 2
        return content[:keep] + f"\n\n[... content truncated ...]\n\n" + content[-keep:]

8. 实践:实现一个混合检索的 Context Engine

本节为你提供的核心技术价值:通过完整代码示例展示如何将 Chunk、Symbol、Graph、RAG 四大策略整合,构建生产级别的 Context Engine。

8.1 整体架构设计

8.2 核心实现代码
代码语言:javascript
复制
import asyncio
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple, Any
from enum import Enum
import time


class QueryIntent(Enum):
    """查询意图"""
    COMPLETION = "completion"
    GENERATION = "generation"
    EXPLANATION = "explanation"
    REFACTORING = "refactoring"
    BUG_FIX = "bug_fix"
    NAVIGATION = "navigation"
    SEARCH = "search"


@dataclass
class Query:
    """用户查询"""
    text: str
    intent: QueryIntent
    file_path: Optional[str] = None
    cursor_position: Optional[int] = None
    language: Optional[str] = None
    constraints: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ContextEngineConfig:
    """Context Engine 配置"""
    # 向量检索配置
    vector_model: str = "multi-qa-MiniLM-L6-cos-v1"
    vector_weight: float = 0.5
    
    # BM25 配置
    bm25_weight: float = 0.3
    
    # Symbol 配置
    symbol_weight: float = 0.15
    
    # Graph 配置
    graph_weight: float = 0.05
    
    # 预算配置
    max_context_tokens: int = 8000
    max_chunks: int = 50
    
    # 性能配置
    num_workers: int = 4
    cache_enabled: bool = True
    cache_ttl: int = 300  # seconds


class ContextEngine:
    """
    混合检索 Context Engine
    
    整合四大策略:
    1. Chunk:基础分块
    2. Symbol:符号索引
    3. Graph:图结构
    4. RAG:向量 + BM25 混合检索
    """
    
    def __init__(self, config: ContextEngineConfig):
        self.config = config
        
        # 初始化各组件
        self.vector_store: VectorIndex = None
        self.bm25_index: BM25 = None
        self.symbol_index: SymbolIndex = None
        self.call_graph: CallGraph = None
        self.dependency_graph: DependencyGraph = None
        
        # 缓存
        self._cache: Dict[str, Tuple[Any, float]] = {}
        
        # 状态
        self._initialized = False
    
    async def initialize(self, project_root: str):
        """
        初始化 Context Engine
        
        Args:
            project_root: 项目根目录
        """
        if self._initialized:
            return
        
        print(f"Initializing Context Engine for {project_root}")
        
        # 1. 加载或构建 Chunk 索引
        chunks = await self._build_chunk_index(project_root)
        
        # 2. 构建 Symbol 索引
        self.symbol_index = await self._build_symbol_index(project_root)
        
        # 3. 构建 Graph 索引
        self.call_graph, self.dependency_graph = await self._build_graph_index(project_root)
        
        # 4. 构建 RAG 索引(向量 + BM25)
        await self._build_rag_index(chunks)
        
        self._initialized = True
        print(f"Context Engine initialized with {len(chunks)} chunks")
    
    async def _build_chunk_index(self, project_root: str) -> List[Chunk]:
        """构建 Chunk 索引"""
        from pathlib import Path
        
        # 收集所有代码文件
        code_files = []
        for ext in ['.py', '.js', '.ts', '.go', '.rs', '.java', '.cpp', '.c', '.h']:
            code_files.extend(Path(project_root).rglob(f'*{ext}'))
        
        # 使用结构感知分块器
        chunker = StructuralChunker()
        all_chunks = []
        
        for file_path in code_files:
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    code = f.read()
                
                chunks = chunker.chunk_with_importance(
                    code, 
                    str(file_path)
                )
                
                for chunk in chunks:
                    all_chunks.append(chunk)
            
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
        
        return all_chunks
    
    async def _build_symbol_index(self, project_root: str) -> SymbolIndex:
        """构建 Symbol 索引"""
        extractor = LSSPSymbolExtractor()
        builder = SymbolIndexBuilder(extractor)
        
        # 收集所有 Python 文件
        code_files = list(Path(project_root).rglob('*.py'))
        
        return builder.build_from_files([str(f) for f in code_files])
    
    async def _build_graph_index(
        self, 
        project_root: str
    ) -> Tuple[CallGraph, DependencyGraph]:
        """构建 Graph 索引"""
        from pathlib import Path
        
        # 构建调用图
        call_graph_builder = StaticCallGraphBuilder()
        call_graph = call_graph_builder.build_from_project(project_root)
        
        # 构建依赖图
        dep_graph_builder = DependencyGraphBuilder(project_root)
        dependency_graph = dep_graph_builder.build()
        
        return call_graph, dependency_graph
    
    async def _build_rag_index(self, chunks: List[Chunk]):
        """构建 RAG 索引(向量 + BM25)"""
        # 准备索引数据
        chunk_data = [
            {
                'id': f"{c.file_path}:{c.start_line}",
                'content': c.content,
                'file_path': c.file_path,
                'start_line': c.start_line,
                'end_line': c.end_line,
                'metadata': c.metadata
            }
            for c in chunks
        ]
        
        # 构建混合检索器
        self.retriever = HybridRetriever(
            vector_weight=self.config.vector_weight,
            bm25_weight=self.config.bm25_weight
        )
        self.retriever.index(chunk_data)
    
    async def query(self, query: Query) -> str:
        """
        处理查询并返回上下文
        
        Args:
            query: 用户查询
            
        Returns:
            组装好的上下文字符串
        """
        # 1. 并行执行多策略检索
        chunk_results = await self._parallel_search(query)
        
        # 2. 分数融合与重排
        scored_chunks = self._fuse_and_rerank(chunk_results, query)
        
        # 3. 多阶段上下文选择
        selector = MultiStageSelector(
            max_chunks=self.config.max_chunks
        )
        selected = selector.select(
            scored_chunks,
            query.text,
            query.file_path,
            query.cursor_position or 0
        )
        
        # 4. 上下文组装
        assembler = RAGContextAssembler(
            max_tokens=self.config.max_context_tokens
        )
        
        search_results = [
            SearchResult(
                chunk_id=c.chunk_id,
                content=c.content,
                file_path=c.file_path,
                start_line=c.start_line,
                end_line=c.end_line,
                score=c.score
            )
            for c in selected
        ]
        
        context = assembler.assemble(search_results, query.text)
        
        return context
    
    async def _parallel_search(self, query: Query) -> List[ScoredChunk]:
        """并行执行多种检索策略"""
        tasks = []
        
        # RAG 检索
        tasks.append(self._rag_search(query))
        
        # Symbol 检索(如果适用)
        if query.text:
            tasks.append(self._symbol_search(query))
        
        # Graph 检索(如果适用)
        if query.file_path:
            tasks.append(self._graph_search(query))
        
        # 并行执行
        results = await asyncio.gather(*tasks)
        
        # 合并结果
        merged = {}
        for result_list in results:
            for chunk in result_list:
                if chunk.chunk_id not in merged:
                    merged[chunk.chunk_id] = chunk
                else:
                    # 分数累加
                    merged[chunk.chunk_id].score += chunk.score
        
        return list(merged.values())
    
    async def _rag_search(self, query: Query) -> List[ScoredChunk]:
        """RAG 检索"""
        search_results = self.retriever.search(
            query.text,
            top_k=self.config.max_chunks
        )
        
        return [
            ScoredChunk(
                chunk_id=r.chunk_id,
                content=r.content,
                file_path=r.file_path,
                start_line=r.start_line,
                end_line=r.end_line,
                score=r.combined_score
            )
            for r in search_results
        ]
    
    async def _symbol_search(self, query: Query) -> List[ScoredChunk]:
        """Symbol 检索"""
        if not self.symbol_index:
            return []
        
        # 提取查询中的符号名
        symbols = self.symbol_index.find_by_prefix(query.text)
        
        chunks = []
        for symbol in symbols[:10]:  # 限制数量
            chunks.append(ScoredChunk(
                chunk_id=f"symbol:{symbol.name}",
                content=f"{symbol.name}: {symbol.kind.name}",
                file_path=symbol.file_path,
                start_line=symbol.start_line,
                end_line=symbol.end_line,
                score=0.5 * self.config.symbol_weight
            ))
        
        return chunks
    
    async def _graph_search(self, query: Query) -> List[ScoredChunk]:
        """Graph 检索"""
        if not self.call_graph or not query.file_path:
            return []
        
        chunks = []
        
        # 查找直接调用者和被调用者
        func_id = query.file_path
        callers = self.call_graph.find_all_callers(func_id, max_depth=2)
        callees = self.call_graph.find_all_callees(func_id, max_depth=2)
        
        # 添加调用链上下文
        for caller_id in list(callers)[:5]:
            if ':' in caller_id:
                file_path, func_name = caller_id.rsplit(':', 1)
                chunks.append(ScoredChunk(
                    chunk_id=f"graph:{caller_id}",
                    content=f"Caller: {func_name}",
                    file_path=file_path,
                    start_line=0,
                    end_line=0,
                    score=0.3 * self.config.graph_weight
                ))
        
        return chunks
    
    def _fuse_and_rerank(
        self, 
        chunks: List[ScoredChunk], 
        query: Query
    ) -> List[ScoredChunk]:
        """分数融合与重排"""
        # 简单策略:直接按分数排序
        # 实际可使用更复杂的融合策略(如 Reciprocal Rank Fusion)
        chunks.sort(key=lambda c: c.score, reverse=True)
        return chunks
    
    def get_callers(self, file_path: str, function_name: str = None) -> List[str]:
        """获取函数的调用者"""
        if not self.call_graph:
            return []
        
        func_id = f"{file_path}:{function_name}" if function_name else file_path
        return list(self.call_graph.find_all_callers(func_id))
    
    def get_callees(self, file_path: str, function_name: str = None) -> List[str]:
        """获取函数调用的其他函数"""
        if not self.call_graph:
            return []
        
        func_id = f"{file_path}:{function_name}" if function_name else file_path
        return list(self.call_graph.find_all_callees(func_id))
8.3 使用示例
代码语言:javascript
复制
async def main():
    """使用 Context Engine 的示例"""
    
    # 1. 创建配置
    config = ContextEngineConfig(
        vector_weight=0.5,
        bm25_weight=0.3,
        symbol_weight=0.15,
        graph_weight=0.05,
        max_context_tokens=8000
    )
    
    # 2. 初始化 Context Engine
    engine = ContextEngine(config)
    await engine.initialize("/path/to/project")
    
    # 3. 执行查询
    query = Query(
        text="How does the user authentication work?",
        intent=QueryIntent.EXPLANATION,
        file_path="/path/to/project/auth.py",
        cursor_position=100
    )
    
    # 4. 获取上下文
    context = await engine.query(query)
    
    print(f"Retrieved context ({len(context)} chars):")
    print(context)
    
    # 5. 影响分析示例
    callers = engine.get_callers("/path/to/project/auth.py", "authenticate")
    print(f"Functions that call authenticate: {callers}")


if __name__ == "__main__":
    asyncio.run(main())

9. 总结与展望

9.1 四大策略对比

策略

核心能力

优势

局限

适用场景

Chunk

代码切分

基础、通用

语义割裂

所有场景的预处理

Symbol

结构索引

精确、类型感知

需要 LSP 支持

代码导航、补全

Graph

关系建模

语义关联、影响分析

构建成本高

重构、调试

RAG

语义检索

理解自然语言

计算密集

问答、代码生成

9.2 工程实践建议

基于本文的分析,对于构建 AI IDE 的 Context Engine,提出以下建议:

  1. 分层索引:建立文件级、类级、函数级的多层索引,支持不同粒度的检索
  2. 混合检索:向量检索 + BM25 + Symbol + Graph 的融合策略,比单一策略更鲁棒
  3. 自适应压缩:根据查询类型和 Token 预算动态调整上下文
  4. 增量更新:利用文件监控实现索引的增量更新,避免全量重建
  5. 成本监控:建立 Token 消耗监控体系,及时发现异常
9.3 未来方向
  1. 更大上下文窗口:随着模型能力提升,更大的上下文窗口将减少上下文选择的压力
  2. 更智能的检索:结合代码结构理解的多模态检索,减少"lost in the middle"问题
  3. 端侧部署:将 Context Engine 部署到端侧,减少云端 API 依赖和延迟
  4. 个性化学习:根据用户习惯学习上下文偏好,提供个性化的上下文推荐

参考链接


参考链接:

附录(Appendix):

A. RAG 实现细节

A.1 向量数据库选型对比

数据库

优点

缺点

适用规模

FAISS

高效、Facebook 背书

仅支持向量

< 100M

Milvus

云原生、支持混合检索

部署复杂

> 100M

Chroma

轻量、易用

功能有限

< 10M

Qdrant

高性能、支持过滤

相对较新

> 100M

A.2 BM25 参数调优
代码语言:javascript
复制
# 推荐参数配置
bm25 = BM25(k1=1.5, b=0.75)

# 针对短查询优化
bm25_short = BM25(k1=2.0, b=0.5)

# 针对长文档优化
bm25_long = BM25(k1=1.2, b=0.9)

B. 性能基准测试结果

策略组合

召回率@10

延迟 (ms)

成本 ($/1K queries)

BM25 Only

0.65

15

0.05

Vector Only

0.72

45

0.12

Hybrid (0.5/0.5)

0.81

55

0.15

Hybrid + Graph

0.85

70

0.18

测试环境:100K 代码块,平均块大小 100 行

关键词: Context Engine, Chunk, Symbol, Graph, RAG, AI IDE, 上下文检索, 混合检索, 代码索引, Token 优化

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2026-05-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 目录
  • 1. 引言:为什么 Context 是 AI IDE 的核心瓶颈
    • 1.1 上下文危机的本质
    • 1.2 AI IDE 的上下文处理流程
    • 1.3 四大策略的协同关系
  • 2. Chunk 策略:代码分块的艺术与科学
    • 2.1 分块:上下文工程的基石
    • 2.2 固定分块策略
      • 2.2.1 实现原理
      • 2.2.2 固定分块的优劣分析
    • 2.3 语义分块策略
      • 2.3.1 基于 AST 的语义分块
      • 2.3.2 语义分块的边界检测
    • 2.4 结构感知分块策略
      • 2.4.1 分块重要性评分
    • 2.5 分块策略对比分析
    • 2.6 分块策略的工程实践
  • 3. Symbol 策略:代码结构的精确索引
    • 3.1 为什么需要 Symbol 索引
    • 3.2 Language Server Protocol (LSP) 与符号提取
      • 3.2.1 LSP 符号能力概览
      • 3.2.2 基于 LSP 的符号提取实现
    • 3.3 符号索引的构建与存储
      • 3.3.1 符号索引的数据结构
    • 3.4 类型推断与符号增强
      • 3.4.1 类型推断引擎
    • 3.5 Symbol 策略在 AI IDE 中的应用
      • 3.5.1 智能补全
      • 3.5.2 代码理解增强
  • 4. Graph 策略:代码关系的语义建模
    • 4.1 代码图结构的本质
    • 4.2 调用图构建
      • 4.2.1 静态调用图构建
    • 4.3 依赖图构建
    • 4.4 Graph 策略在 AI IDE 中的应用
      • 4.4.1 影响分析
      • 4.4.2 上下文优先级排序
  • 5. RAG 策略:检索增强的语义匹配
    • 5.1 RAG 在 AI IDE 中的定位
    • 5.2 向量检索基础
      • 5.2.1 文本向量化
      • 5.2.2 向量索引与检索
    • 5.3 BM25 检索算法
    • 5.4 混合检索架构
    • 5.5 RAG 上下文组装策略
  • 6. 上下文选择算法:从海量代码中精准选材
    • 6.1 上下文选择的挑战
      • 6.1.1 挑战的本质
    • 6.2 多阶段筛选算法
    • 6.3 上下文预算分配
  • 7. 成本控制:Token 预算与上下文窗口的博弈
    • 7.1 Token 经济的本质
    • 7.2 Token 预算管理框架
    • 7.3 自适应上下文压缩
  • 8. 实践:实现一个混合检索的 Context Engine
    • 8.1 整体架构设计
    • 8.2 核心实现代码
    • 8.3 使用示例
  • 9. 总结与展望
    • 9.1 四大策略对比
    • 9.2 工程实践建议
    • 9.3 未来方向
  • 参考链接
  • A. RAG 实现细节
    • A.1 向量数据库选型对比
    • A.2 BM25 参数调优
  • B. 性能基准测试结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档