T5架构
Text-to-Text Transfer Transformer
概述
T5是Google于2019年提出的统一文本到文本框架,将所有NLP任务都转化为文本生成问题。 这种统一的任务表示方式使得同一个模型可以处理翻译、分类、问答、摘要等多种任务。
核心思想:通过任务前缀将所有NLP任务统一为文本到文本格式,实现模型和训练流程的统一。
文本到文本框架
任务统一表示
所有任务都转化为"输入文本 → 输出文本"的形式:
# 英德翻译
Input: "translate English to German: That is good."
Output: "Das ist gut."
# 文本分类
Input: "cola sentence: The course is jumping well."
Output: "acceptable"
# 摘要生成
Input: "summarize: [long article text]"
Output: "[summary text]"
# 问答
Input: "question: What is the capital of France? context: France is a country..."
Output: "Paris"
任务前缀示例
| 任务 | 前缀示例 |
|---|---|
| 翻译 | translate English to German: |
| 摘要 | summarize: |
| 问答 | question: ... context: |
| 分类 | cola sentence: / stsb sentence1: ... sentence2: |
| 回归 | stsb sentence1: ... sentence2: → "3.5" |
模型架构
标准Transformer编码器-解码器
T5使用原始Transformer的编码器-解码器架构,但做了一些简化:
- 移除了编码器-解码器注意力中的偏置
- 使用简化版的LayerNorm(无偏置,在外部应用)
- 使用相对位置编码而非绝对位置编码
| 模型 | 层数 | 隐藏维度 | FFN维度 | 注意力头 | 参数量 |
|---|---|---|---|---|---|
| T5-small | 6 | 512 | 1024 | 8 | 60M |
| T5-base | 12 | 768 | 3072 | 12 | 220M |
| T5-large | 24 | 1024 | 4096 | 16 | 770M |
| T5-3B | 24 | 1024 | 16384 | 32 | 3B |
| T5-11B | 24 | 1024 | 65536 | 128 | 11B |
预训练任务
Span Corruption
T5使用Span Corruption作为预训练任务,随机遮蔽连续的token序列:
- 随机遮蔽15%的token
- 遮蔽连续的span,平均长度为3
- 每个遮蔽span用唯一的哨兵token替代
- 目标是在输出中还原被遮蔽的内容
# 原始文本
Thank you for inviting me to your party last week.
# 遮蔽后输入
Thank you <X> me to your party <Y> week.
# 目标输出
<X> for inviting <Y> last <Z>
注意:目标输出也需要哨兵token来标记不同span的开始。
代码示例
基础使用
from transformers import T5Tokenizer, T5ForConditionalGeneration
# 加载模型和分词器
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')
# 翻译任务
input_text = "translate English to German: The house is wonderful."
input_ids = tokenizer(input_text, return_tensors='pt').input_ids
outputs = model.generate(input_ids, max_length=40)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 输出: Das Haus ist wunderbar.
# 摘要任务
article = "summarize: New research shows that..."
inputs = tokenizer(article, return_tensors='pt', max_length=512, truncation=True)
outputs = model.generate(inputs.input_ids, max_length=150, num_beams=4)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))多任务微调
# 准备多任务数据
tasks = [
{"input": "translate English to German: Hello", "target": "Hallo"},
{"input": "classify sentiment: I love this!", "target": "positive"},
{"input": "summarize: Long article text...", "target": "Brief summary"},
]
# 统一训练
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir='./t5-multitask',
num_train_epochs=3,
per_device_train_batch_size=8,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()T5变体
| 变体 | 特点 | 适用场景 |
|---|---|---|
| T5 | 标准版本 | 通用NLP任务 |
| mT5 | 多语言版本,101种语言 | 多语言任务 |
| T5-v1_1 | 改进版,GELU激活,无预训练dropout | 更好性能 |
| Flan-T5 | 指令微调版本 | 零样本/少样本推理 |
性能对比
| 基准测试 | T5-base | T5-large | T5-11B |
|---|---|---|---|
| GLUE | 83.6 | 86.4 | 89.3 |
| SuperGLUE | - | - | 88.9 |
| SQuAD | 85.5 | 89.1 | 93.4 |
| WMT En-De | 27.0 | 28.4 | 29.8 |
最佳实践
- 使用任务前缀明确指定任务类型
- 多任务学习时混合不同任务的数据
- 生成时使用beam search获得更好结果
- 长序列任务考虑使用T5-large或更大版本
- 多语言任务使用mT5
- 零样本推理优先使用Flan-T5
参考资料
- Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer (Raffel et al., 2019)
- mT5: A massively multilingual pre-trained text-to-text transformer (Xue et al., 2020)
- Hugging Face T5 Documentation