
作者: HOS(安全风信子) 日期: 2026-05-24 主要来源平台: GitHub 摘要: Context(上下文)是 AI IDE 能力的上限,这一论断在人工智能辅助编程领域已获得广泛认同。在实际软件开发场景中,代码库规模动辄数十万乃至上百万行,而 AI 模型的有效上下文窗口通常仅有数万 token。如何在海量代码中精准提取最相关的上下文片段,是 AI IDE 面临的核心挑战。本文深度解析四大上下文策略:Chunk(文本分块)通过多种分块算法将代码拆分为可管理的单元;Symbol(符号索引)建立函数、类、变量的索引结构实现精确导航;Graph(图结构)构建代码结构图、调用图、依赖图捕获语义关联;RAG(检索增强)融合向量检索与传统 BM25 算法实现语义匹配。通过对 Cursor、Claude Code、Cline 等主流产品的深度分析,我们揭示上下文工程背后的原理与实现机制,探讨如何在效果与成本之间取得平衡,并给出混合检索 Context Engine 的完整实现方案。
“The quality of an AI coding assistant is ultimately limited by the quality of its context.” AI 编程助手的质量,最终受限于其上下文的质量。
当代软件开发的核心矛盾之一是代码规模膨胀与模型上下文有限之间的冲突。根据 GitHub 的统计数据[^1],一个中等规模的企业级代码库通常包含:
代码库规模 | 代码行数(LOC) | Token 估算(压缩率 1.5) |
|---|---|---|
小型项目 | 10,000 - 50,000 | 15K - 75K |
中型项目 | 50,000 - 200,000 | 75K - 300K |
大型项目 | 200,000 - 1,000,000 | 300K - 1.5M |
巨型项目 | > 1,000,000 | > 1.5M |
而截至 2026 年,主流大模型的上下文窗口虽然已有显著提升,但真正有效使用的上下文通常被限制在 32K - 128K token 范围内,超过此范围后模型对远处信息的注意力会显著衰减,这被研究者称为 “lost in the middle” 问题[^2]。
一个完整的 AI IDE 上下文处理流程通常包含以下阶段:

Chunk、Symbol、Graph、RAG 四大策略并非相互替代,而是互补协同的关系:
这四者的关系可以用下表概括:
策略 | 核心问题 | 输入 | 输出 | 典型算法 |
|---|---|---|---|---|
Chunk | 如何切分代码? | 原始代码文件 | 结构化块列表 | 固定窗口、语义分块 |
Symbol | 符号在哪? | 代码 AST | 符号索引表 | LSP 协议、类型推断 |
Graph | 代码如何关联? | 符号索引 | 图结构 | 静态分析、指针解析 |
RAG | 最相关的代码? | 查询 + 块 | 相关块排序 | 向量检索、BM25 |
本节为你提供的核心技术价值:理解代码分块的多种策略及其适用场景,掌握固定分块、语义分块、结构感知分块的原理与实现。
Chunk(文本分块)是 Context Engine 中最基础的策略,其核心思想是将大规模的代码库拆分为多个独立的、可检索的片段。分块的质量直接影响后续检索的效果:分块过大,相关性噪声增加;分块过小,上下文连续性丢失。
固定分块是最简单也是最常用的分块策略。其核心思想是按照预设的固定长度(通常以 token 数或行数计)对代码进行均匀切分。
class FixedChunker:
"""
固定大小分块器
将代码按固定 token 数或行数进行切分
"""
def __init__(self, chunk_size: int = 512, overlap: int = 50):
"""
Args:
chunk_size: 每块的 token 数(包含重叠)
overlap: 相邻块之间的重叠 token 数
"""
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_code(self, code: str, language: str = None) -> list[Chunk]:
"""
对代码进行固定大小分块
Args:
code: 源代码文本
language: 编程语言(用于 tokenizer 选择)
Returns:
分块列表,每个块包含文本、起始位置、结束位置
"""
# 选择对应语言的 tokenizer
tokenizer = self._get_tokenizer(language)
# 将代码转换为 token 序列
tokens = tokenizer.encode(code)
chunks = []
step = self.chunk_size - self.overlap
for start in range(0, len(tokens), step):
end = min(start + self.chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
# 解码回文本(可能存在边界偏差,但可接受)
chunk_text = tokenizer.decode(chunk_tokens)
chunks.append(Chunk(
content=chunk_text,
start_token=start,
end_token=end,
start_line=self._get_line_from_offset(code, start),
end_line=self._get_line_from_offset(code, end)
))
if end >= len(tokens):
break
return chunks
def _get_tokenizer(self, language: str):
"""获取对应语言的 tokenizer"""
# 实际实现中可根据 language 返回不同的 tokenizer
# 这里简化处理,使用 tiktoken 作为通用 tokenizer
import tiktoken
return tiktoken.get_encoding("cl100k_base")
def _get_line_from_offset(self, code: str, token_offset: int) -> int:
"""根据 token 偏移量估算行号"""
# 简化实现:假设平均每 token 约 4 字符,每行约 80 字符
char_offset = token_offset * 4
return code[:char_offset].count('\n') + 1
@dataclass
class Chunk:
"""代码分块"""
content: str
start_token: int
end_token: int
start_line: int
end_line: int
file_path: str = None
language: str = None
metadata: dict = field(default_factory=dict)优点:
缺点:
语义分块旨在解决固定分块的语义割裂问题。其核心思想是根据代码的语义边界(如函数、类、模块)进行分块,确保每个块尽可能保持语义的完整性。
现代编程语言的代码具有明确的语法结构——函数、类、方法、模块等。这些结构天然构成了语义边界,是进行分块的理想切分点。
import ast
from typing import Iterator
from dataclasses import dataclass, field
class SemanticChunker:
"""
基于 AST 的语义分块器
通过解析代码的抽象语法树,按语义单元进行分块
"""
def __init__(self, max_chunk_size: int = 512, min_chunk_size: int = 50):
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
def chunk_code(self, code: str, file_path: str = None) -> list[Chunk]:
"""
对代码进行语义分块
策略:
1. 解析 AST
2. 遍历 AST 节点,按语义单元(函数、类、方法)分块
3. 过大块进一步拆分,过小块合并
"""
try:
tree = ast.parse(code, filename=file_path)
except SyntaxError:
# 解析失败时回退到固定分块
return self._fallback_chunk(code)
chunks = []
# 遍历 AST 节点,按语义单元分块
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
chunk = self._function_to_chunk(node, code, file_path)
if self._is_valid_chunk(chunk):
chunks.append(chunk)
elif self._is_large_chunk(chunk):
# 过大的块需要进一步拆分
chunks.extend(self._split_large_chunk(chunk))
elif isinstance(node, ast.ClassDef):
chunk = self._class_to_chunk(node, code, file_path)
if self._is_valid_chunk(chunk):
chunks.append(chunk)
elif self._is_large_chunk(chunk):
chunks.extend(self._split_large_chunk(chunk))
# 按文件中的位置排序
chunks.sort(key=lambda c: c.start_line)
# 合并过小的块
chunks = self._merge_small_chunks(chunks)
return chunks
def _function_to_chunk(self, node: ast.FunctionDef, code: str, file_path: str) -> Chunk:
"""将函数节点转换为 Chunk"""
# 获取函数定义的行范围
start_line = node.lineno
end_line = node.end_lineno or start_line + 20
# 提取函数代码(包含完整定义)
lines = code.split('\n')
function_lines = lines[start_line - 1:end_line]
function_code = '\n'.join(function_lines)
return Chunk(
content=function_code,
start_token=0, # token 统计在外部进行
end_token=0,
start_line=start_line,
end_line=end_line,
file_path=file_path,
metadata={
'type': 'function',
'name': node.name,
'args': [arg.arg for arg in node.args.args],
'decorators': [ast.unparse(d) for d in node.decorator_list]
}
)
def _class_to_chunk(self, node: ast.ClassDef, code: str, file_path: str) -> Chunk:
"""将类节点转换为 Chunk"""
start_line = node.lineno
end_line = node.end_lineno or start_line + 50
lines = code.split('\n')
class_lines = lines[start_line - 1:end_line]
class_code = '\n'.join(class_lines)
# 提取类的方法列表
methods = [n.name for n in node.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
return Chunk(
content=class_code,
start_token=0,
end_token=0,
start_line=start_line,
end_line=end_line,
file_path=file_path,
metadata={
'type': 'class',
'name': node.name,
'methods': methods,
'bases': [ast.unparse(base) for base in node.bases]
}
)
def _is_valid_chunk(self, chunk: Chunk) -> bool:
"""检查块是否有效(不过小)"""
return len(chunk.content.split('\n')) >= self.min_chunk_size
def _is_large_chunk(self, chunk: Chunk) -> bool:
"""检查块是否过大"""
return len(chunk.content.split('\n')) > self.max_chunk_size
def _split_large_chunk(self, chunk: Chunk) -> list[Chunk]:
"""拆分过大的块(简化实现)"""
# 实际实现中可按行数进一步拆分
# 这里简化处理,返回原块
return [chunk]
def _merge_small_chunks(self, chunks: list[Chunk]) -> list[Chunk]:
"""合并过小的块"""
if not chunks:
return []
merged = [chunks[0]]
for chunk in chunks[1:]:
last = merged[-1]
combined_lines = last.end_line - last.start_line + chunk.end_line - chunk.start_line
if combined_lines < self.min_chunk_size:
# 合并到前一个块
merged[-1] = Chunk(
content=last.content + '\n' + chunk.content,
start_token=last.start_token,
end_token=chunk.end_token,
start_line=last.start_line,
end_line=chunk.end_line,
file_path=chunk.file_path,
metadata={**last.metadata, **chunk.metadata}
)
else:
merged.append(chunk)
return merged
def _fallback_chunk(self, code: str) -> list[Chunk]:
"""解析失败时的回退策略"""
fixed_chunker = FixedChunker()
return fixed_chunker.chunk_code(code)在实际实现中,语义分块需要精确检测各种代码结构的边界:
class BoundaryDetector:
"""
代码结构边界检测器
精确检测函数、类、模块的边界位置
"""
# 语言的缩进模式
INDENT_PATTERNS = {
'python': {'indent': ' ', 'dedent': None},
'javascript': {'indent': ' ', 'dedent': None},
'java': {'indent': ' ', 'dedent': None},
}
def detect_function_boundaries(self, code: str, language: str = 'python') -> list[tuple[int, int, str]]:
"""
检测所有函数的边界
Returns:
[(start_line, end_line, function_name), ...]
"""
if language == 'python':
return self._detect_python_functions(code)
elif language in ('javascript', 'typescript'):
return self._detect_js_functions(code)
else:
return self._detect_generic_functions(code)
def _detect_python_functions(self, code: str) -> list[tuple[int, int, str]]:
"""检测 Python 函数的边界"""
tree = ast.parse(code)
boundaries = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
boundaries.append((
node.lineno,
node.end_lineno or node.lineno + 20,
node.name
))
return boundaries
def _detect_js_functions(self, code: str) -> list[tuple[int, int, str]]:
"""检测 JavaScript 函数的边界(简化实现)"""
import re
# 匹配函数声明
func_pattern = r'(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>|(\w+)\s*:\s*(?:async\s*)?\([^)]*\)\s*=>)'
boundaries = []
lines = code.split('\n')
for i, line in enumerate(lines, 1):
match = re.search(func_pattern, line)
if match:
func_name = match.group(1) or match.group(2) or match.group(3)
# 简化:假设函数体在后续 5-50 行
boundaries.append((i, min(i + 20, len(lines)), func_name))
return boundaries
def _detect_generic_functions(self, code: str) -> list[tuple[int, int, str]]:
"""通用函数边界检测(基于括号匹配)"""
boundaries = []
lines = code.split('\n')
in_function = False
function_start = 0
function_name = ''
brace_count = 0
for i, line in enumerate(lines, 1):
if 'function' in line or 'def ' in line:
in_function = True
function_start = i
# 提取函数名
match = re.search(r'(?:function|def)\s+(\w+)', line)
function_name = match.group(1) if match else f'anonymous_{i}'
if in_function:
brace_count += line.count('{') - line.count('}')
if brace_count == 0 and '{' in ''.join(lines[function_start-1:i]):
boundaries.append((function_start, i, function_name))
in_function = False
return boundaries结构感知分块在语义分块的基础上更进一步,不仅考虑代码的语法结构,还考虑代码的语义关系和重要性权重。这种策略的核心思想是:并非所有代码都具有同等重要性,核心业务逻辑、入口函数、公共 API 应该获得更多的上下文份额。
from enum import Enum
from typing import Callable
class CodeImportance(Enum):
"""代码重要性等级"""
CRITICAL = 1 # 关键:入口点、公共 API
HIGH = 2 # 高:业务核心逻辑、数据模型
MEDIUM = 3 # 中:辅助功能、内部实现
LOW = 4 # 低:工具函数、配置
class StructuralChunker:
"""
结构感知分块器
根据代码在项目中的结构性位置和重要性进行分块
"""
def __init__(self):
self.importance_weights = {
'main': 4.0,
'index': 4.0,
'app': 3.5,
'api': 3.5,
'router': 3.0,
'service': 3.0,
'model': 2.5,
'controller': 2.5,
'util': 1.5,
'helper': 1.5,
'config': 1.0,
'test': 0.8,
}
def chunk_with_importance(
self,
code: str,
file_path: str,
structure_hints: dict = None
) -> list[Chunk]:
"""
带重要性评分的结构感知分块
Args:
code: 源代码
file_path: 文件路径(用于判断文件类型和位置)
structure_hints: 额外的结构提示(如入口点列表)
"""
# 1. 基础分块
base_chunks = self._base_chunking(code, file_path)
# 2. 计算每个块的重要性分数
for chunk in base_chunks:
chunk.importance = self._calculate_importance(chunk, file_path, structure_hints)
# 3. 根据重要性调整块大小
adjusted_chunks = self._adjust_chunks_by_importance(base_chunks)
return adjusted_chunks
def _calculate_importance(
self,
chunk: Chunk,
file_path: str,
hints: dict = None
) -> float:
"""
计算分块的重要性分数
考量因素:
1. 文件路径(目录和文件名)
2. 代码类型(函数声明、类定义等)
3. 导出/公开程度
4. 调用频率
"""
score = 1.0
# 文件路径因素
file_name = chunk.file_path.split('/')[-1].lower()
directory = '/'.join(chunk.file_path.split('/')[:-1]).lower()
for key, weight in self.importance_weights.items():
if key in file_name or key in directory:
score *= weight
break
# 代码结构因素
metadata = chunk.metadata
if metadata.get('type') == 'function':
func_name = metadata.get('name', '').lower()
# 构造函数、入口函数高权重
if func_name in ('__init__', 'main', 'run', 'start', 'init'):
score *= 2.0
# 下划线开头的私有方法低权重
if func_name.startswith('_') and not func_name.startswith('__'):
score *= 0.7
# 公共方法(无下划线)适度权重
if not func_name.startswith('_'):
score *= 1.3
elif metadata.get('type') == 'class':
class_name = metadata.get('name', '')
# 继承自特定基类高权重
bases = metadata.get('bases', [])
if any('Exception' in b or 'Error' in b for b in bases):
score *= 1.5
# 公共类适度权重
if not class_name.startswith('_'):
score *= 1.2
# 外部提示因素
if hints:
# 如果这是已知的入口函数
entry_functions = hints.get('entry_functions', [])
if metadata.get('name') in entry_functions:
score *= 3.0
return score
def _adjust_chunks_by_importance(self, chunks: list[Chunk]) -> list[Chunk]:
"""
根据重要性调整块大小
策略:
- 高重要性块:扩大上下文,减少拆分
- 低重要性块:缩小块,减少上下文
"""
# 计算平均重要性
avg_importance = sum(c.importance for c in chunks) / len(chunks) if chunks else 1.0
adjusted = []
for chunk in chunks:
if chunk.importance > avg_importance * 1.5:
# 高重要性:扩展块(包含更多上下文)
expanded = self._expand_chunk(chunk)
adjusted.append(expanded)
elif chunk.importance < avg_importance * 0.5:
# 低重要性:压缩块(减少到核心部分)
compressed = self._compress_chunk(chunk)
adjusted.append(compressed)
else:
adjusted.append(chunk)
return adjusted
def _expand_chunk(self, chunk: Chunk) -> Chunk:
"""扩展高重要性块"""
# 简化实现:扩展 20% 的上下文
lines = chunk.content.split('\n')
expand_lines = max(1, len(lines) // 5)
return Chunk(
content=chunk.content,
start_token=chunk.start_token,
end_token=chunk.end_token,
start_line=max(1, chunk.start_line - expand_lines),
end_line=chunk.end_line + expand_lines,
file_path=chunk.file_path,
metadata={**chunk.metadata, 'expanded': True}
)
def _compress_chunk(self, chunk: Chunk) -> Chunk:
"""压缩低重要性块"""
# 简化实现:标记为压缩
return Chunk(
content=chunk.content,
start_token=chunk.start_token,
end_token=chunk.end_token,
start_line=chunk.start_line,
end_line=chunk.end_line,
file_path=chunk.file_path,
metadata={**chunk.metadata, 'compressed': True}
)
def _base_chunking(self, code: str, file_path: str) -> list[Chunk]:
"""基础分块(使用语义分块器)"""
semantic_chunker = SemanticChunker()
return semantic_chunker.chunk_code(code, file_path)策略 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
固定分块 | 实现简单、计算高效 | 语义割裂、边界不智能 | 大规模代码的快速预览、精确检索不可行时 |
语义分块 | 保持语义完整、AST 精确 | 实现复杂、解析开销大 | 需要精确代码理解的场景 |
结构感知 | 重要性区分、上下文优化 | 需要项目结构知识、权重调优复杂 | 大型代码库、差异化检索需求 |
在实际 AI IDE 产品中,分块策略的选择通常遵循以下原则:
Cursor 的分块策略分析表明[^3],其采用了一种层级感知的混合分块方法:
本节为你提供的核心技术价值:理解符号索引的构建原理,掌握基于 LSP 的符号提取与类型推理技术,实现精确的代码定位和导航。
Chunk 策略关注的是代码的文本切分,而 Symbol 策略关注的是代码的结构语义。在软件开发中,代码不仅仅是文本,更是具有丰富语义结构的符号系统:
这些符号信息对于 AI 理解代码至关重要。当用户询问"这个函数在哪里定义"或"这个变量的类型是什么"时,Symbol 索引能够提供精确的答案。
LSP(Language Server Protocol)是 Microsoft 发起的标准化协议,用于在编辑器/IDE 和语言服务器之间进行通信[^4]。LSP 提供了丰富的代码符号操作能力,是构建 Symbol 索引的基础。

import json
import subprocess
from dataclasses import dataclass, field
from typing import Optional
from enum import IntEnum
class SymbolKind(IntEnum):
"""LSP 符号类型"""
FILE = 1
MODULE = 2
NAMESPACE = 3
PACKAGE = 4
CLASS = 5
METHOD = 6
PROPERTY = 7
FIELD = 8
CONSTRUCTOR = 9
ENUM = 10
INTERFACE = 11
FUNCTION = 12
VARIABLE = 13
CONSTANT = 14
STRING = 15
NUMBER = 16
BOOLEAN = 17
ARRAY = 18
OBJECT = 19
KEY = 20
NULL = 21
ENUMMEMBER = 22
STRUCT = 23
EVENT = 24
OPERATOR = 25
TYPEPARAMETER = 26
@dataclass
class SymbolInfo:
"""符号信息"""
name: str
kind: SymbolKind
file_path: str
start_line: int
start_column: int
end_line: int
end_column: int
container_name: str = None
signature: str = None
docstring: str = None
dependencies: list[str] = field(default_factory=list)
type_info: str = None
deprecated: bool = False
class LSSPSymbolExtractor:
"""
基于 LSP 的符号提取器
使用 Language Server Protocol 提取代码符号信息
"""
def __init__(self, ls_path: str = None):
"""
Args:
ls_path: 语言服务器可执行文件路径
"""
self.ls_path = ls_path
self.process = None
self.request_id = 0
def extract_symbols(self, file_path: str, language: str) -> list[SymbolInfo]:
"""
从源文件中提取所有符号
Args:
file_path: 源文件路径
language: 编程语言
Returns:
符号信息列表
"""
# 启动语言服务器
self._start_server(language)
try:
# 发送 textDocument/documentSymbol 请求
response = self._send_request('textDocument/documentSymbol', {
'textDocument': {'uri': self._path_to_uri(file_path)}
})
symbols = self._parse_symbol_response(response, file_path)
return symbols
finally:
self._stop_server()
def find_definition(self, file_path: str, line: int, character: int) -> SymbolInfo:
"""
查找符号定义位置
Args:
file_path: 源文件路径
line: 光标行号(0索引)
character: 光标列号(0索引)
Returns:
定义位置的符号信息
"""
self._start_server(self._detect_language(file_path))
try:
response = self._send_request('textDocument/definition', {
'textDocument': {'uri': self._path_to_uri(file_path)},
'position': {'line': line, 'character': character}
})
return self._parse_location_response(response)
finally:
self._stop_server()
def find_references(self, file_path: str, line: int, character: int) -> list[SymbolInfo]:
"""
查找符号的所有引用
Returns:
引用位置列表
"""
self._start_server(self._detect_language(file_path))
try:
response = self._send_request('textDocument/references', {
'textDocument': {'uri': self._path_to_uri(file_path)},
'position': {'line': line, 'character': character},
'context': {'includeDeclaration': True}
})
return self._parse_locations_response(response)
finally:
self._stop_server()
def _start_server(self, language: str):
"""启动语言服务器"""
if self.process is not None:
return
# 根据语言选择对应的语言服务器
server_cmd = self._get_server_command(language)
self.process = subprocess.Popen(
server_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 发送初始化请求
self._send_request('initialize', {
'processId': self.process.pid,
'rootUri': self._path_to_uri('.'),
'capabilities': {
'textDocument': {
'documentSymbol': {
'hierarchicalDocumentSymbolSupport': True
}
}
}
})
# 发送 initialized 通知
self._send_notification('initialized', {})
def _stop_server(self):
"""停止语言服务器"""
if self.process:
self._send_notification('shutdown', {})
self.process.terminate()
self.process = None
def _send_request(self, method: str, params: dict) -> dict:
"""发送 LSP 请求"""
self.request_id += 1
request = {
'jsonrpc': '2.0',
'id': self.request_id,
'method': method,
'params': params
}
request_json = json.dumps(request) + '\n'
self.process.stdin.write(request_json.encode())
self.process.stdin.flush()
# 读取响应
response_line = self.process.stdout.readline()
response = json.loads(response_line)
return response.get('result', {})
def _send_notification(self, method: str, params: dict):
"""发送 LSP 通知"""
notification = {
'jsonrpc': '2.0',
'method': method,
'params': params
}
notification_json = json.dumps(notification) + '\n'
self.process.stdin.write(notification_json.encode())
self.process.stdin.flush()
def _parse_symbol_response(self, response: list, file_path: str) -> list[SymbolInfo]:
"""解析符号响应"""
symbols = []
def parse_recursive(items, container=''):
for item in items:
if 'children' in item:
# DocumentSymbol format
symbols.append(SymbolInfo(
name=item['name'],
kind=SymbolKind(item['kind']),
file_path=file_path,
start_line=item['range']['start']['line'],
start_column=item['range']['start']['character'],
end_line=item['range']['end']['line'],
end_column=item['range']['end']['character'],
container_name=container,
detail=item.get('detail')
))
parse_recursive(item['children'], item['name'])
else:
# SymbolInformation format
location = item['location']
symbols.append(SymbolInfo(
name=item['name'],
kind=SymbolKind(item['kind']),
file_path=file_path,
start_line=location['range']['start']['line'],
start_column=location['range']['start']['character'],
end_line=location['range']['end']['line'],
end_column=location['range']['end']['character'],
container_name=item.get('containerName')
))
parse_recursive(response)
return symbols
def _get_server_command(self, language: str) -> list[str]:
"""获取语言服务器启动命令"""
servers = {
'python': ['python', '-m', 'pylsp'],
'javascript': ['node', '/path/to/javascript-language-server.js'],
'typescript': ['node', '/path/to/typescript-language-server.js'],
'rust': ['rust-analyzer'],
'go': ['gopls'],
'java': ['jdtls'],
}
return servers.get(language, ['pylsp'])
def _path_to_uri(self, path: str) -> str:
"""将文件路径转换为 file URI"""
import urllib.parse
return 'file://' + urllib.parse.quote(path)
def _detect_language(self, file_path: str) -> str:
"""根据文件扩展名检测语言"""
ext_map = {
'.py': 'python',
'.js': 'javascript',
'.ts': 'typescript',
'.rs': 'rust',
'.go': 'go',
'.java': 'java',
}
import os
_, ext = os.path.splitext(file_path)
return ext_map.get(ext.lower(), 'python')提取的符号信息需要高效存储,以便快速检索。对于大规模代码库,符号数量可能达到数十万,因此需要设计高效的索引结构。
from typing import Dict, List, Set, Optional
from dataclasses import dataclass, field
import json
@dataclass
class SymbolIndex:
"""
符号索引
支持多种查询模式:
1. 按名称查询(精确/模糊)
2. 按类型查询
3. 按文件/包查询
4. 按依赖关系查询
"""
# 名称 -> 符号列表(支持同名符号重载)
name_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
# 类型 -> 符号列表
kind_index: Dict[SymbolKind, List[SymbolInfo]] = field(default_factory=dict)
# 文件路径 -> 符号列表
file_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
# 容器(类/模块) -> 子符号列表
container_index: Dict[str, List[SymbolInfo]] = field(default_factory=dict)
# 全符号列表(按文件路径和行号排序)
all_symbols: List[SymbolInfo] = field(default_factory=list)
def add_symbol(self, symbol: SymbolInfo):
"""添加符号到索引"""
# 名称索引
if symbol.name not in self.name_index:
self.name_index[symbol.name] = []
self.name_index[symbol.name].append(symbol)
# 类型索引
if symbol.kind not in self.kind_index:
self.kind_index[symbol.kind] = []
self.kind_index[symbol.kind].append(symbol)
# 文件索引
if symbol.file_path not in self.file_index:
self.file_index[symbol.file_path] = []
self.file_index[symbol.file_path].append(symbol)
# 容器索引
if symbol.container_name:
if symbol.container_name not in self.container_index:
self.container_index[symbol.container_name] = []
self.container_index[symbol.container_name].append(symbol)
# 全量列表
self.all_symbols.append(symbol)
def find_by_name(self, name: str, exact: bool = True) -> List[SymbolInfo]:
"""按名称查找符号"""
if exact:
return self.name_index.get(name, [])
else:
# 模糊匹配
results = []
for sym_name, symbols in self.name_index.items():
if name.lower() in sym_name.lower():
results.extend(symbols)
return results
def find_by_kind(self, kind: SymbolKind) -> List[SymbolInfo]:
"""按类型查找符号"""
return self.kind_index.get(kind, [])
def find_by_file(self, file_path: str) -> List[SymbolInfo]:
"""查找文件中的所有符号"""
return self.file_index.get(file_path, [])
def find_in_container(self, container: str) -> List[SymbolInfo]:
"""查找容器内的符号(类的所有方法等)"""
return self.container_index.get(container, [])
def find_by_prefix(self, prefix: str) -> List[SymbolInfo]:
"""查找以指定前缀开头的符号"""
results = []
for sym_name, symbols in self.name_index.items():
if sym_name.startswith(prefix):
results.extend(symbols)
return sorted(results, key=lambda s: (s.file_path, s.start_line))
class SymbolIndexBuilder:
"""
符号索引构建器
支持增量构建和批量构建
"""
def __init__(self, extractor: LSSPSymbolExtractor):
self.extractor = extractor
self.index = SymbolIndex()
def build_from_files(self, file_paths: list[str]) -> SymbolIndex:
"""
从多个文件构建索引
Args:
file_paths: 源文件路径列表
Returns:
构建完成的符号索引
"""
for file_path in file_paths:
self.add_file(file_path)
return self.index
def add_file(self, file_path: str):
"""增量添加文件到索引"""
language = self._detect_language(file_path)
symbols = self.extractor.extract_symbols(file_path, language)
for symbol in symbols:
self.index.add_symbol(symbol)
def remove_file(self, file_path: str):
"""从索引中移除文件"""
# 找出需要移除的符号
to_remove = self.index.find_by_file(file_path)
for symbol in to_remove:
# 从各索引中移除
self.index.name_index[symbol.name].remove(symbol)
self.index.kind_index[symbol.kind].remove(symbol)
self.index.file_index[symbol.file_path].remove(symbol)
if symbol.container_name and symbol.container_name in self.index.container_index:
self.index.container_index[symbol.container_name].remove(symbol)
self.index.all_symbols.remove(symbol)
def _detect_language(self, file_path: str) -> str:
"""检测语言"""
import os
_, ext = os.path.splitext(file_path)
ext_map = {
'.py': 'python',
'.js': 'javascript',
'.ts': 'typescript',
'.rs': 'rust',
'.go': 'go',
'.java': 'java',
}
return ext_map.get(ext.lower(), 'python')符号索引不仅包含符号的位置信息,还应包含类型信息、依赖关系等语义信息。这些信息对于 AI 理解代码至关重要。
from typing import Dict, Set, Optional, List
from dataclasses import dataclass
@dataclass
class TypeInfo:
"""类型信息"""
name: str
module: str = None
generic_params: List['TypeInfo'] = None
base_types: List[str] = None
methods: Dict[str, 'TypeInfo'] = None
properties: Dict[str, 'TypeInfo'] = None
class TypeInferenceEngine:
"""
类型推断引擎
基于静态分析推断变量的类型信息
支持:
1. 变量声明类型推断
2. 函数返回类型推断
3. 表达式类型推断
4. 泛型参数推断
"""
def __init__(self, symbol_index: SymbolIndex):
self.symbol_index = symbol_index
self.type_cache: Dict[str, TypeInfo] = {}
def infer_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
"""
推断符号的类型
Args:
symbol: 符号信息
Returns:
推断得到的类型信息
"""
cache_key = f"{symbol.file_path}:{symbol.start_line}"
if cache_key in self.type_cache:
return self.type_cache[cache_key]
if symbol.type_info:
# 已有类型信息
return self._parse_type_string(symbol.type_info)
# 基于上下文推断
inferred = self._infer_from_context(symbol)
self.type_cache[cache_key] = inferred
return inferred
def _infer_from_context(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
"""从代码上下文推断类型"""
if symbol.kind == SymbolKind.VARIABLE:
return self._infer_variable_type(symbol)
elif symbol.kind == SymbolKind.FUNCTION:
return self._infer_function_return_type(symbol)
elif symbol.kind == SymbolKind.PROPERTY:
return self._infer_property_type(symbol)
return None
def _infer_variable_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
"""推断变量类型"""
# 读取源代码
with open(symbol.file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
if symbol.start_line > len(lines):
return None
line = lines[symbol.start_line - 1]
# Python 类型注解
if ': ' in line and '=' in line:
type_part = line.split(':')[1].split('=')[0].strip()
return self._parse_type_string(type_part)
# JavaScript/TypeScript 类型注解
if ': ' in line and ('const' in line or 'let' in line or 'var' in line):
type_part = line.split(':')[1].split(';')[0].strip()
return self._parse_type_string(type_part)
return None
def _infer_function_return_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
"""推断函数返回类型"""
# 查找函数签名中的返回类型注解
if symbol.signature:
# 解析函数签名
import re
match = re.search(r'->\s*(\w+)', symbol.signature)
if match:
return TypeInfo(name=match.group(1))
return None
def _infer_property_type(self, symbol: SymbolInfo) -> Optional[TypeInfo]:
"""推断属性类型"""
# 属性通常是类成员
if symbol.container_name:
container_symbols = self.symbol_index.find_by_name(symbol.container_name, exact=True)
if container_symbols:
# 查找属性定义
container = container_symbols[0]
# 尝试从源代码中解析属性类型
pass
return None
def _parse_type_string(self, type_str: str) -> TypeInfo:
"""解析类型字符串"""
# 处理泛型
import re
generic_match = re.match(r'(\w+)\[(\w+(?:,\s*\w+)*)\]', type_str)
if generic_match:
base_type = generic_match.group(1)
params = generic_match.group(2).split(',')
return TypeInfo(
name=base_type,
generic_params=[TypeInfo(name=p.strip()) for p in params]
)
# 简单类型
return TypeInfo(name=type_str)
def build_type_hierarchy(self) -> Dict[str, Set[str]]:
"""
构建类型继承层次结构
Returns:
{子类型: {父类型集合}}
"""
hierarchy: Dict[str, Set[str]] = {}
for symbol in self.symbol_index.all_symbols:
if symbol.kind in (SymbolKind.CLASS, SymbolKind.INTERFACE):
type_info = self.infer_type(symbol)
if type_info and type_info.base_types:
for base in type_info.base_types:
if base not in hierarchy:
hierarchy[base] = set()
hierarchy[base].add(type_info.name)
return hierarchy当用户输入代码时,Symbol 索引用于:
Symbol 索引为 AI 提供代码的结构化视图:
class SymbolEnhancedContextBuilder:
"""
基于符号索引的上下文构建器
将符号信息融入检索结果,增强 AI 对代码的理解
"""
def __init__(self, symbol_index: SymbolIndex, type_inference: TypeInferenceEngine):
self.symbol_index = symbol_index
self.type_inference = type_inference
def build_enhanced_context(
self,
chunks: list[Chunk],
query: str
) -> list[EnhancedChunk]:
"""
构建增强的上下文
为每个 chunk 添加符号信息,帮助 AI 理解代码结构
"""
enhanced_chunks = []
for chunk in chunks:
# 查找 chunk 范围内的所有符号
symbols_in_range = self._find_symbols_in_range(chunk)
# 构建符号摘要
symbol_summary = self._build_symbol_summary(symbols_in_range)
# 添加到 chunk
enhanced = EnhancedChunk(
content=chunk.content,
start_line=chunk.start_line,
end_line=chunk.end_line,
file_path=chunk.file_path,
symbols=symbols_in_range,
symbol_summary=symbol_summary,
type_info=self._build_type_context(symbols_in_range)
)
enhanced_chunks.append(enhanced)
return enhanced_chunks
def _find_symbols_in_range(self, chunk: Chunk) -> list[SymbolInfo]:
"""查找 chunk 范围内的符号"""
symbols = []
for symbol in self.symbol_index.all_symbols:
if (symbol.file_path == chunk.file_path and
chunk.start_line <= symbol.start_line <= chunk.end_line):
symbols.append(symbol)
return symbols
def _build_symbol_summary(self, symbols: list[SymbolInfo]) -> str:
"""构建符号摘要"""
if not symbols:
return "无符号定义"
# 按类型分组
by_kind = {}
for sym in symbols:
if sym.kind not in by_kind:
by_kind[sym.kind] = []
by_kind[sym.kind].append(sym.name)
parts = []
for kind, names in sorted(by_kind.items(), key=lambda x: x[0].value):
parts.append(f"{kind.name}: {', '.join(names)}")
return "; ".join(parts)
def _build_type_context(self, symbols: list[SymbolInfo]) -> str:
"""构建类型上下文"""
type_contexts = []
for symbol in symbols:
type_info = self.type_inference.infer_type(symbol)
if type_info:
if type_info.generic_params:
type_str = f"{type_info.name}[{', '.join(t.name for t in type_info.generic_params)}]"
else:
type_str = type_info.name
type_contexts.append(f"{symbol.name}: {type_str}")
return "\n".join(type_contexts)
@dataclass
class EnhancedChunk:
"""增强的代码块"""
content: str
start_line: int
end_line: int
file_path: str
symbols: list[SymbolInfo]
symbol_summary: str
type_info: str本节为你提供的核心技术价值:理解代码图的构建原理,掌握调用图、依赖图、数据流图的实现技术,实现深层次代码关系推理。
代码不仅是静态的文本,更是一个由多种关系交织而成的语义网络。Graph 策略的核心思想是将代码中的各种关系建模为图结构,从而支持:
这种结构化的关系表示使得 AI 能够进行深层次推理,如"修改这个函数会影响哪些地方"。
调用图(Call Graph)是图结构中最重要的一种,表示函数之间的调用关系。

import ast
from typing import Dict, Set, List, Tuple, Optional
from dataclasses import dataclass, field
@dataclass
class CallGraphNode:
"""调用图节点"""
name: str
file_path: str
start_line: int
end_line: int
function_type: str # 'function', 'method', 'lambda', 'class'
is_external: bool = False
@dataclass
class CallGraphEdge:
"""调用图边"""
caller: str # caller node id
callee: str # callee node id
call_site_line: int
is_virtual: bool = False # 虚函数调用(运行时确定)
class CallGraph:
"""
调用图
支持:
1. 静态调用图构建(基于 AST 分析)
2. 动态调用图构建(基于运行时追踪)
3. 调用链查询
"""
def __init__(self):
self.nodes: Dict[str, CallGraphNode] = {}
self.edges: List[CallGraphEdge] = []
# 优化:构建反向索引
self.callee_to_callers: Dict[str, Set[str]] = {}
self.caller_to_callees: Dict[str, Set[str]] = {}
def add_node(self, node: CallGraphNode) -> str:
"""添加节点,返回节点 ID"""
node_id = self._make_node_id(node.file_path, node.name)
self.nodes[node_id] = node
return node_id
def add_edge(self, caller_id: str, callee_id: str, call_site_line: int, is_virtual: bool = False):
"""添加调用边"""
edge = CallGraphEdge(
caller=caller_id,
callee=callee_id,
call_site_line=call_site_line,
is_virtual=is_virtual
)
self.edges.append(edge)
# 更新反向索引
if callee_id not in self.callee_to_callers:
self.callee_to_callers[callee_id] = set()
self.callee_to_callers[callee_id].add(caller_id)
if caller_id not in self.caller_to_callees:
self.caller_to_callees[caller_id] = set()
self.caller_to_callees[caller_id].add(callee_id)
def get_callers(self, node_id: str) -> Set[str]:
"""获取调用指定函数的所有函数(反向查询)"""
return self.callee_to_callers.get(node_id, set())
def get_callees(self, node_id: str) -> Set[str]:
"""获取指定函数调用的所有函数(正向查询)"""
return self.caller_to_callees.get(node_id, set())
def find_call_chain(self, start: str, end: str) -> Optional[List[str]]:
"""
查找从 start 到 end 的调用链
使用 BFS 查找最短路径
"""
from collections import deque
if start == end:
return [start]
queue = deque([(start, [start])])
visited = {start}
while queue:
current, path = queue.popleft()
for callee in self.get_callees(current):
if callee == end:
return path + [callee]
if callee not in visited:
visited.add(callee)
queue.append((callee, path + [callee]))
return None
def find_all_callers(self, node_id: str, max_depth: int = 10) -> Set[str]:
"""
递归查找所有调用者(影响分析)
Args:
node_id: 目标函数 ID
max_depth: 最大递归深度
"""
all_callers: Set[str] = set()
def dfs(current: str, depth: int):
if depth >= max_depth:
return
for caller in self.get_callers(current):
if caller not in all_callers:
all_callers.add(caller)
dfs(caller, depth + 1)
dfs(node_id, 0)
return all_callers
def find_all_callees(self, node_id: str, max_depth: int = 10) -> Set[str]:
"""
递归查找所有被调用者(依赖分析)
"""
all_callees: Set[str] = set()
def dfs(current: str, depth: int):
if depth >= max_depth:
return
for callee in self.get_callees(current):
if callee not in all_callees:
all_callees.add(callee)
dfs(callee, depth + 1)
dfs(node_id, 0)
return all_callees
def _make_node_id(self, file_path: str, name: str) -> str:
"""生成唯一的节点 ID"""
return f"{file_path}:{name}"依赖图表示模块/包之间的导入关系,是代码组织和架构分析的基础。
import os
from typing import Dict, Set, List, Optional
from dataclasses import dataclass, field
@dataclass
class DependencyNode:
"""依赖图节点"""
name: str
path: str
node_type: str # 'module', 'package', 'file'
is_external: bool = False
@dataclass
class DependencyEdge:
"""依赖图边"""
source: str
target: str
import_line: int
import_type: str # 'import', 'from_import', 'dynamic'
class DependencyGraph:
"""
依赖图
支持:
1. 模块依赖分析
2. 循环依赖检测
3. 依赖路径查找
4. 影响范围分析
"""
def __init__(self):
self.nodes: Dict[str, DependencyNode] = {}
self.edges: List[DependencyEdge] = []
self.source_to_targets: Dict[str, Set[str]] = {}
self.target_to_sources: Dict[str, Set[str]] = {}
def add_node(self, node: DependencyNode) -> str:
"""添加节点"""
self.nodes[node.path] = node
return node.path
def add_edge(self, edge: DependencyEdge):
"""添加依赖边"""
self.edges.append(edge)
if edge.source not in self.source_to_targets:
self.source_to_targets[edge.source] = set()
self.source_to_targets[edge.source].add(edge.target)
if edge.target not in self.target_to_sources:
self.target_to_sources[edge.target] = set()
self.target_to_sources[edge.target].add(edge.source)
def find_cycle(self) -> Optional[List[str]]:
"""
检测循环依赖
Returns:
如果存在循环依赖,返回循环路径
"""
visited = set()
rec_stack = set()
path = []
def dfs(node: str) -> Optional[List[str]]:
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in self.source_to_targets.get(node, set()):
if neighbor not in visited:
result = dfs(neighbor)
if result:
return result
elif neighbor in rec_stack:
# 找到循环
cycle_start = path.index(neighbor)
return path[cycle_start:] + [neighbor]
path.pop()
rec_stack.remove(node)
return None
for node in self.nodes:
if node not in visited:
result = dfs(node)
if result:
return result
return None
def get_transitive_closure(self, node_id: str) -> Set[str]:
"""
获取节点的传递闭包(所有直接或间接依赖)
"""
closure = set()
to_visit = {node_id}
while to_visit:
current = to_visit.pop()
for neighbor in self.source_to_targets.get(current, set()):
if neighbor not in closure:
closure.add(neighbor)
to_visit.add(neighbor)
return closure
def get_reverse_transitive_closure(self, node_id: str) -> Set[str]:
"""
获取反向传递闭包(所有直接或间接影响该节点的上游)
"""
closure = set()
to_visit = {node_id}
while to_visit:
current = to_visit.pop()
for neighbor in self.target_to_sources.get(current, set()):
if neighbor not in closure:
closure.add(neighbor)
to_visit.add(neighbor)
return closure当修改某个函数时,通过调用图可以快速确定所有受影响的下游函数:
class ImpactAnalyzer:
"""
代码影响分析器
基于图结构分析代码修改的影响范围
"""
def __init__(self, call_graph: CallGraph, dependency_graph: DependencyGraph):
self.call_graph = call_graph
self.dependency_graph = dependency_graph
def analyze_impact(self, file_path: str, function_name: str) -> dict:
"""
分析修改指定函数的影响
Returns:
{
'direct_callers': [...], # 直接调用者
'all_callers': [...], # 所有递归调用者
'affected_files': [...], # 受影响的文件
'test_files': [...], # 可能受影响的测试文件
}
"""
func_id = f"{file_path}:{function_name}"
# 获取所有调用者
all_callers = self.call_graph.find_all_callers(func_id)
# 转换为文件列表
affected_files = set()
for caller_id in all_callers:
if ':' in caller_id:
caller_file = caller_id.rsplit(':', 1)[0]
affected_files.add(caller_file)
# 获取依赖该文件的模块
reverse_deps = self.dependency_graph.get_reverse_transitive_closure(file_path)
return {
'direct_callers': list(self.call_graph.get_callers(func_id)),
'all_callers': list(all_callers),
'affected_files': list(affected_files),
'affected_dependencies': list(reverse_deps)
}基于图的拓扑排序,AI IDE 可以为检索结果赋予优先级:
class GraphContextScorer:
"""
基于图的上下文评分器
根据代码在图中的位置计算重要性分数
"""
def __init__(self, call_graph: CallGraph, dependency_graph: DependencyGraph):
self.call_graph = call_graph
self.dependency_graph = dependency_graph
def score(self, file_path: str, function_name: str = None) -> float:
"""
计算代码的重要性分数
分数考量因素:
1. 被调用/引用次数(越多越重要)
2. 在依赖图中的层级(越底层越核心)
3. 是否为公共 API
"""
score = 1.0
if function_name:
node_id = f"{file_path}:{function_name}"
# 调用次数因子
caller_count = len(self.call_graph.get_callers(node_id))
score *= (1 + caller_count * 0.1)
# 导出函数因子
if not function_name.startswith('_'):
score *= 1.5
# 文件级别的评分
direct_deps = self.dependency_graph.get_transitive_closure(file_path)
dependents = self.dependency_graph.get_reverse_transitive_closure(file_path)
# 被越多模块依赖越重要
score *= (1 + len(dependents) * 0.05)
return min(score, 10.0) # 上限本节为你提供的核心技术价值:掌握向量检索与 BM25 的原理与实现,理解混合检索的架构设计,实现高质量的语义匹配。
RAG(Retrieval-Augmented Generation,检索增强生成)策略将信息检索与语言模型生成相结合,是 Context Engine 中实现语义匹配的核心组件。
与传统的关键词检索相比,RAG 能够:
login、authenticate、signIn 等多种实现向量检索(Vector Search)将文本映射到高维向量空间,通过向量相似度计算实现语义匹配。
import numpy as np
from typing import List, Optional
from dataclasses import dataclass, field
import hashlib
@dataclass
class TextVector:
"""文本向量"""
text: str
vector: np.ndarray
metadata: dict = field(default_factory=dict)
@property
def norm(self) -> float:
"""向量的 L2 范数"""
return np.linalg.norm(self.vector)
def cosine_similarity(self, other: 'TextVector') -> float:
"""计算余弦相似度"""
dot_product = np.dot(self.vector, other.vector)
return dot_product / (self.norm * other.norm + 1e-8)
class CodeVectorizer:
"""
代码向量化器
支持多种嵌入模型:
1. 基于 TF-IDF 的传统方法
2. 基于 Sentence Transformers 的现代方法
3. 专门针对代码的 CodeBERT、GraphCodeBERT
"""
def __init__(self, model_name: str = 'codebert'):
self.model_name = model_name
self.model = None
self._load_model()
def _load_model(self):
"""加载向量化模型"""
if self.model_name == 'tfidf':
self.model = TFIDFVectorizer()
elif self.model_name == 'codebert':
self.model = CodeBERTModel()
elif self.model_name == 'sentence-transformers':
self.model = SentenceTransformerModel()
else:
raise ValueError(f"Unknown model: {self.model_name}")
def vectorize(self, code: str, language: str = None) -> TextVector:
"""将代码向量化"""
vector = self.model.encode(code, language)
return TextVector(text=code, vector=vector)
def batch_vectorize(self, codes: List[str], language: str = None) -> List[TextVector]:
"""批量向量化"""
return [self.vectorize(code, language) for code in codes]
class TFIDFVectorizer:
"""
基于 TF-IDF 的向量化器
传统方法,适用于快速原型和简单场景
"""
def __init__(self, max_features: int = 4096):
self.max_features = max_features
self.vocabulary = {}
self.idf = {}
self.doc_count = 0
def fit(self, documents: List[str]):
"""构建词汇表和 IDF"""
from collections import Counter
self.doc_count = len(documents)
# 词频统计
df = Counter()
for doc in documents:
tokens = self._tokenize(doc)
df.update(set(tokens))
# 构建词汇表(取 top max_features)
for word, freq in df.most_common(self.max_features):
self.vocabulary[word] = len(self.vocabulary)
# 计算 IDF
for word, doc_freq in df.items():
if word in self.vocabulary:
self.idf[word] = np.log(self.doc_count / (doc_freq + 1)) + 1
def encode(self, text: str, language: str = None) -> np.ndarray:
"""将文本编码为向量"""
tokens = self._tokenize(text)
# TF
tf = Counter(tokens)
# 构建向量
vector = np.zeros(len(self.vocabulary))
for word, freq in tf.items():
if word in self.vocabulary:
idx = self.vocabulary[word]
vector[idx] = freq * self.idf.get(word, 1.0)
# 归一化
norm = np.linalg.norm(vector)
if norm > 0:
vector = vector / norm
return vector
def _tokenize(self, text: str) -> List[str]:
"""分词"""
import re
# 简单分词:保留单词和下划线,移除数字和特殊字符
tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*', text.lower())
return tokensfrom typing import List, Tuple, Optional
import heapq
class VectorIndex:
"""
向量索引
支持:
1. 暴力搜索(精确)
2. IVF-PQ 索引(近似)
3. HNSW 索引(近似)
"""
def __init__(self, dimension: int, index_type: str = 'flat'):
self.dimension = dimension
self.index_type = index_type
self.vectors: List[TextVector] = []
if index_type == 'flat':
self._search = self._flat_search
elif index_type == 'ivf_pq':
self._index = None
self._search = self._ivf_pq_search
elif index_type == 'hnsw':
self._hnsw = HNSWIndex(dimension)
self._search = self._hnsw_search
def add(self, vector: TextVector):
"""添加向量到索引"""
self.vectors.append(vector)
if self.index_type == 'hnsw':
self._hnsw.add(vector)
def search(self, query: TextVector, k: int = 10) -> List[Tuple[TextVector, float]]:
"""
搜索最相似的 k 个向量
Returns:
[(vector, similarity), ...]
"""
return self._search(query, k)
def _flat_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
"""暴力搜索"""
similarities = []
for vec in self.vectors:
sim = query.cosine_similarity(vec)
similarities.append((vec, sim))
# 返回 top-k
return heapq.nlargest(k, similarities, key=lambda x: x[1])
def _ivf_pq_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
"""IVF-PQ 近似搜索(简化实现)"""
# 实际实现应使用 FAISS 库
# 这里返回暴力搜索作为简化
return self._flat_search(query, k)
def _hnsw_search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
"""HNSW 近似搜索"""
return self._hnsw.search(query, k)
class HNSWIndex:
"""
Hierarchical Navigable Small World (HNSW) 索引
一种高效的近似最近邻搜索算法
"""
def __init__(self, dimension: int, m: int = 16, ef: int = 200):
self.dimension = dimension
self.m = m # 每个节点的最大连接数
self.ef = ef # 搜索时的动态列表大小
self.layers: List[List[TextVector]] = []
self.graph: List[dict[int, list[int]]] = [] # 每层的邻接表
def add(self, vector: TextVector):
"""添加向量到索引"""
# 简化实现
if not self.layers:
self.layers.append([])
self.graph.append({})
self.layers[0].append(vector)
def search(self, query: TextVector, k: int) -> List[Tuple[TextVector, float]]:
"""搜索"""
# 简化实现:使用第一层的暴力搜索
if not self.layers:
return []
layer_vectors = self.layers[0]
similarities = [(v, query.cosine_similarity(v)) for v in layer_vectors]
return heapq.nlargest(k, similarities, key=lambda x: x[1])BM25(Best Matching 25)是一种基于词频的概率检索模型,在信息检索领域广泛应用。其核心思想是:一个词在文档中出现的次数越多,文档越相关;但词频达到一定饱和后,边际效用递减。
import math
from typing import List, Tuple, Set
from collections import Counter
class BM25:
"""
BM25 检索算法
公式:
Score(D, Q) = Σ IDF(qi) * (tf(qi, D) * (k1 + 1)) / (tf(qi, D) + k1 * (1 - b + b * |D|/avgdl))
其中:
- tf(qi, D): 词 qi 在文档 D 中的词频
- |D|: 文档长度
- avgdl: 平均文档长度
- k1: 词频饱和参数(通常 1.2-2.0)
- b: 文档长度归一化参数(通常 0.75)
- IDF(qi): 逆文档频率
"""
def __init__(self, k1: float = 1.5, b: float = 0.75):
self.k1 = k1
self.b = b
self.documents: List[str] = []
self.tokenized_docs: List[List[str]] = []
self.avgdl = 0.0
self.doc_freq: dict = {} # 词 -> 包含该词的文档数
self.N = 0 # 总文档数
self.idf: dict = {}
def fit(self, documents: List[str]):
"""
构建 BM25 索引
Args:
documents: 文档列表
"""
self.documents = documents
self.N = len(documents)
# 分词
self.tokenized_docs = [self._tokenize(doc) for doc in documents]
# 计算平均文档长度
total_len = sum(len(doc) for doc in self.tokenized_docs)
self.avgdl = total_len / self.N if self.N > 0 else 0
# 统计文档频率
self.doc_freq = Counter()
for doc_tokens in self.tokenized_docs:
for token in set(doc_tokens):
self.doc_freq[token] += 1
# 计算 IDF
self._compute_idf()
def _tokenize(self, text: str) -> List[str]:
"""分词"""
import re
# 保留编程语言标识符
tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*', text.lower())
return tokens
def _compute_idf(self):
"""计算 IDF 值"""
for term, df in self.doc_freq.items():
# IDF 公式:log((N - df + 0.5) / (df + 0.5))
self.idf[term] = math.log((self.N - df + 0.5) / (df + 0.5) + 1)
def score(self, query: str, doc_idx: int) -> float:
"""
计算查询对单个文档的 BM25 分数
Args:
query: 查询字符串
doc_idx: 文档索引
Returns:
BM25 分数
"""
query_tokens = self._tokenize(query)
doc_tokens = self.tokenized_docs[doc_idx]
doc_len = len(doc_tokens)
score = 0.0
tf = Counter(doc_tokens)
for token in query_tokens:
if token not in self.idf:
continue
term_freq = tf[token]
idf = self.idf[token]
# BM25 公式
numerator = term_freq * (self.k1 + 1)
denominator = term_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
score += idf * (numerator / denominator)
return score
def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
"""
搜索最相关的文档
Args:
query: 查询字符串
top_k: 返回的最多结果数
Returns:
[(doc_idx, score), ...]
"""
scores = [(i, self.score(query, i)) for i in range(self.N)]
# 过滤零分结果并排序
scores = [(idx, score) for idx, score in scores if score > 0]
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]向量检索擅长语义匹配但对精确关键词不敏感,BM25 擅长精确匹配但无法理解语义。混合检索通过融合两种方法的优势,实现1+1 > 2的效果。
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass, field
import numpy as np
@dataclass
class SearchResult:
"""搜索结果"""
chunk_id: str
content: str
file_path: str
start_line: int
end_line: int
vector_score: float = 0.0
bm25_score: float = 0.0
combined_score: float = 0.0
rank: int = 0
metadata: dict = field(default_factory=dict)
class HybridRetriever:
"""
混合检索器
融合向量检索和 BM25 检索的结果
"""
def __init__(
self,
vector_weight: float = 0.6,
bm25_weight: float = 0.4,
rerank: bool = True
):
"""
Args:
vector_weight: 向量检索权重
bm25_weight: BM25 权重
rerank: 是否进行重排序
"""
self.vector_weight = vector_weight
self.bm25_weight = bm25_weight
self.rerank = rerank
# 各子检索器
self.vector_index: VectorIndex = None
self.bm25: BM25 = None
# 文档存储
self.documents: Dict[str, dict] = {}
def index(self, chunks: List[dict]):
"""
构建索引
Args:
chunks: [{'id': ..., 'content': ..., 'file_path': ..., 'metadata': ...}, ...]
"""
# 存储文档
for chunk in chunks:
self.documents[chunk['id']] = chunk
# 构建向量索引
vectorizer = CodeVectorizer('sentence-transformers')
vectors = [vectorizer.vectorize(chunk['content']) for chunk in chunks]
self.vector_index = VectorIndex(dimension=len(vectors[0].vector) if vectors else 384)
for i, vec in enumerate(vectors):
vec.metadata = {'chunk_id': chunks[i]['id']}
self.vector_index.add(vec)
# 构建 BM25 索引
self.bm25 = BM25()
self.bm25.fit([chunk['content'] for chunk in chunks])
def search(
self,
query: str,
top_k: int = 20,
max_results: int = 10
) -> List[SearchResult]:
"""
混合搜索
Args:
query: 查询字符串
top_k: 每种检索方式返回的结果数
max_results: 最终返回的结果数
"""
# 并行执行两种检索
vector_results = self._vector_search(query, top_k)
bm25_results = self._bm25_search(query, top_k)
# 分数归一化
vector_results = self._normalize_scores(vector_results)
bm25_results = self._normalize_scores(bm25_results)
# 合并结果
combined = self._combine_results(vector_results, bm25_results)
# 重排序
if self.rerank:
combined = self._rerank(combined, query)
# 截取 top_k
combined = combined[:max_results]
# 设置最终排名
for i, result in enumerate(combined):
result.rank = i + 1
return combined
def _vector_search(self, query: str, top_k: int) -> Dict[str, float]:
"""向量检索"""
if not self.vector_index:
return {}
vectorizer = CodeVectorizer('sentence-transformers')
query_vector = vectorizer.vectorize(query)
results = self.vector_index.search(query_vector, top_k)
return {r[0].metadata.get('chunk_id', ''): r[1] for r in results}
def _bm25_search(self, query: str, top_k: int) -> Dict[str, float]:
"""BM25 检索"""
if not self.bm25:
return {}
scores = self.bm25.search(query, top_k)
return {str(idx): score for idx, score in scores}
def _normalize_scores(self, scores: Dict[str, float]) -> Dict[str, float]:
"""Min-Max 归一化分数"""
if not scores:
return {}
min_s = min(scores.values())
max_s = max(scores.values())
if max_s == min_s:
return {k: 1.0 for k in scores}
return {k: (v - min_s) / (max_s - min_s) for k, v in scores.items()}
def _combine_results(
self,
vector_scores: Dict[str, float],
bm25_scores: Dict[str, float]
) -> List[SearchResult]:
"""合并检索结果"""
all_ids = set(vector_scores.keys()) | set(bm25_scores.keys())
results = []
for chunk_id in all_ids:
if chunk_id not in self.documents:
continue
doc = self.documents[chunk_id]
v_score = vector_scores.get(chunk_id, 0.0)
b_score = bm25_scores.get(chunk_id, 0.0)
# 加权求和
combined = self.vector_weight * v_score + self.bm25_weight * b_score
result = SearchResult(
chunk_id=chunk_id,
content=doc['content'],
file_path=doc.get('file_path', ''),
start_line=doc.get('start_line', 0),
end_line=doc.get('end_line', 0),
vector_score=v_score,
bm25_score=b_score,
combined_score=combined,
metadata=doc.get('metadata', {})
)
results.append(result)
# 按综合分数排序
results.sort(key=lambda x: x.combined_score, reverse=True)
return results
def _rerank(self, results: List[SearchResult], query: str) -> List[SearchResult]:
"""
使用交叉编码器重排序
简化实现:基于关键词匹配的重排序
实际应使用 Cross-Encoder 模型
"""
query_terms = set(query.lower().split())
for result in results:
content_terms = set(result.content.lower().split())
# 计算查询词在结果中的覆盖率
coverage = len(query_terms & content_terms) / len(query_terms) if query_terms else 0
# 轻微调整分数:覆盖率高的结果提升
result.combined_score = result.combined_score * (1 + 0.1 * coverage)
results.sort(key=lambda x: x.combined_score, reverse=True)
return results检索到的结果需要组装成最终的上下文,供给 AI 模型使用。这个过程需要考虑:
class RAGContextAssembler:
"""
RAG 上下文组装器
将检索结果组装成适合 AI 模型的上下文
"""
def __init__(
self,
max_tokens: int = 8000,
preserve_edges: bool = True,
position_strategy: str = 'balanced'
):
"""
Args:
max_tokens: 最大 token 数
preserve_edges: 是否保留代码块的完整边界
position_strategy: 位置策略
- 'balanced': 首尾各 30%,中间 40%
- 'front': 前 70%
- 'back': 后 70%
- 'uniform': 均匀分布
"""
self.max_tokens = max_tokens
self.preserve_edges = preserve_edges
self.position_strategy = position_strategy
def assemble(
self,
results: List[SearchResult],
query: str,
system_prompt: str = None
) -> str:
"""
组装上下文
Args:
results: 排序后的检索结果
query: 用户查询
system_prompt: 系统提示
Returns:
组装后的上下文字符串
"""
# 估算 token
def estimate_tokens(text: str) -> int:
# 简化估算:约 4 字符/token
return len(text) // 4
# 计算系统提示的 token
system_tokens = estimate_tokens(system_prompt) if system_prompt else 0
query_tokens = estimate_tokens(f"Query: {query}\n")
# 可用于上下文的 token
available_tokens = self.max_tokens - system_tokens - query_tokens - 200 # 保留一些余量
# 按位置策略分配 token
token_allocation = self._allocate_tokens(len(results), available_tokens)
# 选择和裁剪内容
context_parts = []
current_tokens = 0
for i, (result, allocated) in enumerate(zip(results, token_allocation)):
content_tokens = estimate_tokens(result.content)
if content_tokens <= allocated:
# 内容足够,直接添加
context_parts.append(self._format_result(result, i + 1))
current_tokens += content_tokens
else:
# 需要裁剪
truncated = self._truncate_content(result, allocated)
context_parts.append(self._format_result_with_truncation(result, truncated, i + 1))
current_tokens += estimate_tokens(truncated)
# 组装最终上下文
if system_prompt:
final_context = f"{system_prompt}\n\n"
else:
final_context = ""
final_context += f"Query: {query}\n\n"
final_context += "Relevant Code:\n"
final_context += "\n\n".join(context_parts)
return final_context
def _allocate_tokens(self, num_results: int, total_tokens: int) -> List[int]:
"""根据位置策略分配 token"""
if num_results == 0:
return []
# 计算每部分的比例
if self.position_strategy == 'balanced':
# 首尾各 35%,中间 30%
if num_results == 1:
ratios = [1.0]
elif num_results == 2:
ratios = [0.5, 0.5]
elif num_results >= 3:
mid_count = num_results - 2
front = 0.35
back = 0.35
middle = 0.30
ratios = []
ratios.append(front)
for _ in range(mid_count):
ratios.append(middle / mid_count)
ratios.append(back)
elif self.position_strategy == 'front':
ratios = [1.0 / num_results] * num_results
# 前面的权重更高
for i in range(num_results):
ratios[i] *= (num_results - i) / (num_results - i + 1)
elif self.position_strategy == 'back':
ratios = [1.0 / num_results] * num_results
for i in range(num_results):
ratios[i] *= (i + 1) / (i + 2)
else: # uniform
ratios = [1.0 / num_results] * num_results
# 转换为 token 数量
tokens = [int(total_tokens * r) for r in ratios]
# 调整以确保总和正确
diff = total_tokens - sum(tokens)
if diff != 0 and tokens:
tokens[-1] += diff
return tokens
def _format_result(self, result: SearchResult, rank: int) -> str:
"""格式化单个结果"""
return f"""--- Result #{rank} (Score: {result.combined_score:.3f}) ---
File: {result.file_path}:{result.start_line}-{result.end_line}
{result.content}
---"""
def _format_result_with_truncation(
self,
result: SearchResult,
truncated_content: str,
rank: int
) -> str:
"""格式化截断的结果"""
return f"""--- Result #{rank} (Score: {result.combined_score:.3f}, TRUNCATED) ---
File: {result.file_path}:{result.start_line}-{result.end_line}
{truncated_content}
[... {len(result.content) - len(truncated_content)} characters truncated ...]
---"""
def _truncate_content(self, result: SearchResult, max_tokens: int) -> str:
"""截断内容以适应 token 限制"""
max_chars = max_tokens * 4 # 约 4 字符/token
if len(result.content) <= max_chars:
return result.content
# 优先保留重要部分:开头和结尾
keep_from_start = max_chars // 2
keep_from_end = max_chars // 2
start_part = result.content[:keep_from_start]
end_part = result.content[-keep_from_end:]
return start_part + "\n\n[... content truncated ...]\n\n" + end_part本节为你提供的核心技术价值:理解上下文选择的挑战与算法设计,掌握多阶段筛选与贪心选择的核心原理,实现高效的大规模代码库上下文管理。
在百万行级别的代码库中,即使进行了分块和检索,候选上下文仍然可能远超模型的容量限制。上下文选择算法需要在召回率(relevant content 覆盖率)和精确率(irrelevant noise 排除率)之间取得平衡。
代码库规模:1,000,000 行
分块大小:100 行/块
总块数:10,000 块
模型上下文:100,000 token ≈ 25,000 行 ≈ 250 块
选择比例:250 / 10,000 = 2.5%
问题:从 10,000 个块中选出最重要的 250 个这是一个典型的top-k 选择问题,但具有以下特殊性:
from typing import List, Set, Dict, Tuple, Optional
from dataclasses import dataclass, field
import heapq
@dataclass
class ScoredChunk:
"""带分数的代码块"""
chunk_id: str
content: str
file_path: str
start_line: int
end_line: int
score: float = 0.0
features: dict = field(default_factory=dict)
class MultiStageSelector:
"""
多阶段上下文选择器
阶段:
1. 粗筛:快速过滤明显不相关的块
2. 精排:计算每个块的详细分数
3. 贪心选择:在多样性约束下选择最优组合
"""
def __init__(
self,
max_chunks: int = 100,
diversity_weight: float = 0.2,
recency_weight: float = 0.1,
proximity_weight: float = 0.3,
relevance_weight: float = 0.4
):
self.max_chunks = max_chunks
self.diversity_weight = diversity_weight
self.recency_weight = recency_weight
self.proximity_weight = proximity_weight
self.relevance_weight = relevance_weight
# 状态
self.current_file = None
self.current_line = 0
self.modification_times: Dict[str, float] = {} # chunk_id -> timestamp
def select(
self,
candidates: List[ScoredChunk],
query: str,
current_file: str = None,
current_line: int = 0
) -> List[ScoredChunk]:
"""
多阶段选择
Args:
candidates: 候选块列表
query: 查询
current_file: 当前编辑文件
current_line: 当前编辑行
Returns:
选中的块列表
"""
self.current_file = current_file
self.current_line = current_line
# 阶段 1:粗筛 - 基于简单规则快速过滤
filtered = self._coarse_filter(candidates, query)
# 阶段 2:精排 - 计算综合分数
scored = self._fine_rank(filtered, query)
# 阶段 3:贪心选择 - 考虑多样性
selected = self._greedy_select(scored)
return selected
def _coarse_filter(self, candidates: List[ScoredChunk], query: str) -> List[ScoredChunk]:
"""
粗筛阶段
快速过滤策略:
1. 文件类型过滤
2. 关键词黑名单过滤
3. 最小相关性阈值过滤
"""
# 黑名单关键词
blacklist = {'test', 'mock', 'fixture', '__pycache__', '.min.js'}
# 扩展查询关键词
query_terms = set(query.lower().split())
filtered = []
for chunk in candidates:
# 文件类型检查
if any(term in chunk.file_path.lower() for term in blacklist):
continue
# 基础相关性检查:至少有一个查询词在内容或路径中
content_lower = chunk.content.lower()
path_lower = chunk.file_path.lower()
has_match = any(
term in content_lower or term in path_lower
for term in query_terms
)
if has_match or not query_terms:
filtered.append(chunk)
# 如果过滤后仍然太多,随机采样
if len(filtered) > self.max_chunks * 3:
import random
random.shuffle(filtered)
filtered = filtered[:self.max_chunks * 3]
return filtered
def _fine_rank(self, candidates: List[ScoredChunk], query: str) -> List[ScoredChunk]:
"""
精排阶段
计算每个块的多维度分数:
1. 相关性分数(来自 RAG)
2. 邻近性分数(与当前编辑位置的距离)
3. 时效性分数(最近修改)
4. 多样性分数(与已选块的区别)
"""
for chunk in candidates:
features = self._compute_features(chunk, query)
chunk.features = features
# 综合分数
chunk.score = (
self.relevance_weight * features['relevance'] +
self.proximity_weight * features['proximity'] +
self.recency_weight * features['recency']
)
# 按分数排序
candidates.sort(key=lambda c: c.score, reverse=True)
return candidates
def _compute_features(self, chunk: ScoredChunk, query: str) -> dict:
"""计算多维特征"""
features = {}
# 1. 相关性分数(这里简化,实际来自 RAG 检索)
query_terms = set(query.lower().split())
content_terms = set(chunk.content.lower().split())
# Jaccard 相似度
intersection = len(query_terms & content_terms)
union = len(query_terms | content_terms)
features['relevance'] = intersection / union if union > 0 else 0
# 2. 邻近性分数
if self.current_file and chunk.file_path == self.current_file:
# 在同一文件中,距离越近分数越高
distance = abs(chunk.start_line - self.current_line)
features['proximity'] = 1.0 / (1.0 + distance * 0.1)
else:
features['proximity'] = 0.0
# 3. 时效性分数
mod_time = self.modification_times.get(chunk.chunk_id, 0)
if mod_time > 0:
import time
age_days = (time.time() - mod_time) / 86400
features['recency'] = 1.0 / (1.0 + age_days)
else:
features['recency'] = 0.5 # 默认中间值
return features
def _greedy_select(self, candidates: List[ScoredChunk]) -> List[ScoredChunk]:
"""
贪心选择
在考虑多样性的情况下选择最优组合
"""
if len(candidates) <= self.max_chunks:
return candidates
selected = []
remaining = candidates.copy()
while len(selected) < self.max_chunks and remaining:
# 选择当前最优
best = remaining.pop(0)
selected.append(best)
# 惩罚与已选块相似的内容(多样性维护)
remaining = self._penalize_similar(remaining, selected)
# 重新排序
remaining.sort(key=lambda c: c.score, reverse=True)
return selected
def _penalize_similar(
self,
candidates: List[ScoredChunk],
selected: List[ScoredChunk]
) -> List[ScoredChunk]:
"""惩罚与已选块相似的内容"""
selected_signatures = [self._get_signature(c) for c in selected]
penalized = []
for candidate in candidates:
sig = self._get_signature(candidate)
# 计算与最近选择块的相似度
max_similarity = 0.0
for selected_sig in selected_signatures[-5:]: # 只看最近 5 个
sim = self._signature_similarity(sig, selected_sig)
max_similarity = max(max_similarity, sim)
# 应用惩罚
candidate.score *= (1 - self.diversity_weight * max_similarity)
penalized.append(candidate)
return penalized
def _get_signature(self, chunk: ScoredChunk) -> Set[str]:
"""获取块的签名(顶层词汇集合)"""
# 取内容的前 20 个实词作为签名
import re
tokens = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]{2,}', chunk.content.lower())
# 去除停用词
stopwords = {'function', 'class', 'def', 'return', 'if', 'else', 'for', 'while', 'import', 'from', 'const', 'let', 'var', 'public', 'private', 'static'}
tokens = [t for t in tokens if t not in stopwords]
return set(tokens[:20])
def _signature_similarity(self, sig1: Set[str], sig2: Set[str]) -> float:
"""计算两个签名的相似度"""
if not sig1 or not sig2:
return 0.0
intersection = len(sig1 & sig2)
union = len(sig1 | sig2)
return intersection / union if union > 0 else 0.0在实际场景中,AI IDE 需要处理多种类型的上下文请求:代码补全、代码生成、代码解释、bug 修复等。每种请求对上下文的需求不同,需要动态分配上下文预算。
from enum import Enum
from typing import Dict, Callable
class QueryType(Enum):
"""查询类型"""
COMPLETION = "completion" # 代码补全
GENERATION = "generation" # 代码生成
EXPLANATION = "explanation" # 代码解释
REFACTORING = "refactoring" # 重构
BUG_FIX = "bug_fix" # Bug 修复
DEBUGGING = "debugging" # 调试
GENERAL = "general" # 通用
class ContextBudgetAllocator:
"""
上下文预算分配器
根据查询类型动态分配上下文预算
"""
# 不同查询类型的预算分配策略
BUDGET_STRATEGIES = {
QueryType.COMPLETION: {
'max_context': 2000,
'focus': 'current_file',
'include_imports': True,
'include_related': True,
'include_tests': False,
},
QueryType.GENERATION: {
'max_context': 8000,
'focus': 'related_files',
'include_imports': True,
'include_related': True,
'include_tests': False,
},
QueryType.EXPLANATION: {
'max_context': 5000,
'focus': 'target_code',
'include_imports': True,
'include_related': False,
'include_tests': False,
},
QueryType.REFACTORING: {
'max_context': 10000,
'focus': 'target_and_callers',
'include_imports': True,
'include_related': True,
'include_tests': True,
},
QueryType.BUG_FIX: {
'max_context': 8000,
'focus': 'error_and_related',
'include_imports': True,
'include_related': True,
'include_tests': True,
},
QueryType.DEBUGGING: {
'max_context': 6000,
'focus': 'error_context',
'include_imports': False,
'include_related': True,
'include_tests': False,
},
QueryType.GENERAL: {
'max_context': 5000,
'focus': 'balanced',
'include_imports': True,
'include_related': True,
'include_tests': False,
},
}
def allocate(
self,
query_type: QueryType,
query: str,
available_budget: int
) -> Dict:
"""
分配上下文预算
Returns:
{
'max_context': 最大上下文量,
'chunk_budgets': {'type': budget},
'focus_areas': [...]
}
"""
strategy = self.BUDGET_STRATEGIES.get(query_type, self.BUDGET_STRATEGIES[QueryType.GENERAL])
# 根据可用预算调整
max_context = min(strategy['max_context'], available_budget)
# 计算各部分预算
chunk_budgets = self._calculate_chunk_budgets(strategy, max_context)
return {
'query_type': query_type,
'max_context': max_context,
'chunk_budgets': chunk_budgets,
'focus_areas': strategy['focus'].split('_'),
'include_imports': strategy['include_imports'],
'include_related': strategy['include_related'],
'include_tests': strategy['include_tests'],
}
def _calculate_chunk_budgets(
self,
strategy: Dict,
total_budget: int
) -> Dict[str, int]:
"""计算各部分的预算分配"""
focus = strategy['focus']
if focus == 'current_file':
return {
'current_function': int(total_budget * 0.4),
'current_file': int(total_budget * 0.3),
'imports': int(total_budget * 0.2),
'related': int(total_budget * 0.1),
}
elif focus == 'target_and_callers':
return {
'target_code': int(total_budget * 0.35),
'callers': int(total_budget * 0.25),
'callees': int(total_budget * 0.15),
'dependencies': int(total_budget * 0.15),
'tests': int(total_budget * 0.1),
}
elif focus == 'error_context':
return {
'error_location': int(total_budget * 0.3),
'error_trace': int(total_budget * 0.25),
'related_functions': int(total_budget * 0.25),
'data_flow': int(total_budget * 0.2),
}
else: # balanced
return {
'target_code': int(total_budget * 0.3),
'related_code': int(total_budget * 0.3),
'imports': int(total_budget * 0.2),
'tests': int(total_budget * 0.2),
}本节为你提供的核心技术价值:理解 AI IDE 中的成本结构,掌握 Token 预算管理策略,实现效果与成本的最优平衡。
在 AI IDE 中,每一次 AI 模型的调用都涉及 Token 消耗。Token 消耗主要来自两部分:

根据 2026 年的主流模型定价[^5]:
模型 | 输入 ($/1M tokens) | 输出 ($/1M tokens) |
|---|---|---|
GPT-4o | $2.50 | $10.00 |
Claude 3.5 Sonnet | $3.00 | $15.00 |
Gemini 1.5 Pro | $1.25 | $5.00 |
对于一个日活跃开发者(假设每天 200 次 AI 调用,平均每次 50K 输入 tokens),月度成本可达:
200 calls/day × 30 days × 50K tokens × $3/1M = $90/月/人对于一个 100 人的开发团队,月度成本约 $9,000。因此,Token 优化不仅是技术问题,更是经济问题。
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Tuple
from enum import Enum
import time
class CostLevel(Enum):
"""成本等级"""
LOW = "low" # 简单查询,本地模型
MEDIUM = "medium" # 标准查询,中等模型
HIGH = "high" # 复杂查询,顶级模型
@dataclass
class TokenBudget:
"""Token 预算"""
total: int
used: int = 0
reserved: int = 0 # 预留给关键上下文的 tokens
@property
def available(self) -> int:
return self.total - self.used - self.reserved
@property
def usage_ratio(self) -> float:
return (self.used + self.reserved) / self.total if self.total > 0 else 0
class TokenBudgetManager:
"""
Token 预算管理器
职责:
1. 跟踪 Token 消耗
2. 动态调整预算分配
3. 实施成本控制策略
"""
def __init__(self, daily_budget_tokens: int = 10_000_000):
"""
Args:
daily_budget_tokens: 每日 Token 预算上限
"""
self.daily_budget = daily_budget_tokens
self.daily_used = 0
self.last_reset = time.time()
# 各功能类型的预算
self.feature_budgets: Dict[str, TokenBudget] = {
'completion': TokenBudget(total=int(daily_budget_tokens * 0.3)),
'generation': TokenBudget(total=int(daily_budget_tokens * 0.25)),
'explanation': TokenBudget(total=int(daily_budget_tokens * 0.15)),
'refactoring': TokenBudget(total=int(daily_budget_tokens * 0.15)),
'other': TokenBudget(total=int(daily_budget_tokens * 0.15)),
}
# 成本等级配置
self.cost_thresholds = {
CostLevel.LOW: 1000, # < 1K tokens
CostLevel.MEDIUM: 5000, # 1K - 5K tokens
CostLevel.HIGH: 5000, # > 5K tokens
}
def allocate(
self,
feature: str,
requested: int,
priority: int = 1
) -> Tuple[int, CostLevel]:
"""
分配 Token 预算
Args:
feature: 功能类型
requested: 请求的 Token 数
priority: 优先级(1-5,越高越优先)
Returns:
(实际分配数, 成本等级)
"""
# 重置每日计数
self._check_daily_reset()
# 确定成本等级
cost_level = self._determine_cost_level(requested)
# 获取功能预算
budget = self.feature_budgets.get(feature, self.feature_budgets['other'])
# 检查预算可用性
if budget.available >= requested:
budget.used += requested
self.daily_used += requested
return requested, cost_level
# 预算不足,尝试回收低优先级预算
reclaimed = self._reclaim_budget(priority)
if budget.available + reclaimed >= requested:
# 优先使用功能自身预算
available = budget.available
budget.used += available
remaining = requested - available
budget.used += remaining
self.daily_used += requested
return requested, cost_level
# 最终方案:按比例分配
allocated = min(budget.available, requested)
budget.used += allocated
self.daily_used += allocated
return allocated, cost_level
def _check_daily_reset(self):
"""检查是否需要重置每日计数"""
current_time = time.time()
if current_time - self.last_reset > 86400: # 24 hours
self.daily_used = 0
self.last_reset = current_time
# 重置各功能预算
for budget in self.feature_budgets.values():
budget.used = 0
def _determine_cost_level(self, tokens: int) -> CostLevel:
"""确定成本等级"""
if tokens < self.cost_thresholds[CostLevel.LOW]:
return CostLevel.LOW
elif tokens < self.cost_thresholds[CostLevel.MEDIUM]:
return CostLevel.MEDIUM
else:
return CostLevel.HIGH
def _reclaim_budget(self, priority: int) -> int:
"""回收低优先级预算"""
reclaimed = 0
# 按优先级排序(低优先级的先回收)
sorted_features = sorted(
self.feature_budgets.items(),
key=lambda x: self._get_feature_priority(x[0])
)
for feature, budget in sorted_features:
if self._get_feature_priority(feature) < priority and budget.reserved > 0:
# 回收预留但未使用的预算
reclaimed += budget.reserved
budget.reserved = 0
return reclaimed
def _get_feature_priority(self, feature: str) -> int:
"""获取功能优先级"""
priorities = {
'completion': 5, # 代码补全高优先级
'generation': 4,
'refactoring': 3,
'explanation': 2,
'other': 1,
}
return priorities.get(feature, 1)
def get_cost_report(self) -> Dict:
"""获取成本报告"""
self._check_daily_reset()
return {
'daily_budget': self.daily_budget,
'daily_used': self.daily_used,
'daily_usage_ratio': self.daily_used / self.daily_budget if self.daily_budget > 0 else 0,
'feature_costs': {
feature: {
'budget': budget.total,
'used': budget.used,
'usage_ratio': budget.usage_ratio
}
for feature, budget in self.feature_budgets.items()
}
}当 Token 预算紧张时,需要对上下文进行压缩。自适应压缩根据内容的重要性动态调整压缩比例。
class AdaptiveContextCompressor:
"""
自适应上下文压缩器
策略:
1. 保留高相关性内容,低压缩
2. 保留首尾内容,中间可压缩
3. 保留结构标记(如函数签名),可压缩实现细节
4. 保留近期内容,压缩旧内容
"""
def __init__(
self,
min_compression_ratio: float = 0.5, # 最小压缩率(保留 50%)
max_compression_ratio: float = 0.9 # 最大压缩率(保留 10%)
):
self.min_compression_ratio = min_compression_ratio
self.max_compression_ratio = max_compression_ratio
def compress(
self,
content: str,
target_tokens: int,
relevance_scores: Dict[int, float] = None
) -> str:
"""
自适应压缩内容
Args:
content: 原始内容
target_tokens: 目标 token 数
relevance_scores: 每行/每个段落的 relevance 分数
Returns:
压缩后的内容
"""
# 估算当前 token 数
current_tokens = len(content) // 4
if current_tokens <= target_tokens:
return content
# 计算压缩比
compression_ratio = target_tokens / current_tokens
compression_ratio = max(
self.min_compression_ratio,
min(self.max_compression_ratio, compression_ratio)
)
# 分析内容结构
segments = self._split_into_segments(content)
if not segments:
return self._aggressive_compress(content, target_tokens)
# 计算每段的压缩权重
segment_weights = self._calculate_weights(
segments,
relevance_scores
)
# 按权重分配目标 token
total_weight = sum(segment_weights)
allocated_tokens = {
i: int(target_tokens * (w / total_weight))
for i, w in enumerate(segment_weights)
}
# 压缩各段
compressed_segments = []
for i, segment in enumerate(segments):
seg_tokens = allocated_tokens.get(i, 0)
if seg_tokens < 20: # 太小则跳过
continue
if len(segment) // 4 > seg_tokens:
# 需要压缩
compressed = self._compress_segment(segment, seg_tokens)
compressed_segments.append(compressed)
else:
compressed_segments.append(segment)
return '\n'.join(compressed_segments)
def _split_into_segments(self, content: str) -> List[str]:
"""将内容分割成段落"""
# 按空行分割
segments = content.split('\n\n')
return [s.strip() for s in segments if s.strip()]
def _calculate_weights(
self,
segments: List[str],
relevance_scores: Dict[int, float]
) -> List[float]:
"""计算每段的权重"""
weights = []
for i, segment in enumerate(segments):
base_weight = 1.0
# 首段和末段权重更高(位置效应)
if i == 0:
base_weight *= 1.5
elif i == len(segments) - 1:
base_weight *= 1.3
# 相关性分数
relevance = relevance_scores.get(i, 0.5) if relevance_scores else 0.5
base_weight *= (0.5 + relevance)
# 长度调整:太长或太短的段权重降低
line_count = len(segment.split('\n'))
if line_count < 3:
base_weight *= 0.7
elif line_count > 50:
base_weight *= 0.8
weights.append(base_weight)
return weights
def _compress_segment(self, segment: str, target_tokens: int) -> str:
"""压缩单个段落"""
lines = segment.split('\n')
if len(lines) <= 3:
return segment
target_lines = max(3, int(len(lines) * (target_tokens * 4 / len(segment))))
if target_lines >= len(lines):
return segment
# 保留开头和结尾
keep_from_start = target_lines // 2
keep_from_end = target_lines - keep_from_start
compressed = lines[:keep_from_start]
compressed.append(f"... [{len(lines) - target_lines} lines truncated] ...")
compressed.extend(lines[-keep_from_end:])
return '\n'.join(compressed)
def _aggressive_compress(self, content: str, target_tokens: int) -> str:
"""激进压缩(当无法按段落分割时)"""
target_chars = target_tokens * 4
if len(content) <= target_chars:
return content
# 保留开头和结尾
keep = target_chars // 2
return content[:keep] + f"\n\n[... content truncated ...]\n\n" + content[-keep:]本节为你提供的核心技术价值:通过完整代码示例展示如何将 Chunk、Symbol、Graph、RAG 四大策略整合,构建生产级别的 Context Engine。

import asyncio
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple, Any
from enum import Enum
import time
class QueryIntent(Enum):
"""查询意图"""
COMPLETION = "completion"
GENERATION = "generation"
EXPLANATION = "explanation"
REFACTORING = "refactoring"
BUG_FIX = "bug_fix"
NAVIGATION = "navigation"
SEARCH = "search"
@dataclass
class Query:
"""用户查询"""
text: str
intent: QueryIntent
file_path: Optional[str] = None
cursor_position: Optional[int] = None
language: Optional[str] = None
constraints: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ContextEngineConfig:
"""Context Engine 配置"""
# 向量检索配置
vector_model: str = "multi-qa-MiniLM-L6-cos-v1"
vector_weight: float = 0.5
# BM25 配置
bm25_weight: float = 0.3
# Symbol 配置
symbol_weight: float = 0.15
# Graph 配置
graph_weight: float = 0.05
# 预算配置
max_context_tokens: int = 8000
max_chunks: int = 50
# 性能配置
num_workers: int = 4
cache_enabled: bool = True
cache_ttl: int = 300 # seconds
class ContextEngine:
"""
混合检索 Context Engine
整合四大策略:
1. Chunk:基础分块
2. Symbol:符号索引
3. Graph:图结构
4. RAG:向量 + BM25 混合检索
"""
def __init__(self, config: ContextEngineConfig):
self.config = config
# 初始化各组件
self.vector_store: VectorIndex = None
self.bm25_index: BM25 = None
self.symbol_index: SymbolIndex = None
self.call_graph: CallGraph = None
self.dependency_graph: DependencyGraph = None
# 缓存
self._cache: Dict[str, Tuple[Any, float]] = {}
# 状态
self._initialized = False
async def initialize(self, project_root: str):
"""
初始化 Context Engine
Args:
project_root: 项目根目录
"""
if self._initialized:
return
print(f"Initializing Context Engine for {project_root}")
# 1. 加载或构建 Chunk 索引
chunks = await self._build_chunk_index(project_root)
# 2. 构建 Symbol 索引
self.symbol_index = await self._build_symbol_index(project_root)
# 3. 构建 Graph 索引
self.call_graph, self.dependency_graph = await self._build_graph_index(project_root)
# 4. 构建 RAG 索引(向量 + BM25)
await self._build_rag_index(chunks)
self._initialized = True
print(f"Context Engine initialized with {len(chunks)} chunks")
async def _build_chunk_index(self, project_root: str) -> List[Chunk]:
"""构建 Chunk 索引"""
from pathlib import Path
# 收集所有代码文件
code_files = []
for ext in ['.py', '.js', '.ts', '.go', '.rs', '.java', '.cpp', '.c', '.h']:
code_files.extend(Path(project_root).rglob(f'*{ext}'))
# 使用结构感知分块器
chunker = StructuralChunker()
all_chunks = []
for file_path in code_files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
code = f.read()
chunks = chunker.chunk_with_importance(
code,
str(file_path)
)
for chunk in chunks:
all_chunks.append(chunk)
except Exception as e:
print(f"Error processing {file_path}: {e}")
return all_chunks
async def _build_symbol_index(self, project_root: str) -> SymbolIndex:
"""构建 Symbol 索引"""
extractor = LSSPSymbolExtractor()
builder = SymbolIndexBuilder(extractor)
# 收集所有 Python 文件
code_files = list(Path(project_root).rglob('*.py'))
return builder.build_from_files([str(f) for f in code_files])
async def _build_graph_index(
self,
project_root: str
) -> Tuple[CallGraph, DependencyGraph]:
"""构建 Graph 索引"""
from pathlib import Path
# 构建调用图
call_graph_builder = StaticCallGraphBuilder()
call_graph = call_graph_builder.build_from_project(project_root)
# 构建依赖图
dep_graph_builder = DependencyGraphBuilder(project_root)
dependency_graph = dep_graph_builder.build()
return call_graph, dependency_graph
async def _build_rag_index(self, chunks: List[Chunk]):
"""构建 RAG 索引(向量 + BM25)"""
# 准备索引数据
chunk_data = [
{
'id': f"{c.file_path}:{c.start_line}",
'content': c.content,
'file_path': c.file_path,
'start_line': c.start_line,
'end_line': c.end_line,
'metadata': c.metadata
}
for c in chunks
]
# 构建混合检索器
self.retriever = HybridRetriever(
vector_weight=self.config.vector_weight,
bm25_weight=self.config.bm25_weight
)
self.retriever.index(chunk_data)
async def query(self, query: Query) -> str:
"""
处理查询并返回上下文
Args:
query: 用户查询
Returns:
组装好的上下文字符串
"""
# 1. 并行执行多策略检索
chunk_results = await self._parallel_search(query)
# 2. 分数融合与重排
scored_chunks = self._fuse_and_rerank(chunk_results, query)
# 3. 多阶段上下文选择
selector = MultiStageSelector(
max_chunks=self.config.max_chunks
)
selected = selector.select(
scored_chunks,
query.text,
query.file_path,
query.cursor_position or 0
)
# 4. 上下文组装
assembler = RAGContextAssembler(
max_tokens=self.config.max_context_tokens
)
search_results = [
SearchResult(
chunk_id=c.chunk_id,
content=c.content,
file_path=c.file_path,
start_line=c.start_line,
end_line=c.end_line,
score=c.score
)
for c in selected
]
context = assembler.assemble(search_results, query.text)
return context
async def _parallel_search(self, query: Query) -> List[ScoredChunk]:
"""并行执行多种检索策略"""
tasks = []
# RAG 检索
tasks.append(self._rag_search(query))
# Symbol 检索(如果适用)
if query.text:
tasks.append(self._symbol_search(query))
# Graph 检索(如果适用)
if query.file_path:
tasks.append(self._graph_search(query))
# 并行执行
results = await asyncio.gather(*tasks)
# 合并结果
merged = {}
for result_list in results:
for chunk in result_list:
if chunk.chunk_id not in merged:
merged[chunk.chunk_id] = chunk
else:
# 分数累加
merged[chunk.chunk_id].score += chunk.score
return list(merged.values())
async def _rag_search(self, query: Query) -> List[ScoredChunk]:
"""RAG 检索"""
search_results = self.retriever.search(
query.text,
top_k=self.config.max_chunks
)
return [
ScoredChunk(
chunk_id=r.chunk_id,
content=r.content,
file_path=r.file_path,
start_line=r.start_line,
end_line=r.end_line,
score=r.combined_score
)
for r in search_results
]
async def _symbol_search(self, query: Query) -> List[ScoredChunk]:
"""Symbol 检索"""
if not self.symbol_index:
return []
# 提取查询中的符号名
symbols = self.symbol_index.find_by_prefix(query.text)
chunks = []
for symbol in symbols[:10]: # 限制数量
chunks.append(ScoredChunk(
chunk_id=f"symbol:{symbol.name}",
content=f"{symbol.name}: {symbol.kind.name}",
file_path=symbol.file_path,
start_line=symbol.start_line,
end_line=symbol.end_line,
score=0.5 * self.config.symbol_weight
))
return chunks
async def _graph_search(self, query: Query) -> List[ScoredChunk]:
"""Graph 检索"""
if not self.call_graph or not query.file_path:
return []
chunks = []
# 查找直接调用者和被调用者
func_id = query.file_path
callers = self.call_graph.find_all_callers(func_id, max_depth=2)
callees = self.call_graph.find_all_callees(func_id, max_depth=2)
# 添加调用链上下文
for caller_id in list(callers)[:5]:
if ':' in caller_id:
file_path, func_name = caller_id.rsplit(':', 1)
chunks.append(ScoredChunk(
chunk_id=f"graph:{caller_id}",
content=f"Caller: {func_name}",
file_path=file_path,
start_line=0,
end_line=0,
score=0.3 * self.config.graph_weight
))
return chunks
def _fuse_and_rerank(
self,
chunks: List[ScoredChunk],
query: Query
) -> List[ScoredChunk]:
"""分数融合与重排"""
# 简单策略:直接按分数排序
# 实际可使用更复杂的融合策略(如 Reciprocal Rank Fusion)
chunks.sort(key=lambda c: c.score, reverse=True)
return chunks
def get_callers(self, file_path: str, function_name: str = None) -> List[str]:
"""获取函数的调用者"""
if not self.call_graph:
return []
func_id = f"{file_path}:{function_name}" if function_name else file_path
return list(self.call_graph.find_all_callers(func_id))
def get_callees(self, file_path: str, function_name: str = None) -> List[str]:
"""获取函数调用的其他函数"""
if not self.call_graph:
return []
func_id = f"{file_path}:{function_name}" if function_name else file_path
return list(self.call_graph.find_all_callees(func_id))async def main():
"""使用 Context Engine 的示例"""
# 1. 创建配置
config = ContextEngineConfig(
vector_weight=0.5,
bm25_weight=0.3,
symbol_weight=0.15,
graph_weight=0.05,
max_context_tokens=8000
)
# 2. 初始化 Context Engine
engine = ContextEngine(config)
await engine.initialize("/path/to/project")
# 3. 执行查询
query = Query(
text="How does the user authentication work?",
intent=QueryIntent.EXPLANATION,
file_path="/path/to/project/auth.py",
cursor_position=100
)
# 4. 获取上下文
context = await engine.query(query)
print(f"Retrieved context ({len(context)} chars):")
print(context)
# 5. 影响分析示例
callers = engine.get_callers("/path/to/project/auth.py", "authenticate")
print(f"Functions that call authenticate: {callers}")
if __name__ == "__main__":
asyncio.run(main())策略 | 核心能力 | 优势 | 局限 | 适用场景 |
|---|---|---|---|---|
Chunk | 代码切分 | 基础、通用 | 语义割裂 | 所有场景的预处理 |
Symbol | 结构索引 | 精确、类型感知 | 需要 LSP 支持 | 代码导航、补全 |
Graph | 关系建模 | 语义关联、影响分析 | 构建成本高 | 重构、调试 |
RAG | 语义检索 | 理解自然语言 | 计算密集 | 问答、代码生成 |
基于本文的分析,对于构建 AI IDE 的 Context Engine,提出以下建议:
参考链接:
附录(Appendix):
数据库 | 优点 | 缺点 | 适用规模 |
|---|---|---|---|
FAISS | 高效、Facebook 背书 | 仅支持向量 | < 100M |
Milvus | 云原生、支持混合检索 | 部署复杂 | > 100M |
Chroma | 轻量、易用 | 功能有限 | < 10M |
Qdrant | 高性能、支持过滤 | 相对较新 | > 100M |
# 推荐参数配置
bm25 = BM25(k1=1.5, b=0.75)
# 针对短查询优化
bm25_short = BM25(k1=2.0, b=0.5)
# 针对长文档优化
bm25_long = BM25(k1=1.2, b=0.9)策略组合 | 召回率@10 | 延迟 (ms) | 成本 ($/1K queries) |
|---|---|---|---|
BM25 Only | 0.65 | 15 | 0.05 |
Vector Only | 0.72 | 45 | 0.12 |
Hybrid (0.5/0.5) | 0.81 | 55 | 0.15 |
Hybrid + Graph | 0.85 | 70 | 0.18 |
测试环境:100K 代码块,平均块大小 100 行
关键词: Context Engine, Chunk, Symbol, Graph, RAG, AI IDE, 上下文检索, 混合检索, 代码索引, Token 优化