前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Marker 源码解析(二)

Marker 源码解析(二)

作者头像
ApacheCN_飞龙
发布2024-03-09 08:46:42
920
发布2024-03-09 08:46:42
举报
文章被收录于专栏:信数据得永生信数据得永生

.\marker\marker\models.py

代码语言:javascript
复制
# 从 marker.cleaners.equations 模块中导入 load_texify_model 函数
from marker.cleaners.equations import load_texify_model
# 从 marker.ordering 模块中导入 load_ordering_model 函数
from marker.ordering import load_ordering_model
# 从 marker.postprocessors.editor 模块中导入 load_editing_model 函数
from marker.postprocessors.editor import load_editing_model
# 从 marker.segmentation 模块中导入 load_layout_model 函数
from marker.segmentation import load_layout_model

# 定义一个函数用于加载所有模型
def load_all_models():
    # 调用 load_editing_model 函数,加载编辑模型
    edit = load_editing_model()
    # 调用 load_ordering_model 函数,加载排序模型
    order = load_ordering_model()
    # 调用 load_layout_model 函数,加载布局模型
    layout = load_layout_model()
    # 调用 load_texify_model 函数,加载 TeXify 模型
    texify = load_texify_model()
    # 将加载的模型按顺序存储在列表中
    model_lst = [texify, layout, order, edit]
    # 返回模型列表
    return model_lst

.\marker\marker\ocr\page.py

代码语言:javascript
复制
import io  # 导入io模块
from typing import List, Optional  # 导入类型提示相关模块

import fitz as pymupdf  # 导入fitz模块并重命名为pymupdf
import ocrmypdf  # 导入ocrmypdf模块
from spellchecker import SpellChecker  # 从spellchecker模块导入SpellChecker类

from marker.ocr.utils import detect_bad_ocr  # 从marker.ocr.utils模块导入detect_bad_ocr函数
from marker.schema import Block  # 从marker.schema模块导入Block类
from marker.settings import settings  # 从marker.settings模块导入settings变量

ocrmypdf.configure_logging(verbosity=ocrmypdf.Verbosity.quiet)  # 配置ocrmypdf的日志记录级别为quiet

# 对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    # 如果OCR_ENGINE设置为"tesseract",则调用ocr_entire_page_tess函数
    if settings.OCR_ENGINE == "tesseract":
        return ocr_entire_page_tess(page, lang, spellchecker)
    # 如果OCR_ENGINE设置为"ocrmypdf",则调用ocr_entire_page_ocrmp函数
    elif settings.OCR_ENGINE == "ocrmypdf":
        return ocr_entire_page_ocrmp(page, lang, spellchecker)
    else:
        raise ValueError(f"Unknown OCR engine {settings.OCR_ENGINE}")  # 抛出数值错误异常,显示未知的OCR引擎

# 使用tesseract对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_tess(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    try:
        # 获取页面的完整OCR文本页
        full_tp = page.get_textpage_ocr(flags=settings.TEXT_FLAGS, dpi=settings.OCR_DPI, full=True, language=lang)
        # 获取页面的文本块列表
        blocks = page.get_text("dict", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)["blocks"]
        # 获取页面的完整文本
        full_text = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)

        # 如果完整文本长度为0,则返回空列表
        if len(full_text) == 0:
            return []

        # 检查OCR是否成功。如果失败,返回空列表
        # 例如,如果有一张扫描的空白页上有一些淡淡的文本印记,OCR可能会失败
        if detect_bad_ocr(full_text, spellchecker):
            return []
    except RuntimeError:
        return []
    return blocks  # 返回文本块列表

# 使用ocrmypdf对整个页面进行OCR识别,返回Block对象列表
def ocr_entire_page_ocrmp(page, lang: str, spellchecker: Optional[SpellChecker] = None) -> List[Block]:
    # 使用ocrmypdf获取整个页面的OCR文本
    src = page.parent  # 页面所属文档
    blank_doc = pymupdf.open()  # 创建临时的1页文档
    blank_doc.insert_pdf(src, from_page=page.number, to_page=page.number, annots=False, links=False)  # 插入PDF页面
    pdfbytes = blank_doc.tobytes()  # 获取文档字节流
    inbytes = io.BytesIO(pdfbytes)  # 转换为BytesIO对象
    # 创建一个字节流对象,用于存储 ocrmypdf 处理后的结果 PDF
    outbytes = io.BytesIO()  # let ocrmypdf store its result pdf here
    # 使用 ocrmypdf 进行 OCR 处理
    ocrmypdf.ocr(
        inbytes,
        outbytes,
        language=lang,
        output_type="pdf",
        redo_ocr=None if settings.OCR_ALL_PAGES else True,
        force_ocr=True if settings.OCR_ALL_PAGES else None,
        progress_bar=False,
        optimize=False,
        fast_web_view=1e6,
        skip_big=15, # skip images larger than 15 megapixels
        tesseract_timeout=settings.TESSERACT_TIMEOUT,
        tesseract_non_ocr_timeout=settings.TESSERACT_TIMEOUT,
    )
    # 以 fitz PDF 格式打开 OCR 处理后的输出
    ocr_pdf = pymupdf.open("pdf", outbytes.getvalue())  # read output as fitz PDF
    # 获取 OCR 处理后的文本块信息
    blocks = ocr_pdf[0].get_text("dict", sort=True, flags=settings.TEXT_FLAGS)["blocks"]
    # 获取 OCR 处理后的完整文本
    full_text = ocr_pdf[0].get_text("text", sort=True, flags=settings.TEXT_FLAGS)

    # 确保原始 PDF/EPUB/MOBI 的边界框和 OCR 处理后的 PDF 的边界框相同
    assert page.bound() == ocr_pdf[0].bound()

    # 如果完整文本为空,则返回空列表
    if len(full_text) == 0:
        return []

    # 如果检测到 OCR 处理不良,则返回空列表
    if detect_bad_ocr(full_text, spellchecker):
        return []

    # 返回文本块信息
    return blocks

.\marker\marker\ocr\utils.py

代码语言:javascript
复制
# 导入必要的模块和类
from typing import Optional
from nltk import wordpunct_tokenize
from spellchecker import SpellChecker
from marker.settings import settings
import re

# 检测 OCR 文本质量是否差,返回布尔值
def detect_bad_ocr(text, spellchecker: Optional[SpellChecker], misspell_threshold=.7, space_threshold=.6, newline_threshold=.5, alphanum_threshold=.4):
    # 如果文本长度为0,则假定 OCR 失败
    if len(text) == 0:
        return True

    # 使用 wordpunct_tokenize 函数将文本分词
    words = wordpunct_tokenize(text)
    # 过滤掉空白字符
    words = [w for w in words if w.strip()]
    # 提取文本中的字母数字字符
    alpha_words = [word for word in words if word.isalnum()]

    # 如果提供了拼写检查器
    if spellchecker:
        # 检查文本中的拼写错误
        misspelled = spellchecker.unknown(alpha_words)
        # 如果拼写错误数量超过阈值,则返回 True
        if len(misspelled) > len(alpha_words) * misspell_threshold:
            return True

    # 计算文本中空格的数量
    spaces = len(re.findall(r'\s+', text))
    # 计算文本中字母字符的数量
    alpha_chars = len(re.sub(r'\s+', '', text))
    # 如果空格占比超过阈值,则返回 True
    if spaces / (alpha_chars + spaces) > space_threshold:
        return True

    # 计算文本中换行符的数量
    newlines = len(re.findall(r'\n+', text))
    # 计算文本中非换行符的数量
    non_newlines = len(re.sub(r'\n+', '', text))
    # 如果换行符占比超过阈值,则返回 True
    if newlines / (newlines + non_newlines) > newline_threshold:
        return True

    # 如果文本中字母数字字符比例低于阈值,则返回 True
    if alphanum_ratio(text) < alphanum_threshold: # Garbled text
        return True

    # 计算文本中无效字符的数量
    invalid_chars = len([c for c in text if c in settings.INVALID_CHARS])
    # 如果无效字符数量超过阈值,则返回 True
    if invalid_chars > max(3.0, len(text) * .02):
        return True

    # 默认情况下返回 False
    return False

# 将字体标志拆解为可读的形式
def font_flags_decomposer(flags):
    l = []
    # 检查字体标志中是否包含上标
    if flags & 2 ** 0:
        l.append("superscript")
    # 检查字体标志中是否包含斜体
    if flags & 2 ** 1:
        l.append("italic")
    # 检查字体标志中是否包含衬线
    if flags & 2 ** 2:
        l.append("serifed")
    else:
        l.append("sans")
    # 检查字体标志中是否包含等宽字体
    if flags & 2 ** 3:
        l.append("monospaced")
    else:
        l.append("proportional")
    # 检查字体标志中是否包含粗体
    if flags & 2 ** 4:
        l.append("bold")
    # 返回拆解后的字体标志字符串
    return "_".join(l)

# 计算文本中字母数字字符的比例
def alphanum_ratio(text):
    # 去除文本中的空格和换行符
    text = text.replace(" ", "")
    text = text.replace("\n", "")
    # 统计文本中的字母数字字符数量
    alphanumeric_count = sum([1 for c in text if c.isalnum()])

    # 如果文本长度为0,则返回1
    if len(text) == 0:
        return 1

    # 计算字母数字字符比例
    ratio = alphanumeric_count / len(text)
    # 返回变量 ratio 的值
    return ratio

.\marker\marker\ordering.py

代码语言:javascript
复制
# 导入必要的模块
from copy import deepcopy
from typing import List
import torch
import sys, os
from marker.extract_text import convert_single_page
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3Processor
from PIL import Image
import io
from marker.schema import Page
from marker.settings import settings

# 从设置中加载 LayoutLMv3Processor 模型
processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME)

# 加载 LayoutLMv3ForSequenceClassification 模型
def load_ordering_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(
        settings.ORDERER_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()
    return model

# 获取推理数据
def get_inference_data(page, page_blocks: Page):
    # 深拷贝页面块的边界框
    bboxes = deepcopy([block.bbox for block in page_blocks.blocks])
    # 初始化单词列表
    words = ["."] * len(bboxes)

    # 获取页面的像素图像
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    # 将像素图像转换为 PNG 格式
    png = pix.pil_tobytes(format="PNG")
    # 将 PNG 数据转换为 RGB 图像
    rgb_image = Image.open(io.BytesIO(png)).convert("RGB")

    # 获取页面块的边界框和宽高
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height

    # 调整边界框的值
    for box in bboxes:
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]

        # 将边界框的值转换为相对于页面宽高的比例
        box[0] = int(box[0] / pwidth * 1000)
        box[1] = int(box[1] / pheight * 1000)
        box[2] = int(box[2] / pwidth * 1000)
        box[3] = int(box[3] / pheight * 1000)

    return rgb_image, bboxes, words

# 批量推理
def batch_inference(rgb_images, bboxes, words, model):
    # 对 RGB 图像、单词和边界框进行编码
    encoding = processor(
        rgb_images,
        text=words,
        boxes=bboxes,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )

    # 将像素值转换为模型的数据类型
    encoding["pixel_values"] = encoding["pixel_values"].to(model.dtype)
    # 进入推断模式,不进行梯度计算
    with torch.inference_mode():
        # 将指定的键对应的值移动到模型所在设备上
        for k in ["bbox", "input_ids", "pixel_values", "attention_mask"]:
            encoding[k] = encoding[k].to(model.device)
        # 使用模型进行推理,获取输出
        outputs = model(**encoding)
        # 获取模型输出的预测结果
        logits = outputs.logits

    # 获取预测结果中概率最大的类别索引,并转换为列表
    predictions = logits.argmax(-1).squeeze().tolist()
    # 如果预测结果是整数,则转换为列表
    if isinstance(predictions, int):
        predictions = [predictions]
    # 将预测结果转换为类别标签
    predictions = [model.config.id2label[p] for p in predictions]
    # 返回预测结果
    return predictions
# 为文档中的每个块添加列数计数
def add_column_counts(doc, doc_blocks, model, batch_size):
    # 按照批量大小遍历文档块
    for i in range(0, len(doc_blocks), batch_size):
        # 创建当前批量的索引范围
        batch = range(i, min(i + batch_size, len(doc_blocks)))
        # 初始化空列表用于存储 RGB 图像、边界框和单词
        rgb_images = []
        bboxes = []
        words = []
        # 遍历当前批量的页码
        for pnum in batch:
            # 获取推理数据:RGB 图像、页边界框和页单词
            page = doc[pnum]
            rgb_image, page_bboxes, page_words = get_inference_data(page, doc_blocks[pnum])
            rgb_images.append(rgb_image)
            bboxes.append(page_bboxes)
            words.append(page_words)

        # 进行批量推理,获取预测结果
        predictions = batch_inference(rgb_images, bboxes, words, model)
        # 将预测结果与页码对应,更新文档块的列数计数
        for pnum, prediction in zip(batch, predictions):
            doc_blocks[pnum].column_count = prediction

# 对文档块进行排序
def order_blocks(doc, doc_blocks: List[Page], model, batch_size=settings.ORDERER_BATCH_SIZE):
    # 添加列数计数
    add_column_counts(doc, doc_blocks, model, batch_size)

    # 遍历文档块中的每一页
    for page_blocks in doc_blocks:
        # 如果该页的列数大于1
        if page_blocks.column_count > 1:
            # 根据位置重新排序块
            split_pos = page_blocks.x_start + page_blocks.width / 2
            left_blocks = []
            right_blocks = []
            # 遍历该页的每个块
            for block in page_blocks.blocks:
                # 根据位置将块分为左右两部分
                if block.x_start <= split_pos:
                    left_blocks.append(block)
                else:
                    right_blocks.append(block)
            # 更新该页的块顺序
            page_blocks.blocks = left_blocks + right_blocks
    # 返回排序后的文档块
    return doc_blocks

.\marker\marker\postprocessors\editor.py

代码语言:javascript
复制
# 导入必要的库
from collections import defaultdict, Counter
from itertools import chain
from typing import Optional

# 导入 transformers 库中的 AutoTokenizer 类
from transformers import AutoTokenizer

# 导入 settings 模块中的 settings 变量
from marker.settings import settings

# 导入 torch 库
import torch
import torch.nn.functional as F

# 导入 marker.postprocessors.t5 模块中的 T5ForTokenClassification 类和 byt5_tokenize 函数
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize

# 定义加载编辑模型的函数
def load_editing_model():
    # 如果未启用编辑模型,则返回 None
    if not settings.ENABLE_EDITOR_MODEL:
        return None

    # 从预训练模型中加载 T5ForTokenClassification 模型
    model = T5ForTokenClassification.from_pretrained(
            settings.EDITOR_MODEL_NAME,
            torch_dtype=settings.MODEL_DTYPE,
        ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()

    # 配置模型的标签映射
    model.config.label2id = {
        "equal": 0,
        "delete": 1,
        "newline-1": 2,
        "space-1": 3,
    }
    model.config.id2label = {v: k for k, v in model.config.label2id.items()}
    return model

# 定义编辑全文的函数
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_size: int = settings.EDITOR_BATCH_SIZE):
    # 如果模型为空,则直接返回原始文本和空字典
    if not model:
        return text, {}

    # 对文本进行 tokenization
    tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
    input_ids = tokenized["input_ids"]
    char_token_lengths = tokenized["char_token_lengths"]

    # 准备 token_masks 列表
    token_masks = []
    # 遍历输入的 input_ids,按照 batch_size 进行分批处理
    for i in range(0, len(input_ids), batch_size):
        # 从 tokenized 中获取当前 batch 的 input_ids
        batch_input_ids = tokenized["input_ids"][i: i + batch_size]
        # 将 batch_input_ids 转换为 torch 张量,并指定设备为 model 的设备
        batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
        # 从 tokenized 中获取当前 batch 的 attention_mask
        batch_attention_mask = tokenized["attention_mask"][i: i + batch_size]
        # 将 batch_attention_mask 转换为 torch 张量,并指定设备为 model 的设备
        batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
        
        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行预测
            predictions = model(batch_input_ids, attention_mask=batch_attention_mask)

        # 将预测结果 logits 移动到 CPU 上
        logits = predictions.logits.cpu()

        # 如果最大概率小于阈值,则假设为不良预测
        # 我们希望保守一点,不要对文本进行过多编辑
        probs = F.softmax(logits, dim=-1)
        max_prob = torch.max(probs, dim=-1)
        cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
        labels = logits.argmax(-1)
        labels[cutoff_prob] = model.config.label2id["equal"]
        labels = labels.squeeze().tolist()
        if len(labels) == settings.EDITOR_MAX_LENGTH:
            labels = [labels]
        labels = list(chain.from_iterable(labels))
        token_masks.extend(labels)

    # 文本中的字符列表
    flat_input_ids = list(chain.from_iterable(input_ids)

    # 去除特殊标记 0,1。保留未知标记,尽管它不应该被使用
    assert len(token_masks) == len(flat_input_ids)
    token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]

    # 确保 token_masks 的长度与文本编码后的长度相等
    assert len(token_masks) == len(list(text.encode("utf-8")))

    # 统计编辑次数的字典
    edit_stats = defaultdict(int)
    # 输出文本列表
    out_text = []
    # 起始位置
    start = 0
    # 遍历文本中的每个字符及其索引
    for i, char in enumerate(text):
        # 获取当前字符对应的 token 长度
        char_token_length = char_token_lengths[i]
        # 获取当前字符对应的 token 的 mask
        masks = token_masks[start: start + char_token_length]
        # 将 mask 转换为标签
        labels = [model.config.id2label[mask] for mask in masks]
        # 如果所有标签都是 "delete",则执行删除操作
        if all(l == "delete" for l in labels):
            # 如果删除的是空格,则保留,否则忽略
            if char.strip():
                out_text.append(char)
            else:
                edit_stats["delete"] += 1
        # 如果标签为 "newline-1",则添加换行符
        elif labels[0] == "newline-1":
            out_text.append("\n")
            out_text.append(char)
            edit_stats["newline-1"] += 1
        # 如果标签为 "space-1",则添加空格
        elif labels[0] == "space-1":
            out_text.append(" ")
            out_text.append(char)
            edit_stats["space-1"] += 1
        # 如果标签为其他情况,则保留字符
        else:
            out_text.append(char)
            edit_stats["equal"] += 1

        # 更新下一个字符的起始位置
        start += char_token_length

    # 将处理后的文本列表转换为字符串
    out_text = "".join(out_text)
    # 返回处理后的文本及编辑统计信息
    return out_text, edit_stats

.\marker\marker\postprocessors\t5.py

代码语言:javascript
复制
# 从 transformers 库中导入 T5Config 和 T5PreTrainedModel 类
from transformers import T5Config, T5PreTrainedModel
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 copy 库中导入 deepcopy 函数
from copy import deepcopy
# 从 typing 库中导入 Optional, Tuple, Union, List 类型
from typing import Optional, Tuple, Union, List
# 从 itertools 库中导入 chain 函数
from itertools import chain

# 从 transformers.modeling_outputs 模块中导入 TokenClassifierOutput 类
from transformers.modeling_outputs import TokenClassifierOutput
# 从 transformers.models.t5.modeling_t5 模块中导入 T5Stack 类
from transformers.models.t5.modeling_t5 import T5Stack
# 从 transformers.utils.model_parallel_utils 模块中导入 get_device_map, assert_device_map 函数
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map

# 定义一个函数,用于将文本进行字节编码并分词
def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0):
    # 初始化一个空列表,用于存储字节编码
    byte_codes = []
    # 遍历文本中的每个字符
    for char in text:
        # 将每个字符进行 UTF-8 编码,并加上 3 以考虑特殊标记
        byte_codes.append([byte + 3 for byte in char.encode('utf-8')])

    # 将字节编码展开成一个列表
    tokens = list(chain.from_iterable(byte_codes))
    # 记录每个字符对应的 token 长度
    char_token_lengths = [len(b) for b in byte_codes]

    # 初始化批量 token 和注意力掩码列表
    batched_tokens = []
    attention_mask = []
    # 按照最大长度将 token 进行分批
    for i in range(0, len(tokens), max_length):
        batched_tokens.append(tokens[i:i + max_length])
        attention_mask.append([1] * len(batched_tokens[-1])

    # 对最后一个批次进行填充
    if len(batched_tokens[-1]) < max_length:
        batched_tokens[-1] += [pad_token_id] * (max_length - len(batched_tokens[-1]))
        attention_mask[-1] += [0] * (max_length - len(attention_mask[-1]))

    # 返回包含分词结果的字典
    return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths}

# 定义一个 T5ForTokenClassification 类,继承自 T5PreTrainedModel 类
class T5ForTokenClassification(T5PreTrainedModel):
    # 定义一个列表,用于指定加载时忽略的键
    _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
    # 初始化函数,接受一个T5Config对象作为参数
    def __init__(self, config: T5Config):
        # 调用父类的初始化函数
        super().__init__(config)
        # 设置模型维度为配置中的d_model值
        self.model_dim = config.d_model

        # 创建一个共享的嵌入层,词汇表大小为config.vocab_size,维度为config.d_model
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        # 复制配置对象,用于创建编码器
        encoder_config = deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.is_encoder_decoder = False
        encoder_config.use_cache = False
        # 创建T5Stack编码器
        self.encoder = T5Stack(encoder_config, self.shared)

        # 设置分类器的dropout值
        classifier_dropout = (
            config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate
        )
        self.dropout = nn.Dropout(classifier_dropout)
        # 创建一个线性层,输入维度为config.d_model,输出维度为config.num_labels
        self.classifier = nn.Linear(config.d_model, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

        # 模型并行化
        self.model_parallel = False
        self.device_map = None


    # 并行化函数,接受一个设备映射device_map作为参数
    def parallelize(self, device_map=None):
        # 如果未提供device_map,则根据编码器块的数量和GPU数量生成一个默认的device_map
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        # 检查设备映射的有效性
        assert_device_map(self.device_map, len(self.encoder.block))
        # 将编码器并行化
        self.encoder.parallelize(self.device_map)
        # 将分类器移动到编码器的第一个设备上
        self.classifier.to(self.encoder.first_device)
        self.model_parallel = True

    # 反并行化函数
    def deparallelize(self):
        # 取消编码器的并行化
        self.encoder.deparallelize()
        # 将编码器和分类器移动到CPU上
        self.encoder = self.encoder.to("cpu")
        self.classifier = self.classifier.to("cpu")
        self.model_parallel = False
        self.device_map = None
        # 释放GPU缓存
        torch.cuda.empty_cache()

    # 获取输入嵌入层函数
    def get_input_embeddings(self):
        return self.shared

    # 设置输入嵌入层函数,接受一个新的嵌入层new_embeddings作为参数
    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        # 设置编码器的输入嵌入层为新的嵌入层
        self.encoder.set_input_embeddings(new_embeddings)

    # 获取编码器函数
    def get_encoder(self):
        return self.encoder
    # 对模型中的特定头部进行修剪
    def _prune_heads(self, heads_to_prune):
        # 遍历需要修剪的层和头部
        for layer, heads in heads_to_prune.items():
            # 调用 SelfAttention 模块的 prune_heads 方法进行修剪
            self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)

    # 前向传播函数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
        # 如果 return_dict 为 None,则使用配置中的 use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用编码器进行前向传播
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取序列输出
        sequence_output = outputs[0]

        # 对序列输出进行 dropout
        sequence_output = self.dropout(sequence_output)
        # 将序列输出传入分类器得到 logits
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None

        # 如果不使用 return_dict,则返回输出结果
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 使用 TokenClassifierOutput 类返回结果
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions
        )

.\marker\marker\schema.py

代码语言:javascript
复制
# 导入 Counter 类,用于计数
# 导入 List、Optional、Tuple 类型,用于类型提示
from collections import Counter
from typing import List, Optional, Tuple

# 导入 BaseModel、field_validator 类,用于定义数据模型和字段验证
# 导入 ftfy 模块,用于修复文本中的 Unicode 错误
from pydantic import BaseModel, field_validator
import ftfy

# 导入 boxes_intersect_pct、multiple_boxes_intersect 函数,用于计算两个框的交集比例和多个框的交集情况
# 导入 settings 模块,用于获取配置信息
from marker.bbox import boxes_intersect_pct, multiple_boxes_intersect
from marker.settings import settings

# 定义函数 find_span_type,用于查找给定 span 在页面块中的类型
def find_span_type(span, page_blocks):
    # 默认块类型为 "Text"
    block_type = "Text"
    # 遍历页面块列表
    for block in page_blocks:
        # 如果 span 的边界框与页面块的边界框有交集
        if boxes_intersect_pct(span.bbox, block.bbox):
            # 更新块类型为页面块的类型
            block_type = block.block_type
            break
    # 返回块类型
    return block_type

# 定义类 BboxElement,继承自 BaseModel 类,表示具有边界框的元素
class BboxElement(BaseModel):
    bbox: List[float]

    # 验证 bbox 字段是否包含 4 个元素
    @field_validator('bbox')
    @classmethod
    def check_4_elements(cls, v: List[float]) -> List[float]:
        if len(v) != 4:
            raise ValueError('bbox must have 4 elements')
        return v

    # 计算元素的高度、宽度、起始 x 坐标、起始 y 坐标、面积
    @property
    def height(self):
        return self.bbox[3] - self.bbox[1]

    @property
    def width(self):
        return self.bbox[2] - self.bbox[0]

    @property
    def x_start(self):
        return self.bbox[0]

    @property
    def y_start(self):
        return self.bbox[1]

    @property
    def area(self):
        return self.width * self.height

# 定义类 BlockType,继承自 BboxElement 类,表示具有块类型的元素
class BlockType(BboxElement):
    block_type: str

# 定义类 Span,继承自 BboxElement 类,表示具有文本内容的元素
class Span(BboxElement):
    text: str
    span_id: str
    font: str
    color: int
    ascender: Optional[float] = None
    descender: Optional[float] = None
    block_type: Optional[str] = None
    selected: bool = True

    # 修复文本中的 Unicode 错误
    @field_validator('text')
    @classmethod
    def fix_unicode(cls, text: str) -> str:
        return ftfy.fix_text(text)

# 定义类 Line,继承自 BboxElement 类,表示具有多个 Span 的行元素
class Line(BboxElement):
    spans: List[Span]

    # 获取行的预备文本,即所有 Span 的文本拼接而成
    @property
    def prelim_text(self):
        return "".join([s.text for s in self.spans])

    # 获取行的起始 x 坐标
    @property
    def start(self):
        return self.spans[0].bbox[0]

# 定义类 Block,继承自 BboxElement 类,表示具有多个 Line 的块元素
class Block(BboxElement):
    lines: List[Line]
    pnum: int

    # 获取块的预备文本,即所有 Line 的预备文本拼接而成
    @property
    def prelim_text(self):
        return "\n".join([l.prelim_text for l in self.lines])
    # 检查文本块是否包含公式,通过检查每个 span 的 block_type 是否为 "Formula" 来确定
    def contains_equation(self, equation_boxes=None):
        # 生成一个包含每个 span 的 block_type 是否为 "Formula" 的条件列表
        conditions = [s.block_type == "Formula" for l in self.lines for s in l.spans]
        # 如果提供了 equation_boxes 参数,则添加一个条件,检查文本块的边界框是否与给定框相交
        if equation_boxes:
            conditions += [multiple_boxes_intersect(self.bbox, equation_boxes)]
        # 返回条件列表中是否有任何一个条件为 True
        return any(conditions)

    # 过滤掉包含在 bad_span_ids 中的 span
    def filter_spans(self, bad_span_ids):
        new_lines = []
        for line in self.lines:
            new_spans = []
            for span in line.spans:
                # 如果 span 的 span_id 不在 bad_span_ids 中,则保留该 span
                if not span.span_id in bad_span_ids:
                    new_spans.append(span)
            # 更新 line 的 spans 属性为过滤后的 new_spans
            line.spans = new_spans
            # 如果 line 中仍有 spans,则将其添加到 new_lines 中
            if len(new_spans) > 0:
                new_lines.append(line)
        # 更新 self.lines 为过滤后的 new_lines
        self.lines = new_lines

    # 过滤掉包含在 settings.BAD_SPAN_TYPES 中的 span 的 block_type
    def filter_bad_span_types(self):
        new_lines = []
        for line in self.lines:
            new_spans = []
            for span in line.spans:
                # 如果 span 的 block_type 不在 BAD_SPAN_TYPES 中,则保留该 span
                if span.block_type not in settings.BAD_SPAN_TYPES:
                    new_spans.append(span)
            # 更新 line 的 spans 属性为过滤后的 new_spans
            line.spans = new_spans
            # 如果 line 中仍有 spans,则将其添加到 new_lines 中
            if len(new_spans) > 0:
                new_lines.append(line)
        # 更新 self.lines 为过滤后的 new_lines
        self.lines = new_lines

    # 返回文本块中出现频率最高的 block_type
    def most_common_block_type(self):
        # 统计每个 span 的 block_type 出现的次数
        counter = Counter([s.block_type for l in self.lines for s in l.spans])
        # 返回出现次数最多的 block_type
        return counter.most_common(1)[0][0]

    # 设置文本块中所有 span 的 block_type 为给定的 block_type
    def set_block_type(self, block_type):
        for line in self.lines:
            for span in line.spans:
                # 将 span 的 block_type 设置为给定的 block_type
                span.block_type = block_type
# 定义一个名为 Page 的类,继承自 BboxElement 类
class Page(BboxElement):
    # 类属性:blocks 为 Block 对象列表,pnum 为整数,column_count 和 rotation 可选整数,默认为 None
    blocks: List[Block]
    pnum: int
    column_count: Optional[int] = None
    rotation: Optional[int] = None # 页面的旋转角度

    # 获取页面中非空行的方法
    def get_nonblank_lines(self):
        # 获取页面中所有行
        lines = self.get_all_lines()
        # 过滤出非空行
        nonblank_lines = [l for l in lines if l.prelim_text.strip()]
        return nonblank_lines

    # 获取页面中所有行的方法
    def get_all_lines(self):
        # 获取页面中所有行的列表
        lines = [l for b in self.blocks for l in b.lines]
        return lines

    # 获取页面中非空跨度的方法,返回 Span 对象列表
    def get_nonblank_spans(self) -> List[Span]:
        # 获取页面中所有行
        lines = [l for b in self.blocks for l in b.lines]
        # 过滤出非空跨度
        spans = [s for l in lines for s in l.spans if s.text.strip()]
        return spans

    # 添加块类型到行的方法
    def add_block_types(self, page_block_types):
        # 如果检测到的块类型数量与页面行数不匹配,则打印警告信息
        if len(page_block_types) != len(self.get_all_lines()):
            print(f"Warning: Number of detected lines {len(page_block_types)} does not match number of lines {len(self.get_all_lines())}")

        i = 0
        for block in self.blocks:
            for line in block.lines:
                if i < len(page_block_types):
                    line_block_type = page_block_types[i].block_type
                else:
                    line_block_type = "Text"
                i += 1
                for span in line.spans:
                    span.block_type = line_block_type

    # 获取页面中字体统计信息的方法
    def get_font_stats(self):
        # 获取页面中非空跨度的字体信息
        fonts = [s.font for s in self.get_nonblank_spans()]
        # 统计字体出现次数
        font_counts = Counter(fonts)
        return font_counts

    # 获取页面中行高统计信息的方法
    def get_line_height_stats(self):
        # 获取页面中非空行的行高信息
        heights = [l.bbox[3] - l.bbox[1] for l in self.get_nonblank_lines()]
        # 统计行高出现次数
        height_counts = Counter(heights)
        return height_counts

    # 获取页面中行起始位置统计信息的方法
    def get_line_start_stats(self):
        # 获取页面中非空行的起始位置信息
        starts = [l.bbox[0] for l in self.get_nonblank_lines()]
        # 统计起始位置出现次数
        start_counts = Counter(starts)
        return start_counts
    # 获取文本块中非空行的起始位置列表
    def get_min_line_start(self):
        # 通过列表推导式获取非空行的起始位置,并且该行为文本类型
        starts = [l.bbox[0] for l in self.get_nonblank_lines() if l.spans[0].block_type == "Text"]
        # 如果没有找到非空行,则抛出索引错误
        if len(starts) == 0:
            raise IndexError("No lines found")
        # 返回起始位置列表中的最小值
        return min(starts)

    # 获取文本块中每个文本块的 prelim_text 属性,并用换行符连接成字符串
    @property
    def prelim_text(self):
        return "\n".join([b.prelim_text for b in self.blocks])
# 定义一个继承自BboxElement的MergedLine类,包含文本和字体列表属性
class MergedLine(BboxElement):
    text: str
    fonts: List[str]

    # 返回该行中出现频率最高的字体
    def most_common_font(self):
        # 统计字体列表中各个字体出现的次数
        counter = Counter(self.fonts)
        # 返回出现频率最高的字体
        return counter.most_common(1)[0][0]


# 定义一个继承自BboxElement的MergedBlock类,包含行列表、段落号和块类型列表属性
class MergedBlock(BboxElement):
    lines: List[MergedLine]
    pnum: int
    block_types: List[str]

    # 返回该块中出现频率最高的块类型
    def most_common_block_type(self):
        # 统计块类型列表中各个类型出现的次数
        counter = Counter(self.block_types)
        # 返回出现频率最高的块类型
        return counter.most_common(1)[0][0]


# 定义一个继承自BaseModel的FullyMergedBlock类,包含文本和块类型属性
class FullyMergedBlock(BaseModel):
    text: str
    block_type: str

.\marker\marker\segmentation.py

代码语言:javascript
复制
# 导入所需的库
from concurrent.futures import ThreadPoolExecutor
from typing import List

from transformers import LayoutLMv3ForTokenClassification

# 导入自定义的模块
from marker.bbox import unnormalize_box
from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box
import io
from PIL import Image
from transformers import LayoutLMv3Processor
import numpy as np
from marker.settings import settings
from marker.schema import Page, BlockType
import torch
from math import isclose

# 设置图像最大像素值,避免部分图像被截断
Image.MAX_IMAGE_PIXELS = None

# 从预训练模型加载 LayoutLMv3Processor
processor = LayoutLMv3Processor.from_pretrained(settings.LAYOUT_MODEL_NAME, apply_ocr=False)

# 定义需要分块的键和不需要分块的键
CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
NO_CHUNK_KEYS = ["pixel_values"]

# 加载 LayoutLMv3ForTokenClassification 模型
def load_layout_model():
    # 从预训练模型加载 LayoutLMv3ForTokenClassification 模型
    model = LayoutLMv3ForTokenClassification.from_pretrained(
        settings.LAYOUT_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)

    # 设置模型的标签映射
    model.config.id2label = {
        0: "Caption",
        1: "Footnote",
        2: "Formula",
        3: "List-item",
        4: "Page-footer",
        5: "Page-header",
        6: "Picture",
        7: "Section-header",
        8: "Table",
        9: "Text",
        10: "Title"
    }

    model.config.label2id = {v: k for k, v in model.config.id2label.items()}
    return model

# 检测文档块类型
def detect_document_block_types(doc, blocks: List[Page], layoutlm_model, batch_size=settings.LAYOUT_BATCH_SIZE):
    # 获取特征编码、元数据和样本长度
    encodings, metadata, sample_lengths = get_features(doc, blocks)
    # 预测块类型
    predictions = predict_block_types(encodings, layoutlm_model, batch_size)
    # 将预测结果与框匹配
    block_types = match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model)
    # 断言块类型数量与块数量相等
    assert len(block_types) == len(blocks)
    return block_types

# 获取临时框
def get_provisional_boxes(pred, box, is_subword, start_idx=0):
    # 从预测结果中获取临时框
    prov_predictions = [pred_ for idx, pred_ in enumerate(pred) if not is_subword[idx]][start_idx:]
    # 从列表中筛选出不是子词的元素,并从指定索引开始切片,得到新的列表
    prov_boxes = [box_ for idx, box_ in enumerate(box) if not is_subword[idx]][start_idx:]
    # 返回处理后的预测结果和框
    return prov_predictions, prov_boxes
# 获取页面编码信息,输入参数为页面和页面块对象
def get_page_encoding(page, page_blocks: Page):
    # 如果页面块中的所有行数为0,则返回空列表
    if len(page_blocks.get_all_lines()) == 0:
        return [], []

    # 获取页面块的边界框、宽度和高度
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height

    # 获取页面块的像素图,并转换为 PNG 格式
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 如果图像太大,则缩小以适应模型
    rgb_image = png_image.convert('RGB')
    rgb_width, rgb_height = rgb_image.size

    # 确保图像大小与 PDF 页面的比例正确
    assert isclose(rgb_width / pwidth, rgb_height / pheight, abs_tol=2e-2)

    # 获取页面块中的所有行
    lines = page_blocks.get_all_lines()

    boxes = []
    text = []
    for line in lines:
        box = line.bbox
        # 处理边界框溢出的情况
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]

        # 处理边界框宽度或高度为0或负值的情况
        if box[2] <= box[0]:
            print("Zero width box found, cannot convert properly")
            raise ValueError
        if box[3] <= box[1]:
            print("Zero height box found, cannot convert properly")
            raise ValueError
        boxes.append(box)
        text.append(line.prelim_text)

    # 将边界框归一化为模型(缩放为1000x1000)
    boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
    for box in boxes:
        # 验证所有边界框都是有效的
        assert(len(box) == 4)
        assert(max(box)) <= 1000
        assert(min(box)) >= 0
    # 使用 processor 处理 RGB 图像,传入文本、框、返回偏移映射等参数
    encoding = processor(
        rgb_image,
        text=text,
        boxes=boxes,
        return_offsets_mapping=True,
        truncation=True,
        return_tensors="pt",
        stride=settings.LAYOUT_CHUNK_OVERLAP,
        padding="max_length",
        max_length=settings.LAYOUT_MODEL_MAX,
        return_overflowing_tokens=True
    )
    # 从 encoding 中弹出 offset_mapping 和 overflow_to_sample_mapping
    offset_mapping = encoding.pop('offset_mapping')
    overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
    # 将 encoding 中的 bbox、input_ids、attention_mask、pixel_values 转换为列表
    bbox = list(encoding["bbox"])
    input_ids = list(encoding["input_ids"])
    attention_mask = list(encoding["attention_mask"])
    pixel_values = list(encoding["pixel_values"])

    # 断言各列表长度相等
    assert len(bbox) == len(input_ids) == len(attention_mask) == len(pixel_values) == len(offset_mapping)

    # 将各列表中的元素组成字典,放入 list_encoding 列表中
    list_encoding = []
    for i in range(len(bbox)):
        list_encoding.append({
            "bbox": bbox[i],
            "input_ids": input_ids[i],
            "attention_mask": attention_mask[i],
            "pixel_values": pixel_values[i],
            "offset_mapping": offset_mapping[i]
        })

    # 其他数据包括原始框、pwidth 和 pheight
    other_data = {
        "original_bbox": boxes,
        "pwidth": pwidth,
        "pheight": pheight,
    }
    # 返回 list_encoding 和 other_data
    return list_encoding, other_data
# 获取文档的特征信息
def get_features(doc, blocks):
    # 初始化编码、元数据和样本长度列表
    encodings = []
    metadata = []
    sample_lengths = []
    # 遍历每个块
    for i in range(len(blocks)):
        # 调用函数获取页面编码和其他数据
        encoding, other_data = get_page_encoding(doc[i], blocks[i])
        # 将页面编码添加到编码列表中
        encodings.extend(encoding)
        # 将其他数据添加到元数据列表中
        metadata.append(other_data)
        # 记录当前页面编码的长度
        sample_lengths.append(len(encoding))
    # 返回编码、元数据和样本长度
    return encodings, metadata, sample_lengths


# 预测块类型
def predict_block_types(encodings, layoutlm_model, batch_size):
    # 初始化所有预测结果列表
    all_predictions = []
    # 按批次处理编码
    for i in range(0, len(encodings), batch_size):
        # 计算当前批次的起始和结束索引
        batch_start = i
        batch_end = min(i + batch_size, len(encodings))
        # 获取当前批次的编码
        batch = encodings[batch_start:batch_end]

        # 构建模型输入
        model_in = {}
        for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
            model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)

        model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)

        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行推理
            outputs = layoutlm_model(**model_in)
            logits = outputs.logits

        # 获取预测结果
        predictions = logits.argmax(-1).squeeze().tolist()
        if len(predictions) == settings.LAYOUT_MODEL_MAX:
            predictions = [predictions]
        # 将预测结果添加到所有预测结果列表中
        all_predictions.extend(predictions)
    # 返回所有预测结果
    return all_predictions


# 将预测结果与框匹配
def match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model) -> List[List[BlockType]]:
    # 断言编码、预测结果和样本长度的长度相等
    assert len(encodings) == len(predictions) == sum(sample_lengths)
    assert len(metadata) == len(sample_lengths)

    # 初始化页面起始索引和页面块类型列表
    page_start = 0
    page_block_types = []
    # 返回页面块类型列表
    return page_block_types

.\marker\marker\settings.py

代码语言:javascript
复制
import os
from typing import Optional, List, Dict

from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import fitz as pymupdf
import torch

# 定义一个设置类,继承自BaseSettings
class Settings(BaseSettings):
    # General
    TORCH_DEVICE: Optional[str] = None

    # 计算属性,返回TORCH_DEVICE_MODEL
    @computed_field
    @property
    def TORCH_DEVICE_MODEL(self) -> str:
        # 如果TORCH_DEVICE不为None,则返回TORCH_DEVICE
        if self.TORCH_DEVICE is not None:
            return self.TORCH_DEVICE

        # 如果CUDA可用,则返回"cuda"
        if torch.cuda.is_available():
            return "cuda"

        # 如果MPS可用,则返回"mps"
        if torch.backends.mps.is_available():
            return "mps"

        # 否则返回"cpu"
        return "cpu"

    INFERENCE_RAM: int = 40 # 每个GPU的VRAM量(以GB为单位)。
    VRAM_PER_TASK: float = 2.5 # 每个任务分配的VRAM量(以GB为单位)。 峰值标记VRAM使用量约为3GB,但工作程序的平均值较低。
    DEFAULT_LANG: str = "English" # 我们假设文件所在的默认语言,应该是TESSERACT_LANGUAGES中的一个键

    SUPPORTED_FILETYPES: Dict = {
        "application/pdf": "pdf",
        "application/epub+zip": "epub",
        "application/x-mobipocket-ebook": "mobi",
        "application/vnd.ms-xpsdocument": "xps",
        "application/x-fictionbook+xml": "fb2"
    }

    # PyMuPDF
    TEXT_FLAGS: int = pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES

    # OCR
    INVALID_CHARS: List[str] = [chr(0xfffd), "�"]
    OCR_DPI: int = 400
    TESSDATA_PREFIX: str = ""
    TESSERACT_LANGUAGES: Dict = {
        "English": "eng",
        "Spanish": "spa",
        "Portuguese": "por",
        "French": "fra",
        "German": "deu",
        "Russian": "rus",
        "Chinese": "chi_sim",
        "Japanese": "jpn",
        "Korean": "kor",
        "Hindi": "hin",
    }
    TESSERACT_TIMEOUT: int = 20 # 何时放弃OCR
    # 定义拼写检查语言对应的字典
    SPELLCHECK_LANGUAGES: Dict = {
        "English": "en",
        "Spanish": "es",
        "Portuguese": "pt",
        "French": "fr",
        "German": "de",
        "Russian": "ru",
        "Chinese": None,
        "Japanese": None,
        "Korean": None,
        "Hindi": None,
    }
    # 是否在每一页运行 OCR,即使可以提取文本
    OCR_ALL_PAGES: bool = False
    # 用于 OCR 的并行 CPU 工作线程数
    OCR_PARALLEL_WORKERS: int = 2
    # 使用的 OCR 引擎,可以是 "tesseract" 或 "ocrmypdf",ocrmypdf 质量更高但速度较慢
    OCR_ENGINE: str = "ocrmypdf"

    # Texify 模型相关参数
    TEXIFY_MODEL_MAX: int = 384 # Texify 的最大推理长度
    TEXIFY_TOKEN_BUFFER: int = 256 # Texify 的 token 缓冲区大小
    TEXIFY_DPI: int = 96 # 渲染图像的 DPI
    TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Texify 的批处理大小,CPU 上较低因为使用 float32
    TEXIFY_MODEL_NAME: str = "vikp/texify"

    # Layout 模型相关参数
    BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
    LAYOUT_MODEL_MAX: int = 512
    LAYOUT_CHUNK_OVERLAP: int = 64
    LAYOUT_DPI: int = 96
    LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter"
    LAYOUT_BATCH_SIZE: int = 8 # 最大 512 个 token 意味着较高的批处理大小

    # Ordering 模型相关参数
    ORDERER_BATCH_SIZE: int = 32 # 可以较高,因为最大 token 数为 128
    ORDERER_MODEL_NAME: str = "vikp/column_detector"

    # 最终编辑模型相关参数
    EDITOR_BATCH_SIZE: int = 4
    EDITOR_MAX_LENGTH: int = 1024
    EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
    ENABLE_EDITOR_MODEL: bool = False # 编辑模型可能会产生误报
    EDITOR_CUTOFF_THRESH: float = 0.9 # 忽略概率低于此阈值的预测

    # Ray 相关参数
    RAY_CACHE_PATH: Optional[str] = None # 保存 Ray 缓存的路径
    RAY_CORES_PER_WORKER: int = 1 # 每个 worker 分配的 CPU 核心数

    # 调试相关参数
    DEBUG: bool = False # 启用调试日志
    # 调试数据文件夹路径,默认为 None
    DEBUG_DATA_FOLDER: Optional[str] = None
    # 调试级别,范围从 0 到 2,2 表示记录所有信息
    DEBUG_LEVEL: int = 0
    
    # 计算属性,返回是否使用 CUDA
    @computed_field
    @property
    def CUDA(self) -> bool:
        return "cuda" in self.TORCH_DEVICE
    
    # 计算属性,返回模型数据类型
    @computed_field
    @property
    def MODEL_DTYPE(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cuda":
            return torch.bfloat16
        else:
            return torch.float32
    
    # 计算属性,返回用于转换的数据类型
    @computed_field
    @property
    def TEXIFY_DTYPE(self) -> torch.dtype:
        return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16
    
    # 类配置
    class Config:
        # 从环境文件中查找 local.env 文件
        env_file = find_dotenv("local.env")
        # 额外配置,忽略错误
        extra = "ignore"
# 创建一个 Settings 对象实例
settings = Settings()

.\marker\scripts\verify_benchmark_scores.py

代码语言:javascript
复制
# 导入 json 模块和 argparse 模块
import json
import argparse

# 验证分数的函数,接收一个文件路径作为参数
def verify_scores(file_path):
    # 打开文件并加载 JSON 数据
    with open(file_path, 'r') as file:
        data = json.load(file)

    # 获取 multicolcnn.pdf 文件的分数
    multicolcnn_score = data["marker"]["files"]["multicolcnn.pdf"]["score"]
    # 获取 switch_trans.pdf 文件的分数
    switch_trans_score = data["marker"]["files"]["switch_trans.pdf"]["score"]

    # 如果其中一个分数小于等于 0.4,则抛出 ValueError 异常
    if multicolcnn_score <= 0.4 or switch_trans_score <= 0.4:
        raise ValueError("One or more scores are below the required threshold of 0.4")

# 如果当前脚本被直接执行
if __name__ == "__main__":
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(description="Verify benchmark scores")
    # 添加一个参数,指定文件路径,类型为字符串
    parser.add_argument("file_path", type=str, help="Path to the json file")
    # 解析命令行参数
    args = parser.parse_args()
    # 调用 verify_scores 函数,传入文件路径参数
    verify_scores(args.file_path)
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2024-03-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • .\marker\marker\models.py
  • .\marker\marker\ocr\page.py
  • .\marker\marker\ocr\utils.py
  • .\marker\marker\ordering.py
  • .\marker\marker\postprocessors\editor.py
  • .\marker\marker\postprocessors\t5.py
  • .\marker\marker\schema.py
  • .\marker\marker\segmentation.py
  • .\marker\marker\settings.py
  • .\marker\scripts\verify_benchmark_scores.py
相关产品与服务
AI 应用产品
文字识别(Optical Character Recognition,OCR)基于腾讯优图实验室的深度学习技术,将图片上的文字内容,智能识别成为可编辑的文本。OCR 支持身份证、名片等卡证类和票据类的印刷体识别,也支持运单等手写体识别,支持提供定制化服务,可以有效地代替人工录入信息。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档