Python RAG 多模态特征融合与微调
作者:追风剑情 发布于:2026-5-20 17:29 分类:AI
本示例演示了如何构建一个图文多模态分类模型,从数据预处理、双编码器设计、特征融合到选择性微调的全流程,并掌握了 PyTorch 的基本训练技巧以及解决实际环境问题的方法。
# pip install torch torchvision pillow requests
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.models import ResNet50_Weights
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import requests
import json
# ------------------ 1. 定义 Ollama 文本嵌入函数 ------------------
def get_ollama_embedding(text, model="nomic-embed-text"):
"""通过 Ollama API 获取文本向量"""
url = "http://localhost:11434/api/embeddings"
payload = {
"model": model,
"prompt": text
}
response = requests.post(url, json=payload)
if response.status_code == 200:
return response.json()["embedding"]
else:
raise Exception(f"Ollama API error: {response.text}")
# ------------------ 2. 多模态数据集 ------------------
# 定义多模态数据集类,继承自 PyTorch 的 Dataset 基类
class MultimodalDataset(Dataset):
# 构造函数:初始化数据集对象
def __init__(self, image_paths, texts, labels, transform=None):
# 存储所有图像文件的路径列表,例如 ['cat.jpg', 'dog.jpg']
self.image_paths = image_paths
# 存储所有文本字符串的列表,例如 ['一只猫', '一只狗']
self.texts = texts
# 存储所有类别标签的列表(整数),例如 [0, 1]
self.labels = labels
# 图像预处理操作(如 Resize、ToTensor、Normalize),默认为 None
self.transform = transform
# 返回数据集中的样本总数,供 DataLoader 使用以确定迭代长度
def __len__(self):
return len(self.labels)
# 根据索引 idx 获取一个样本(图像张量、文本嵌入向量、标签张量)
def __getitem__(self, idx):
# ---------- 图像处理部分 ----------
# 根据图像路径打开图像文件,并转换为 RGB 彩色模式(3通道)
image = Image.open(self.image_paths[idx]).convert("RGB")
# 如果定义了图像预处理流水线(transform),则应用到图像上
if self.transform:
image = self.transform(image) # 此时 image 已转为预处理后的张量
# ---------- 文本处理部分(通过 Ollama 获取嵌入向量)----------
# 获取该样本对应的原始文本字符串
text = self.texts[idx]
# 调用外部函数 get_ollama_embedding,向本地 Ollama 服务发送请求,
# 让嵌入模型(如 nomic-embed-text)将文本转换成固定维度的浮点数列表
text_embedding = get_ollama_embedding(text) # 返回 list of float,例如 [0.123, -0.456, ...]
# 将列表转换为 PyTorch 张量,数据类型为 float32,便于后续计算和拼接
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
# ---------- 标签处理 ----------
# 获取该样本的类别标签(整数),转换为 PyTorch 长整型张量
label = torch.tensor(self.labels[idx], dtype=torch.long)
# 返回一个元组,包含:预处理后的图像张量、文本嵌入向量、标签张量
return image, text_embedding, label
# ------------------ 3. 多模态模型 ------------------
# 定义多模态模型类,继承自 PyTorch 的神经网络模块基类 nn.Module
class MultimodalModel(nn.Module):
# 构造函数:初始化模型的各个子模块
def __init__(self, image_feature_dim=512, text_embed_dim=768, num_classes=10):
# 调用父类 nn.Module 的构造函数,完成必要的内部初始化
super().__init__()
# ---------- 图像分支 ----------
# 加载预训练的 ResNet50 模型(在 ImageNet 上训练过的权重)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# 将 ResNet 的最后一层(全连接分类层)去掉,只保留前面的卷积层和池化层
# resnet.children() 返回 ResNet 的所有子模块(如卷积层、批归一化、ReLU、全连接等)
# list(...)[:-1] 去掉最后一个全连接层(原本输出 1000 类)
# nn.Sequential 按顺序封装这些层,输出特征图形状为 (batch, 2048, 1, 1)
self.image_model = nn.Sequential(*list(resnet.children())[:-1])
# 定义一个全连接层,将 ResNet 提取的 2048 维特征降维到 image_feature_dim(默认 512)
# 输入维度 2048,输出维度 image_feature_dim
self.image_fc = nn.Linear(2048, image_feature_dim)
# ---------- 文本分支 ----------
# 定义一个全连接层,对输入的文本嵌入向量做线性变换
# 输入维度 text_embed_dim(默认 768,例如 nomic-embed-text 的维度),输出维度相同
# 这个层是可选的,主要用于让文本特征在后续拼接前进行适应性的变换
self.text_fc = nn.Linear(text_embed_dim, text_embed_dim)
# ---------- 融合分类层 ----------
# 定义一个全连接层,将拼接后的图像特征和文本特征映射到最终的类别数 num_classes
# 输入维度 = image_feature_dim + text_embed_dim
# 输出维度 = num_classes(例如 2 分类输出 2 个 logit)
self.classifier = nn.Linear(image_feature_dim + text_embed_dim, num_classes)
# 前向传播函数:定义数据如何从输入到输出
def forward(self, image, text_embed):
# image: 输入的图像张量,形状为 (batch_size, 3, H, W)
# text_embed: 输入的文本嵌入向量,形状为 (batch_size, text_embed_dim)
# 将图像输入图像模型(ResNet 特征提取器)
# 输出形状:(batch_size, 2048, 1, 1) —— 高度和宽度都是 1
img_feat = self.image_model(image)
# 将 img_feat 展平:保留 batch_size 维度,其余所有维度合并为一维
# img_feat.size(0) 是 batch_size,-1 表示自动计算剩余维度总数(2048*1*1 = 2048)
# 展平后形状:(batch_size, 2048)
img_feat = img_feat.view(img_feat.size(0), -1)
# 通过全连接层降维,得到低维图像特征
# 输出形状:(batch_size, image_feature_dim) 例如 (batch, 512)
img_feat = self.image_fc(img_feat)
# 将文本嵌入向量通过文本全连接层(线性变换)
# 输出形状:(batch_size, text_embed_dim) 例如 (batch, 768)
text_feat = self.text_fc(text_embed)
# 将图像特征和文本特征在特征维度上拼接(dim=1 表示第二维,即特征维)
# 拼接后的形状:(batch_size, image_feature_dim + text_embed_dim)
combined = torch.cat([img_feat, text_feat], dim=1)
# 通过分类层得到每个类别的原始分数(logits)
# 输出形状:(batch_size, num_classes)
out = self.classifier(combined)
# 返回分类结果(通常后续会传给损失函数,如交叉熵)
return out
# ------------------ 4. 训练准备 ------------------
# 假设有图像路径列表、文本列表、标签列表(示例用假数据)
image_paths = ["dog.png", "cat.png"]
texts = ["a dog running", "a cat sleeping"]
labels = [0, 1]
# 将原始图片转换成模型可以接受的标准化张量
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = MultimodalDataset(image_paths, texts, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 模型、优化器、损失函数
model = MultimodalModel(image_feature_dim=256, text_embed_dim=768, num_classes=2)
# 只微调图像分支和分类层,冻结文本分支的参数(这里文本分支没有可训练参数,因为 embedding 来自外部)
# 创建一个空列表,用于存放需要更新(训练)的模型参数
params_to_update = []
# 遍历模型中的所有参数(name 是参数名称字符串,param 是参数张量)
for name, param in model.named_parameters():
# 如果参数名中包含 'image_model' 或 'image_fc' 或 'classifier' 这三个标识之一
if 'image_model' in name or 'image_fc' in name or 'classifier' in name:
# 将该参数的 requires_grad 属性设为 True,表示在反向传播时会计算梯度并更新这个参数
param.requires_grad = True
# 把这个参数加入到待更新的参数列表 params_to_update 中
params_to_update.append(param)
else:
# 对于其他参数(比如文本模型的参数),设置 requires_grad = False
# 这样它们就不会被优化器更新,即参数被“冻结”
param.requires_grad = False
# 定义优化器:使用 Adam 算法,只优化 params_to_update 列表中的参数
# 学习率 lr 设为 0.0001(即 1e-4)
optimizer = optim.Adam(params_to_update, lr=1e-4)
# 定义损失函数:交叉熵损失(CrossEntropyLoss),常用于多分类任务
# 它会自动计算预测结果与真实标签之间的差异,并作为优化的目标
criterion = nn.CrossEntropyLoss()
# ------------------ 5. 训练循环 ------------------
# 将模型设置为训练模式(启用 Dropout、BatchNorm 等层的训练行为)
model.train()
# 外层循环:训练 5 个轮次(epoch)
for epoch in range(5):
# 内层循环:遍历数据加载器,每次返回一个 batch
# batch_idx: batch 的索引(从 0 开始)
# images: 当前 batch 的图像数据(形状如 [batch_size, 3, 224, 224])
# text_embeds: 当前 batch 的文本嵌入向量(形状如 [batch_size, text_dim])
# labels: 当前 batch 的真实标签(形状如 [batch_size])
for batch_idx, (images, text_embeds, labels) in enumerate(dataloader):
# 将优化器中已有的梯度清零,防止梯度累积(如果不清零,梯度会累加)
optimizer.zero_grad()
# 前向传播:将图像和文本嵌入输入模型,得到预测输出(logits)
outputs = model(images, text_embeds)
# 计算损失:比较预测值 outputs 与真实标签 labels,得到标量损失值
loss = criterion(outputs, labels)
# 反向传播:根据损失自动计算每个需要梯度的参数的梯度
loss.backward()
# 参数更新:优化器利用计算出的梯度更新模型参数(即执行一步梯度下降)
optimizer.step()
# 打印当前训练进度和损失值
# epoch+1:将 0-based 转为 1-based 显示
# batch_idx+1:同上
# loss.item():将损失张量转换为 Python 浮点数(以 .4f 格式保留 4 位小数)
print(f'Epoch [{epoch+1}/5], Batch [{batch_idx+1}], Loss: {loss.item():.4f}')
#-------------------- 6. 用训练好的模型进行预测 --------------
model.eval()
with torch.no_grad():
# 假设你想预测第一张图片(dog.png)
image, text_embed, label = dataset[0] # 获取预处理后的数据
# 增加 batch 维度(模型需要 batch 维度)
image = image.unsqueeze(0) # (3,224,224) -> (1,3,224,224)
text_embed = text_embed.unsqueeze(0) # (768,) -> (1,768)
output = model(image, text_embed)
pred = torch.argmax(output, dim=1).item()
print(f"预测类别: {pred},实际标签: {label}")
#------------------- 7. 保存训练好的模型 --------------------
torch.save(model.state_dict(), "multimodal_model.pth")
标签: AI
日历
最新文章
随机文章
热门文章
分类
存档
- 2026年5月(14)
- 2026年4月(7)
- 2026年3月(15)
- 2026年2月(3)
- 2026年1月(6)
- 2025年12月(1)
- 2025年11月(1)
- 2025年9月(3)
- 2025年7月(4)
- 2025年6月(5)
- 2025年5月(1)
- 2025年4月(5)
- 2025年3月(4)
- 2025年2月(3)
- 2025年1月(1)
- 2024年12月(5)
- 2024年11月(5)
- 2024年10月(5)
- 2024年9月(3)
- 2024年8月(3)
- 2024年7月(11)
- 2024年6月(3)
- 2024年5月(9)
- 2024年4月(10)
- 2024年3月(11)
- 2024年2月(24)
- 2024年1月(12)
- 2023年12月(3)
- 2023年11月(9)
- 2023年10月(7)
- 2023年9月(2)
- 2023年8月(7)
- 2023年7月(9)
- 2023年6月(6)
- 2023年5月(7)
- 2023年4月(11)
- 2023年3月(6)
- 2023年2月(11)
- 2023年1月(8)
- 2022年12月(2)
- 2022年11月(4)
- 2022年10月(10)
- 2022年9月(2)
- 2022年8月(13)
- 2022年7月(7)
- 2022年6月(11)
- 2022年5月(18)
- 2022年4月(29)
- 2022年3月(5)
- 2022年2月(6)
- 2022年1月(8)
- 2021年12月(5)
- 2021年11月(3)
- 2021年10月(4)
- 2021年9月(9)
- 2021年8月(14)
- 2021年7月(8)
- 2021年6月(5)
- 2021年5月(2)
- 2021年4月(3)
- 2021年3月(7)
- 2021年2月(2)
- 2021年1月(8)
- 2020年12月(7)
- 2020年11月(2)
- 2020年10月(6)
- 2020年9月(9)
- 2020年8月(10)
- 2020年7月(9)
- 2020年6月(18)
- 2020年5月(4)
- 2020年4月(25)
- 2020年3月(38)
- 2020年1月(21)
- 2019年12月(13)
- 2019年11月(29)
- 2019年10月(44)
- 2019年9月(17)
- 2019年8月(18)
- 2019年7月(25)
- 2019年6月(25)
- 2019年5月(17)
- 2019年4月(10)
- 2019年3月(36)
- 2019年2月(35)
- 2019年1月(28)
- 2018年12月(30)
- 2018年11月(22)
- 2018年10月(4)
- 2018年9月(7)
- 2018年8月(13)
- 2018年7月(13)
- 2018年6月(6)
- 2018年5月(5)
- 2018年4月(13)
- 2018年3月(5)
- 2018年2月(3)
- 2018年1月(8)
- 2017年12月(35)
- 2017年11月(17)
- 2017年10月(16)
- 2017年9月(17)
- 2017年8月(20)
- 2017年7月(34)
- 2017年6月(17)
- 2017年5月(15)
- 2017年4月(32)
- 2017年3月(8)
- 2017年2月(2)
- 2017年1月(5)
- 2016年12月(14)
- 2016年11月(26)
- 2016年10月(12)
- 2016年9月(25)
- 2016年8月(32)
- 2016年7月(14)
- 2016年6月(21)
- 2016年5月(17)
- 2016年4月(13)
- 2016年3月(8)
- 2016年2月(8)
- 2016年1月(18)
- 2015年12月(13)
- 2015年11月(15)
- 2015年10月(12)
- 2015年9月(18)
- 2015年8月(21)
- 2015年7月(35)
- 2015年6月(13)
- 2015年5月(9)
- 2015年4月(4)
- 2015年3月(5)
- 2015年2月(4)
- 2015年1月(13)
- 2014年12月(7)
- 2014年11月(5)
- 2014年10月(4)
- 2014年9月(8)
- 2014年8月(16)
- 2014年7月(26)
- 2014年6月(22)
- 2014年5月(28)
- 2014年4月(15)
友情链接
- Unity官网
- Unity圣典
- Unity在线手册
- Unity中文手册(圣典)
- Unity官方中文论坛
- Unity游戏蛮牛用户文档
- Unity下载存档
- Unity引擎源码下载
- Unity服务
- Unity Ads
- wiki.unity3d
- Visual Studio Code官网
- SenseAR开发文档
- MSDN
- C# 参考
- C# 编程指南
- .NET Framework类库
- .NET 文档
- .NET 开发
- WPF官方文档
- uLua
- xLua
- SharpZipLib
- Protobuf-net
- Protobuf.js
- OpenSSL
- OPEN CASCADE
- JSON
- MessagePack
- C在线工具
- 游戏蛮牛
- GreenVPN
- 聚合数据
- 热云
- 融云
- 腾讯云
- 腾讯开放平台
- 腾讯游戏服务
- 腾讯游戏开发者平台
- 腾讯课堂
- 微信开放平台
- 腾讯实时音视频
- 腾讯即时通信IM
- 微信公众平台技术文档
- 白鹭引擎官网
- 白鹭引擎开放平台
- 白鹭引擎开发文档
- FairyGUI编辑器
- PureMVC-TypeScript
- 讯飞开放平台
- 亲加通讯云
- Cygwin
- Mono开发者联盟
- Scut游戏服务器引擎
- KBEngine游戏服务器引擎
- Photon游戏服务器引擎
- 码云
- SharpSvn
- 腾讯bugly
- 4399原创平台
- 开源中国
- Firebase
- Firebase-Admob-Unity
- google-services-unity
- Firebase SDK for Unity
- Google-Firebase-SDK
- AppsFlyer SDK
- android-repository
- CQASO
- Facebook开发者平台
- gradle下载
- GradleBuildTool下载
- Android Developers
- Google中国开发者
- AndroidDevTools
- Android社区
- Android开发工具
- Google Play Games Services
- Google商店
- Google APIs for Android
- 金钱豹VPN
- TouchSense SDK
- MakeHuman
- Online RSA Key Converter
- Windows UWP应用
- Visual Studio For Unity
- Open CASCADE Technology
- 慕课网
- 阿里云服务器ECS
- 在线免费文字转语音系统
- AI Studio
- 网云穿
- 百度网盘开放平台
- 迅捷画图
- 菜鸟工具
- [CSDN] 程序员研修院
- 华为人脸识别
- 百度AR导航导览SDK
- 海康威视官网
- 海康开放平台
- 海康SDK下载
- git download
- Open CASCADE
- CascadeStudio
- OpenClaw中文社区
- three.js manual
- SVG官方文档
交流QQ群
-
Flash游戏设计: 86184192
Unity游戏设计: 171855449
游戏设计订阅号







