BERT架构
Bidirectional Encoder Representations from Transformers
概述
BERT(Bidirectional Encoder Representations from Transformers)是Google于2018年提出的预训练语言模型, 它通过双向Transformer编码器学习语言表示,在多项NLP任务上取得了突破性进展。
核心创新:利用掩码语言模型(MLM)和下一句预测(NSP)两个预训练任务,实现真正的双向上下文理解。
模型架构
整体结构
BERT仅使用Transformer的编码器部分,是一个深层双向模型:
| 模型 | 层数 | 隐藏维度 | 注意力头 | 参数量 |
|---|---|---|---|---|
| BERT-Base | 12 | 768 | 12 | 110M |
| BERT-Large | 24 | 1024 | 16 | 340M |
输入表示
BERT的输入由三部分嵌入组成:
- Token Embeddings:WordPiece分词后的词嵌入
- Segment Embeddings:区分句子对中的两个句子(EA/EB)
- Position Embeddings:位置信息嵌入(可学习)
Input = Token Embeddings + Segment Embeddings + Position Embeddings
特殊标记
[CLS]:序列起始标记,其最终隐藏状态用于分类任务[SEP]:句子分隔标记,区分不同的句子[MASK]:掩码标记,用于MLM预训练
预训练任务
掩码语言模型 (MLM)
随机遮蔽输入序列中15%的token,然后预测被遮蔽的词:
- 80%概率替换为
[MASK] - 10%概率替换为随机词
- 10%概率保持不变
下一句预测 (NSP)
判断两个句子是否为连续句子,二分类任务:
- 正样本:实际连续的两个句子
- 负样本:随机配对的两个句子
代码示例
使用Transformers库
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# 文本编码
text = "Hello, this is a sample sentence for BERT."
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
# 获取BERT表示
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state # (batch, seq_len, hidden_dim)
pooler_output = outputs.pooler_output # (batch, hidden_dim) [CLS]表示
# 用于下游分类任务
classifier = BertForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=2
)微调示例
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
# 加载用于分类的BERT
model = BertForSequenceClassification.from_pretrained(
'bert-base-uncased',
num_labels=num_classes
)
# 训练参数
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
)
# 训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()应用场景
文本分类
情感分析、主题分类、意图识别等,使用[CLS]向量进行分类。
命名实体识别
序列标注任务,对每个token进行分类识别实体边界。
问答系统
预测答案的起始和结束位置,抽取式问答。
语义相似度
句子对相似度计算,用于信息检索和匹配。
与相关模型对比
| 模型 | 架构 | 方向 | 预训练任务 | 主要应用 |
|---|---|---|---|---|
| BERT | Encoder-only | 双向 | MLM + NSP | 理解任务 |
| GPT | Decoder-only | 单向 | CLM | 生成任务 |
| T5 | Enc-Dec | 双向+单向 | Span Corruption | 理解+生成 |
最佳实践
- 最大序列长度512,长文本需要截断或分片处理
- 微调时学习率通常较小(2e-5到5e-5)
- 使用AdamW优化器配合线性学习率衰减
- 分类任务使用[CLS]向量,序列标注使用所有token表示
- 考虑使用更大的预训练版本(如RoBERTa、DeBERTa)获得更好效果
参考资料
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al., 2018)
- Hugging Face Transformers Documentation
- Google Research BERT Repository
----