T5架构

Text-to-Text Transfer Transformer

概述

T5是Google于2019年提出的统一文本到文本框架,将所有NLP任务都转化为文本生成问题。 这种统一的任务表示方式使得同一个模型可以处理翻译、分类、问答、摘要等多种任务。

核心思想:通过任务前缀将所有NLP任务统一为文本到文本格式,实现模型和训练流程的统一。

文本到文本框架

任务统一表示

所有任务都转化为"输入文本 → 输出文本"的形式:

# 英德翻译

Input: "translate English to German: That is good."

Output: "Das ist gut."


# 文本分类

Input: "cola sentence: The course is jumping well."

Output: "acceptable"


# 摘要生成

Input: "summarize: [long article text]"

Output: "[summary text]"


# 问答

Input: "question: What is the capital of France? context: France is a country..."

Output: "Paris"

任务前缀示例

任务前缀示例
翻译translate English to German:
摘要summarize:
问答question: ... context:
分类cola sentence: / stsb sentence1: ... sentence2:
回归stsb sentence1: ... sentence2: → "3.5"

模型架构

标准Transformer编码器-解码器

T5使用原始Transformer的编码器-解码器架构,但做了一些简化:

  • 移除了编码器-解码器注意力中的偏置
  • 使用简化版的LayerNorm(无偏置,在外部应用)
  • 使用相对位置编码而非绝对位置编码
模型层数隐藏维度FFN维度注意力头参数量
T5-small65121024860M
T5-base12768307212220M
T5-large241024409616770M
T5-3B24102416384323B
T5-11B2410246553612811B

预训练任务

Span Corruption

T5使用Span Corruption作为预训练任务,随机遮蔽连续的token序列:

  • 随机遮蔽15%的token
  • 遮蔽连续的span,平均长度为3
  • 每个遮蔽span用唯一的哨兵token替代
  • 目标是在输出中还原被遮蔽的内容

# 原始文本

Thank you for inviting me to your party last week.


# 遮蔽后输入

Thank you <X> me to your party <Y> week.


# 目标输出

<X> for inviting <Y> last <Z>

注意:目标输出也需要哨兵token来标记不同span的开始。

代码示例

基础使用

from transformers import T5Tokenizer, T5ForConditionalGeneration

# 加载模型和分词器
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

# 翻译任务
input_text = "translate English to German: The house is wonderful."
input_ids = tokenizer(input_text, return_tensors='pt').input_ids

outputs = model.generate(input_ids, max_length=40)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 输出: Das Haus ist wunderbar.

# 摘要任务
article = "summarize: New research shows that..."
inputs = tokenizer(article, return_tensors='pt', max_length=512, truncation=True)
outputs = model.generate(inputs.input_ids, max_length=150, num_beams=4)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

多任务微调

# 准备多任务数据
tasks = [
    {"input": "translate English to German: Hello", "target": "Hallo"},
    {"input": "classify sentiment: I love this!", "target": "positive"},
    {"input": "summarize: Long article text...", "target": "Brief summary"},
]

# 统一训练
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./t5-multitask',
    num_train_epochs=3,
    per_device_train_batch_size=8,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

T5变体

变体特点适用场景
T5标准版本通用NLP任务
mT5多语言版本,101种语言多语言任务
T5-v1_1改进版,GELU激活,无预训练dropout更好性能
Flan-T5指令微调版本零样本/少样本推理

性能对比

基准测试T5-baseT5-largeT5-11B
GLUE83.686.489.3
SuperGLUE--88.9
SQuAD85.589.193.4
WMT En-De27.028.429.8

最佳实践

  • 使用任务前缀明确指定任务类型
  • 多任务学习时混合不同任务的数据
  • 生成时使用beam search获得更好结果
  • 长序列任务考虑使用T5-large或更大版本
  • 多语言任务使用mT5
  • 零样本推理优先使用Flan-T5

参考资料

  • Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer (Raffel et al., 2019)
  • mT5: A massively multilingual pre-trained text-to-text transformer (Xue et al., 2020)
  • Hugging Face T5 Documentation
----