首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >机器学习经典算法:伯努利朴素贝叶斯(Bernoulli Naive Bayes)原理、手动计算与Python/Java双代码实战

机器学习经典算法:伯努利朴素贝叶斯(Bernoulli Naive Bayes)原理、手动计算与Python/Java双代码实战

原创
作者头像
jack.yang
发布2026-03-29 14:09:09
发布2026-03-29 14:09:09
80
举报
文章被收录于专栏:大模型系列大模型系列

一句话答案:伯努利朴素贝叶斯专为二值特征设计,只关心词“是否出现”,不计频次。在短文本、关键词检测场景中,比多项式NB更精准!

如果你在搜索:

  • “伯努利朴素贝叶斯怎么算的?”
  • “Bernoulli NB 手动计算例子”
  • “垃圾邮件检测该用哪种朴素贝叶斯?”
  • “Python 和 Java 怎么实现伯努利朴素贝叶斯?”

那么,这篇文章就是为你写的——从0/1向量到分类决策,一步不跳


一、什么是伯努利朴素贝叶斯?它和多项式、高斯有何不同?

朴素贝叶斯家族三大成员,适用场景截然不同:

算法类型

特征类型

核心逻辑

最佳场景

多项式朴素贝叶斯

词频计数(1,2,3...)

“这个词出现了几次?”

长文本、新闻分类

伯努利朴素贝叶斯

二值特征(0/1)

“这个词有没有出现?”

短文本、关键词检测、垃圾邮件

高斯朴素贝叶斯

连续数值

“这个值符合什么分布?”

身高、温度等连续数据

🤔 为什么短文本要用伯努利?

  • 短文本(如短信、评论、标题)词频信息少,“免费”出现1次 vs 2次差别不大。
  • 关键词是否存在极具判别力:含“赢大奖”≈垃圾邮件,不含≈正常。
  • 伯努利NB聚焦存在性,避免被低频词干扰,更适合稀疏、短小的文本。

二、数学原理:伯努利分布 + 贝叶斯定理

我们要计算:

根据贝叶斯定理和独立性假设:

📐 词出现概率估计(带拉普拉斯平滑)


三、手工推演:一步步计算垃圾邮件分类(带完整数据)

📊 训练数据集(5封邮件,二值化处理)

邮件ID

内容(分词后)

二值向量(词汇表:["免费", "赢", "会议", "项目"])

类别

1

["免费", "赢"]

[1, 1, 0, 0]

垃圾 (1)

2

["免费"]

[1, 0, 0, 0]

垃圾 (1)

3

["会议", "安排"]

[0, 0, 1, 0]

正常 (0)

4

["项目", "进展"]

[0, 0, 0, 1]

正常 (0)

5

["免费", "会议"]

[1, 0, 1, 0]

垃圾 (1)

目标:预测新邮件 ["赢", "项目"] → 二值向量 [1, 1, 0, 1] 属于哪个类别?

🔢 步骤1:计算先验概率 (P(C))

  • 总邮件数 = 5
  • 垃圾邮件数 = 3 → (P(1) = 3/5 = 0.6)
  • 正常邮件数 = 2 → (P(0) = 2/5 = 0.4)

🔢 步骤2:计算每个词在各类中的出现概率(α=1)

垃圾邮件类 (1):共3封
  • “免费”:出现在邮件1,2,5 → 3封 → (P(免费=1|1) = (3+1)/(3+2) = 4/5 = 0.8)
  • “赢”:出现在邮件1 → 1封 → (P(赢=1|1) = (1+1)/5 = 0.4)
  • “会议”:出现在邮件5 → 1封 → (P(会议=1|1) = 2/5 = 0.4)
  • “项目”:0封 → (P(项目=1|1) = (0+1)/5 = 0.2)
正常邮件类 (0):共2封
  • “免费”:0封 → (P(免费=1|0) = (0+1)/(2+2) = 1/4 = 0.25)
  • “赢”:0封 → (P(赢=1|0) = 1/4 = 0.25)
  • “会议”:1封 → (P(会议=1|0) = (1+1)/4 = 0.5)
  • “项目”:1封 → (P(项目=1|0) = 2/4 = 0.5)

🔢 步骤3:计算新邮件 [1, 1, 0, 1] 的后验概率

新邮件特征:

  • “免费”=1, “赢”=1, “会议”=0, “项目”=1
垃圾邮件 (1):
正常邮件 (0):

结论:0.02304 > 0.00625 → 判定为 垃圾邮件 (1)

尽管“项目”是正常词,但“赢”的出现(在垃圾邮件中概率更高)起了决定性作用。


四、Python 实现(scikit-learn + 手写版)

✅ 方式1:使用 scikit-learn(推荐生产环境)

代码语言:javascript
复制
from sklearn.naive_bayes import BernoulliNB
from sklearn.feature_extraction.text import CountVectorizer

# 数据
texts = [
    "免费 赢",
    "免费",
    "会议 安排",
    "项目 进展",
    "免费 会议"
]
labels = [1, 1, 0, 0, 1]  # 1=垃圾, 0=正常

# 向量化(二值化)
vectorizer = CountVectorizer(binary=True)  # 关键:binary=True
X = vectorizer.fit_transform(texts)

# 训练
clf = BernoulliNB(alpha=1.0)
clf.fit(X, labels)

# 预测
new_text = vectorizer.transform(["赢 项目"])
pred = clf.predict(new_text)
print("预测结果:", "垃圾邮件" if pred[0] == 1 else "正常邮件")
# 输出: 垃圾邮件

✅ 方式2:手写核心逻辑

代码语言:javascript
复制
import numpy as np
from collections import defaultdict

class BernoulliNB:
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def fit(self, docs, labels):
        # 构建词汇表
        vocab = set(word for doc in docs for word in doc)
        self.vocab = {word: i for i, word in enumerate(vocab)}
        self.n_vocab = len(vocab)
        
        # 统计
        class_doc_count = defaultdict(int)
        word_doc_count = defaultdict(lambda: np.zeros(self.n_vocab))
        
        for doc, label in zip(docs, labels):
            class_doc_count[label] += 1
            unique_words = set(doc)
            for word in unique_words:
                if word in self.vocab:
                    word_doc_count[label][self.vocab[word]] += 1
        
        # 先验概率
        total_docs = len(labels)
        self.priors = {label: count / total_docs for label, count in class_doc_count.items()}
        
        # 条件概率(伯努利,带平滑)
        self.cond_prob = {}
        for label in class_doc_count:
            n_docs = class_doc_count[label]
            smoothed_count = word_doc_count[label] + self.alpha
            smoothed_total = n_docs + 2 * self.alpha
            self.cond_prob[label] = smoothed_count / smoothed_total
    
    def predict(self, docs):
        predictions = []
        for doc in docs:
            scores = {}
            unique_words = set(doc)
            for label in self.priors:
                log_score = np.log(self.priors[label])
                for word in self.vocab:
                    idx = self.vocab[word]
                    if word in unique_words:
                        log_score += np.log(self.cond_prob[label][idx])
                    else:
                        log_score += np.log(1 - self.cond_prob[label][idx])
                scores[label] = log_score
            predictions.append(max(scores, key=scores.get))
        return predictions

# 使用
docs = [["免费","赢"], ["免费"], ["会议","安排"], ["项目","进展"], ["免费","会议"]]
nb = BernoulliNB(alpha=1.0)
nb.fit(docs, [1,1,0,0,1])
print(nb.predict())  # 输出: [1]

五、Java 实现(纯手写,无第三方库)

代码语言:javascript
复制
import java.util.*;

public class BernoulliNaiveBayes {
    private Map<String, Integer> vocab = new HashMap<>();
    private Map<Integer, Double> priors = new HashMap<>();
    private Map<Integer, double[]> condProb = new HashMap<>();
    private int vocabSize;
    private double alpha;

    public BernoulliNaiveBayes(double alpha) {
        this.alpha = alpha;
    }

    public void fit(List<List<String>> docs, List<Integer> labels) {
        // 构建词汇表
        Set<String> wordSet = new HashSet<>();
        for (List<String> doc : docs) {
            wordSet.addAll(doc);
        }
        int idx = 0;
        for (String word : wordSet) {
            vocab.put(word, idx++);
        }
        vocabSize = vocab.size();

        // 统计每个类别的文档数
        Map<Integer, Integer> classDocCount = new HashMap<>();
        for (int label : labels) {
            classDocCount.put(label, classDocCount.getOrDefault(label, 0) + 1);
        }

        // 统计每个词在各类中出现的文档数
        Map<Integer, int[]> wordDocCount = new HashMap<>();
        for (int label : classDocCount.keySet()) {
            wordDocCount.put(label, new int[vocabSize]);
        }

        for (int i = 0; i < docs.size(); i++) {
            int label = labels.get(i);
            Set<String> uniqueWords = new HashSet<>(docs.get(i));
            for (String word : uniqueWords) {
                if (vocab.containsKey(word)) {
                    int wIdx = vocab.get(word);
                    wordDocCount.get(label)[wIdx]++;
                }
            }
        }

        // 先验概率
        int totalDocs = labels.size();
        for (int label : classDocCount.keySet()) {
            priors.put(label, (double) classDocCount.get(label) / totalDocs);
        }

        // 条件概率(伯努利平滑)
        for (int label : classDocCount.keySet()) {
            double[] prob = new double[vocabSize];
            int nDocs = classDocCount.get(label);
            for (int i = 0; i < vocabSize; i++) {
                prob[i] = (wordDocCount.get(label)[i] + alpha) / (nDocs + 2 * alpha);
            }
            condProb.put(label, prob);
        }
    }

    public int predict(List<String> doc) {
        Set<String> uniqueWords = new HashSet<>(doc);
        Map<Integer, Double> scores = new HashMap<>();
        for (int label : priors.keySet()) {
            double logScore = Math.log(priors.get(label));
            for (String word : vocab.keySet()) {
                int wIdx = vocab.get(word);
                if (uniqueWords.contains(word)) {
                    logScore += Math.log(condProb.get(label)[wIdx]);
                } else {
                    logScore += Math.log(1 - condProb.get(label)[wIdx]);
                }
            }
            scores.put(label, logScore);
        }
        return Collections.max(scores.entrySet(), Map.Entry.comparingByValue()).getKey();
    }

    // 测试
    public static void main(String[] args) {
        List<List<String>> docs = Arrays.asList(
            Arrays.asList("免费", "赢"),
            Arrays.asList("免费"),
            Arrays.asList("会议", "安排"),
            Arrays.asList("项目", "进展"),
            Arrays.asList("免费", "会议")
        );
        List<Integer> labels = Arrays.asList(1, 1, 0, 0, 1);

        BernoulliNaiveBayes nb = new BernoulliNaiveBayes(1.0);
        nb.fit(docs, labels);

        List<String> testDoc = Arrays.asList("赢", "项目");
        int prediction = nb.predict(testDoc);
        System.out.println("预测结果: " + (prediction == 1 ? "垃圾邮件" : "正常邮件"));
        // 输出: 垃圾邮件
    }
}

六、优缺点 & 适用场景总结

优点

缺点

✅ 专为短文本、关键词存在性优化

❌ 忽略词频信息(长文本可能损失信号)

✅ 对稀疏数据鲁棒

❌ 词汇表需预定义,OOV词无法处理

✅ 计算高效,内存占用小

❌ 假设词独立,忽略上下文

✅ 在垃圾邮件、恶意URL检测中表现优异

❌ 平滑参数需调优

🎯 最佳应用场景:

  • 垃圾邮件/钓鱼邮件检测
  • 恶意软件URL关键词分析
  • 短评论情感判断(如“差”“烂”是否存在)
  • 任何基于关键词存在性的二分类任务

七、朴素贝叶斯全家桶对比总结

算法

数据类型

关键参数

何时选用

高斯NB

连续数值

均值μ、标准差σ

身高、温度、成绩等连续特征

多项式NB

词频计数

词频、拉普拉斯平滑

新闻分类、长文本情感分析

伯努利NB

二值特征(0/1)

文档出现次数、平滑

垃圾邮件、短文本、关键词检测

💡 经验法则

  • 文本很长 → 用多项式NB
  • 文本很短 → 用伯努利NB
  • 不是文本 → 用高斯NB

✅ 结语

伯努利朴素贝叶斯用最简单的0/1逻辑,解决了最关键的关键词识别问题。它不关心“说了多少”,只关心“有没有说”——而这,往往是判断意图的第一道防线。

记住:在AI的世界里,有时存在即意义

现在,你已经能:

  • 手动计算伯努利朴素贝叶斯分类结果
  • 用Python或Java从零实现它
  • 根据任务特点选择最合适的朴素贝叶斯变体

关键词:机器学习、伯努利朴素贝叶斯、Bernoulli Naive Bayes、二值特征、垃圾邮件检测、短文本分类、拉普拉斯平滑、Python 伯努利朴素贝叶斯、Java 伯努利朴素贝叶斯、手动计算

相关链接

  • 📂 大模型技术专栏: 欢迎您到访 「大模型系列」。 在这个由参数驱动、以数据为燃料的新智能时代,大语言模型(LLM)已不再是实验室里的前沿概念,而是正在重塑搜索、办公、编程、教育、医疗乃至整个数字世界的底层引擎。从 GPT 到 Llama,从 Claude 到 Qwen,从推理到多模态,大模型正以前所未有的速度进化——它们既是工具,也是平台,更可能是下一代人机交互的“操作系统”。 本系列将带你:
    • 🔍 深入原理:从 Transformer 架构、注意力机制到训练范式(预训练、微调、RLHF);
    • ⚙️ 动手实践:本地部署、模型微调、RAG 构建、Agent 设计等实战指南;
    • 🧠 理解边界:幻觉、偏见、安全对齐、推理瓶颈与当前能力天花板;
    • 🌍 洞察趋势:开源 vs 闭源、端侧部署、MoE 架构、世界模型与 AGI 路径;
    • 💼 落地应用:如何在企业中安全、高效、低成本地集成大模型能力。

    无论你是想写代码调用 API 的开发者,设计 AI 产品的 PM,评估技术路线的管理者,还是单纯好奇智能本质的思考者,这里都有值得你驻足的内容。 不追 hype,只讲逻辑;不谈玄学,专注可复现的认知。 让我们一起,在这场百年一遇的智能革命中,看得更清,走得更稳 https://cloud.tencent.com/developer/column/107314

  • 👤 关于作者专注技术落地,深耕硬核干货 本文作者致力于大模型相关技术的生态建设与实战落地。不同于浅层的概念科普,作者坚持 “手算 + 代码” 的深度分享模式,主张通过手动推演理解算法本质,结合生产级代码验证理论可行性。 请关注我主页:https://cloud.tencent.com/developer/user/2276240

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、什么是伯努利朴素贝叶斯?它和多项式、高斯有何不同?
    • 🤔 为什么短文本要用伯努利?
  • 二、数学原理:伯努利分布 + 贝叶斯定理
    • 📐 词出现概率估计(带拉普拉斯平滑)
  • 三、手工推演:一步步计算垃圾邮件分类(带完整数据)
    • 📊 训练数据集(5封邮件,二值化处理)
    • 🔢 步骤1:计算先验概率 (P(C))
    • 🔢 步骤2:计算每个词在各类中的出现概率(α=1)
      • 垃圾邮件类 (1):共3封
      • 正常邮件类 (0):共2封
    • 🔢 步骤3:计算新邮件 [1, 1, 0, 1] 的后验概率
      • 垃圾邮件 (1):
      • 正常邮件 (0):
  • 四、Python 实现(scikit-learn + 手写版)
    • ✅ 方式1:使用 scikit-learn(推荐生产环境)
    • ✅ 方式2:手写核心逻辑
  • 五、Java 实现(纯手写,无第三方库)
  • 六、优缺点 & 适用场景总结
    • 🎯 最佳应用场景:
  • 七、朴素贝叶斯全家桶对比总结
  • ✅ 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档