Python BERT 多任务学习

作者:追风剑情 发布于:2026-6-24 17:05 分类:AI

  多任务学习(multi-task learning)是指通过同时训练模型完成多个相关任务提高模型的泛化能力。例如,可以同时训练一个模型进行图像分类和文本生成,从而使模型能够更好地理解图像和文本之间的关系。下面的实例演示了在模型训练中使用预训练模型实现多任务学习的过程。在这个例子中,使用预训练的ResNet模型提取图像特征,使用预训练的BERT模型提取文本特征,然后将这些特征用于两个不同的任务:图像分类和文本分类。

示例

# ---------- 1. 开启离线模式(必须在所有 import 之前) ----------
import os
os.environ["TRANSFORMERS_OFFLINE"] = "1"   # 禁止联网下载

# ---------- 2. 定义模型路径常量 ----------
BERT_MODEL_PATH = "./my_bert_model"          # BERT 模型本地目录
RESNET_WEIGHT_PATH = "resnet50-0676ba61.pth" # ResNet 权重文件路径

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import BertTokenizer, BertModel

# ---------- 3. 数据集类 ----------
class MultimodalDataset(Dataset):
    def __init__(self, image_paths, texts, image_labels, text_labels, transform=None):
        self.image_paths = image_paths
        self.texts = texts
        self.image_labels = image_labels
        self.text_labels = text_labels
        self.transform = transform
        # 使用常量加载 BERT 分词器
        self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH)

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        # 图像处理
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # 文本处理
        text = self.texts[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=128,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()

        image_label = torch.tensor(self.image_labels[idx], dtype=torch.long)
        text_label = torch.tensor(self.text_labels[idx], dtype=torch.long)

        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'image_label': image_label,
            'text_label': text_label
        }

# ---------- 4. 模型类 ----------
class MultimodalModel(nn.Module):
    def __init__(self, num_classes_image=2, num_classes_text=2):
        super(MultimodalModel, self).__init__()
        # ----- 图像分支:从本地加载 ResNet50 权重(使用常量) -----
        resnet = models.resnet50(weights=None)
        state_dict = torch.load(RESNET_WEIGHT_PATH, map_location='cpu', weights_only=True)
        resnet.load_state_dict(state_dict, strict=False)
        self.image_model = nn.Sequential(*list(resnet.children())[:-1])
        self.image_classifier = nn.Linear(2048, num_classes_image)

        # ----- 文本分支:从本地目录加载 BERT 模型(使用常量) -----
        self.text_model = BertModel.from_pretrained(BERT_MODEL_PATH)
        self.text_classifier = nn.Linear(768, num_classes_text)

    def forward(self, image, input_ids, attention_mask):
        # 图像特征
        img_feat = self.image_model(image)
        img_feat = img_feat.view(img_feat.size(0), -1)
        image_output = self.image_classifier(img_feat)

        # 文本特征
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]   # [CLS] 向量
        text_output = self.text_classifier(text_feat)

        return image_output, text_output

# ---------- 5. 数据预处理 ----------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ---------- 6. 构造示例数据 ----------
image_paths = ['running_dog.png', 'sleeping_cat.png']   # 请替换为实际图片路径
texts = ['a dog running', 'a cat sleeping']
image_labels = [0, 1]   # 0: dog, 1: cat
text_labels = [0, 1]    # 0: 描述狗, 1: 描述猫

dataset = MultimodalDataset(image_paths, texts, image_labels, text_labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# ---------- 7. 初始化 ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultimodalModel(num_classes_image=2, num_classes_text=2).to(device)
criterion_image = nn.CrossEntropyLoss()
criterion_text = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ---------- 8. 训练 ----------
model.train()
for epoch in range(5):
    for batch in dataloader:
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        image_labels = batch['image_label'].to(device)
        text_labels = batch['text_label'].to(device)

        optimizer.zero_grad()
        image_outputs, text_outputs = model(images, input_ids, attention_mask)

        loss_image = criterion_image(image_outputs, image_labels)
        loss_text = criterion_text(text_outputs, text_labels)
        total_loss = loss_image + loss_text

        total_loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/5], Image Loss: {loss_image.item():.4f}, Text Loss: {loss_text.item():.4f}')

运行测试
11111.png

标签: AI

Powered by emlog  蜀ICP备18021003号-1   sitemap

川公网安备 51019002001593号