Python 大模型微调
作者:追风剑情 发布于:2026-6-10 16:13 分类:AI
一、什么情况下需要做模型微调?
当你有一个预训练模型(如 BERT、ResNet、Whisper 等),并且希望将其应用到一个与预训练任务不完全相同的下游任务时,就需要微调。
二、什么情况下需要做全参数微调?
当你数据充足且需要最高性能,或者任务与预训练任务差异很大时,选择全参数微调。通常需要数千到数万条标注数据,以防止过拟合。数据越少,过拟合风险越高。
三、什么情况下只需要做全连接层微调?
如果预训练模型的特征提取器已经能够很好地表征你的数据。或者数据非常少(几十到几百条)、希望节省资源、或快速验证特征质量时,优先选择仅微调全连接层(线性探测)。
示例:
import os
# 国内镜像
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score
import numpy as np
# ---------- 1. 准备小数据集(中文评论情感) ----------
# 训练数据集
train_texts = [
# 负面 10 条
"太难吃了,服务态度也很差。",
"等了40分钟还没上菜,太失望了。",
"价格贵得离谱,完全不值。",
"菜品不新鲜,吃完肚子不舒服。",
"环境嘈杂,根本无法安静吃饭。",
"点错了菜还不给换,差评。",
"分量太少,两个人不够吃。",
"餐具不干净,有油渍。",
"外卖包装破损,汤洒了一袋子。",
"口味一般,不会再来第二次。",
# 正面 10 条
"味道很棒,服务也很周到。",
"环境优雅,适合约会。",
"价格实惠,分量足。",
"上菜速度快,味道正宗。",
"服务员很热情,主动倒水。",
"菜品精致,拍照很好看。",
"交通方便,位置好找。",
"团购超值,性价比高。",
"孩子很喜欢,说下次还要来。",
"整体体验非常好,推荐给大家。",
]
train_labels = [0,0,0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1,1,1]
# 验证数据集
val_texts = [
"味道还不错,就是价格稍贵。", # 偏正面
"上菜太慢了,等了半小时。", # 负面
"服务员很贴心,送了小礼物。", # 正面
"菜品太咸了,没法吃。", # 负面
]
val_labels = [1, 0, 1, 0]
# ---------- 2. 加载预训练模型和分词器 ----------
model_name = "distilbert-base-multilingual-cased" # 支持中文的小模型
# 加载一个与预训练模型配套的分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 给模型添加一个分类头。
# num_labels: 标签类别数
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# ---------- 3. 自定义Dataset类 ----------
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text, # 要编码的原始文本字符串
truncation=True, # 是否截断超长序列
padding='max_length', # 填充方式:填充到固定长度
max_length=self.max_len, # 最大序列长度(截断和填充的目标长度)
return_tensors='pt' # 返回 PyTorch 张量格式
)
return {
# 每个 token 在词汇表中的数字 ID
'input_ids': encoding['input_ids'].flatten(),
# 掩码列表 1:真实token;0:填充token
'attention_mask': encoding['attention_mask'].flatten(),
# 转成张量
'label': torch.tensor(label, dtype=torch.long)
}
# 训练数据集
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
# 验证数据集
val_dataset = TextDataset(val_texts, val_labels, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)
# ---------- 4. 设置优化器和训练参数 ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 将模型加载进CPU或GPU
model.to(device)
# 全参数训练
# 学习率通常比从头训练小10~100倍
optimizer = AdamW(model.parameters(), lr=2e-5)
# ---仅训练分类层(全链接层)---
# 冻结除分类头外的所有参数
#for name, param in model.named_parameters():
# if 'classifier' in name or 'pre_classifier' in name:
# param.requires_grad = True
# else:
# param.requires_grad = False
# 只将需要更新的参数传给优化器
#optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
# --- end ---
# 训练次数
epochs = 10
# ---------- 5. 训练循环(微调核心) ----------
for epoch in range(epochs):
# 启动训练模式
model.train()
total_loss = 0
for batch in train_loader:
# 数据与模型必须在同一设备中进行计算(例如:GPU)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
# 将模型所有可训练参数的梯度清零
optimizer.zero_grad()
# 执行前向传播
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
# 获取损失值
loss = outputs.loss
# 执行反向传播
loss.backward()
# 根据计算出的梯度更新模型参数
optimizer.step()
# 保存累计损失
total_loss += loss.item()
# 模型验证/评估
# 启动评估模式
model.eval()
# preds: 预测标签列表
# true_labels: 真实标签列表
preds, true_labels = [], []
# 临时禁用梯度计算
with torch.no_grad():
for batch in val_loader:
# 数据与模型必须在同一设备中进行计算(例如:GPU)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(input_ids, attention_mask=attention_mask)
# 获取模型计算出的原始分数
logits = outputs.logits
# 返回分数最高的预测类别
pred = torch.argmax(logits, dim=1)
# pred.cpu().numpy() 将预测张量从 GPU 移到 CPU,并转换为 NumPy 数组
# preds.extend() 追加到总列表末尾
preds.extend(pred.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
acc = accuracy_score(true_labels, preds)
print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f} | Val Acc: {acc:.4f}")
# ---------- 6. 测试一条新评论 ----------
test_text = "这个产品太垃圾了,千万别买。"
inputs = tokenizer(test_text, return_tensors='pt', truncation=True, padding=True).to(device)
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
print(f"\n测试评论: {test_text}")
print(f"预测结果: {'正面' if pred == 1 else '负面'}")
标签: AI
日历
最新文章
随机文章
热门文章
分类
存档
- 2026年6月(5)
- 2026年5月(29)
- 2026年4月(7)
- 2026年3月(15)
- 2026年2月(3)
- 2026年1月(6)
- 2025年12月(1)
- 2025年11月(1)
- 2025年9月(3)
- 2025年7月(4)
- 2025年6月(5)
- 2025年5月(1)
- 2025年4月(5)
- 2025年3月(4)
- 2025年2月(3)
- 2025年1月(1)
- 2024年12月(5)
- 2024年11月(5)
- 2024年10月(5)
- 2024年9月(3)
- 2024年8月(3)
- 2024年7月(11)
- 2024年6月(3)
- 2024年5月(9)
- 2024年4月(10)
- 2024年3月(11)
- 2024年2月(24)
- 2024年1月(12)
- 2023年12月(3)
- 2023年11月(9)
- 2023年10月(7)
- 2023年9月(2)
- 2023年8月(7)
- 2023年7月(9)
- 2023年6月(6)
- 2023年5月(7)
- 2023年4月(11)
- 2023年3月(6)
- 2023年2月(11)
- 2023年1月(8)
- 2022年12月(2)
- 2022年11月(4)
- 2022年10月(10)
- 2022年9月(2)
- 2022年8月(13)
- 2022年7月(7)
- 2022年6月(11)
- 2022年5月(18)
- 2022年4月(29)
- 2022年3月(5)
- 2022年2月(6)
- 2022年1月(8)
- 2021年12月(5)
- 2021年11月(3)
- 2021年10月(4)
- 2021年9月(9)
- 2021年8月(14)
- 2021年7月(8)
- 2021年6月(5)
- 2021年5月(2)
- 2021年4月(3)
- 2021年3月(7)
- 2021年2月(2)
- 2021年1月(8)
- 2020年12月(7)
- 2020年11月(2)
- 2020年10月(6)
- 2020年9月(9)
- 2020年8月(10)
- 2020年7月(9)
- 2020年6月(18)
- 2020年5月(4)
- 2020年4月(25)
- 2020年3月(38)
- 2020年1月(21)
- 2019年12月(13)
- 2019年11月(29)
- 2019年10月(44)
- 2019年9月(17)
- 2019年8月(18)
- 2019年7月(25)
- 2019年6月(25)
- 2019年5月(17)
- 2019年4月(10)
- 2019年3月(36)
- 2019年2月(35)
- 2019年1月(28)
- 2018年12月(30)
- 2018年11月(22)
- 2018年10月(4)
- 2018年9月(7)
- 2018年8月(13)
- 2018年7月(13)
- 2018年6月(6)
- 2018年5月(5)
- 2018年4月(13)
- 2018年3月(5)
- 2018年2月(3)
- 2018年1月(8)
- 2017年12月(35)
- 2017年11月(17)
- 2017年10月(16)
- 2017年9月(17)
- 2017年8月(20)
- 2017年7月(34)
- 2017年6月(17)
- 2017年5月(15)
- 2017年4月(32)
- 2017年3月(8)
- 2017年2月(2)
- 2017年1月(5)
- 2016年12月(14)
- 2016年11月(26)
- 2016年10月(12)
- 2016年9月(25)
- 2016年8月(32)
- 2016年7月(14)
- 2016年6月(21)
- 2016年5月(17)
- 2016年4月(13)
- 2016年3月(8)
- 2016年2月(8)
- 2016年1月(18)
- 2015年12月(13)
- 2015年11月(15)
- 2015年10月(12)
- 2015年9月(18)
- 2015年8月(21)
- 2015年7月(35)
- 2015年6月(13)
- 2015年5月(9)
- 2015年4月(4)
- 2015年3月(5)
- 2015年2月(4)
- 2015年1月(13)
- 2014年12月(7)
- 2014年11月(5)
- 2014年10月(4)
- 2014年9月(8)
- 2014年8月(16)
- 2014年7月(26)
- 2014年6月(22)
- 2014年5月(28)
- 2014年4月(15)
友情链接
- Unity官网
- Unity圣典
- Unity在线手册
- Unity中文手册(圣典)
- Unity官方中文论坛
- Unity游戏蛮牛用户文档
- Unity下载存档
- Unity引擎源码下载
- Unity服务
- Unity Ads
- wiki.unity3d
- Visual Studio Code官网
- SenseAR开发文档
- MSDN
- C# 参考
- C# 编程指南
- .NET Framework类库
- .NET 文档
- .NET 开发
- WPF官方文档
- uLua
- xLua
- SharpZipLib
- Protobuf-net
- Protobuf.js
- OpenSSL
- OPEN CASCADE
- JSON
- MessagePack
- C在线工具
- 游戏蛮牛
- GreenVPN
- 聚合数据
- 热云
- 融云
- 腾讯云
- 腾讯开放平台
- 腾讯游戏服务
- 腾讯游戏开发者平台
- 腾讯课堂
- 微信开放平台
- 腾讯实时音视频
- 腾讯即时通信IM
- 微信公众平台技术文档
- 白鹭引擎官网
- 白鹭引擎开放平台
- 白鹭引擎开发文档
- FairyGUI编辑器
- PureMVC-TypeScript
- 讯飞开放平台
- 亲加通讯云
- Cygwin
- Mono开发者联盟
- Scut游戏服务器引擎
- KBEngine游戏服务器引擎
- Photon游戏服务器引擎
- 码云
- SharpSvn
- 腾讯bugly
- 4399原创平台
- 开源中国
- Firebase
- Firebase-Admob-Unity
- google-services-unity
- Firebase SDK for Unity
- Google-Firebase-SDK
- AppsFlyer SDK
- android-repository
- CQASO
- Facebook开发者平台
- gradle下载
- GradleBuildTool下载
- Android Developers
- Google中国开发者
- AndroidDevTools
- Android社区
- Android开发工具
- Google Play Games Services
- Google商店
- Google APIs for Android
- 金钱豹VPN
- TouchSense SDK
- MakeHuman
- Online RSA Key Converter
- Windows UWP应用
- Visual Studio For Unity
- Open CASCADE Technology
- 慕课网
- 阿里云服务器ECS
- 在线免费文字转语音系统
- AI Studio
- 网云穿
- 百度网盘开放平台
- 迅捷画图
- 菜鸟工具
- [CSDN] 程序员研修院
- 华为人脸识别
- 百度AR导航导览SDK
- 海康威视官网
- 海康开放平台
- 海康SDK下载
- git download
- Open CASCADE
- CascadeStudio
- OpenClaw中文社区
- three.js manual
- SVG官方文档
交流QQ群
-
Flash游戏设计: 86184192
Unity游戏设计: 171855449
游戏设计订阅号








