Python ResNet 训练二分类模型
作者:追风剑情 发布于:2026-5-21 18:11 分类:AI
用 ResNet 训练一个图像二分类模型。
训练脚本:train_binary.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader, random_split
import os
# ------------------- 1. 超参数 -------------------
# 每次同时处理几张图片
BATCH_SIZE = 8
# 训练次数
EPOCHS = 20
# 学习率
LEARNING_RATE = 0.001
# 当前目录,下面有 circuit_boards 和 non_circuit_boards 两个文件夹
DATA_DIR = "./data"
# 二分类
NUM_CLASSES = 2
# ------------------- 2. 数据预处理 -------------------
train_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ------------------- 3. 加载数据集 -------------------
# 假设文件夹结构:
# circuit_boards/ -> 正样本,自动分配标签 0(电路板)
# non_circuit_boards/ -> 负样本,自动分配标签 1(非电路板)
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=train_transforms)
# 检查类别映射
# 预期 ['circuit_boards', 'non_circuit_boards']
print("类别映射:", full_dataset.classes)
# 划分训练集和验证集(80% 训练,20% 验证)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# 为验证集使用不同的预处理(无数据增强)
val_dataset.dataset.transform = val_transforms
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"训练样本数: {len(train_dataset)}, 验证样本数: {len(val_dataset)}")
# ------------------- 4. 加载预训练模型(本地权重) -------------------
model = models.resnet50(weights=None)
state_dict = torch.load("resnet50-0676ba61.pth", map_location='cpu')
model.load_state_dict(state_dict)
# 替换最后的全连接层为二分类头
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)
# 冻结前面的所有层(可选,也可以全部微调,这里演示只训练新头部)
for param in model.parameters():
param.requires_grad = False
# 只让最后的 fc 层可训练
for param in model.fc.parameters():
param.requires_grad = True
# 判断模型是加载进GPU还是CPU运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# ------------------- 5. 损失函数和优化器 -------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=LEARNING_RATE)
# ------------------- 6. 训练与验证 -------------------
best_acc = 0.0
for epoch in range(EPOCHS):
# 训练阶段
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
# 将输入图像和标签数据从CPU内存移动到指定的计算设备(GPU或CPU)上
# .to(device) 确保后续计算在同一个设备上进行,避免跨设备数据传输
inputs, labels = inputs.to(device), labels.to(device)
# 将优化器中所有待更新参数的梯度清零
# 因为梯度在每次反向传播后会累积,不手动清零会导致梯度叠加,影响参数更新
optimizer.zero_grad()
# 前向传播:将输入图像送入模型,得到原始输出logits(尚未经过softmax)
# 对于分类任务,输出形状通常为 [batch_size, num_classes]
outputs = model(inputs)
# 计算损失:使用损失函数(例如交叉熵)比较模型预测值outputs与真实标签labels
# 得到的loss是一个标量张量,代表当前batch的平均损失
loss = criterion(outputs, labels)
# 反向传播:根据损失值自动计算模型中每个可训练参数的梯度
# 梯度会累积到各参数的 .grad 属性中,用于后续优化器更新
loss.backward()
# 优化器更新参数:根据计算出的梯度和预设的学习率等超参数,更新模型参数
# 这是模型学习的关键步骤,参数沿梯度下降方向调整
optimizer.step()
# 累加当前 batch 的总损失(非平均损失)
# loss.item() 是当前 batch 的平均损失(标量),乘以 batch 大小 inputs.size(0) 得到该 batch 的总损失
# 这样便于后面计算整个 epoch 的平均损失(总损失 / 总样本数)
running_loss += loss.item() * inputs.size(0)
# 对模型输出的 logits 沿着类别维度(dim=1)取最大值
# torch.max(outputs, 1) 返回一个元组:(最大值, 最大值所在的索引)
# 使用下划线 _ 忽略最大值本身,只保留索引,赋值给 predicted
# predicted 形状为 [batch_size],每个元素是模型预测的类别标签(0 或 1,对于二分类)
_, predicted = torch.max(outputs, 1)
# 累加当前 batch 的样本总数
# labels.size(0) 等于 inputs.size(0),即 batch 大小
total += labels.size(0)
# 计算当前 batch 中预测正确的样本数
# predicted == labels 返回布尔型张量,逐元素比较是否相等
# .sum() 对布尔张量求和(True 视为 1,False 视为 0),得到该 batch 正确的个数(张量)
# .item() 将该张量的值转换为 Python 标量(整数),便于累加到 correct 变量
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / total
epoch_acc = correct / total
# 验证阶段
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_acc = val_correct / val_total
print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f} | Val Acc: {val_acc:.4f}")
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_binary_circuit.pth")
print(f" -> 保存最佳模型 (验证准确率: {best_acc:.4f})")
print("训练完成!最佳模型已保存为 best_binary_circuit.pth")
测试脚本:test_binary.py
import torch
from torchvision import transforms, models
from PIL import Image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载模型结构
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_binary_circuit.pth"))
model.eval()
def predict(img_path):
img = Image.open(img_path).convert('RGB')
#在索引0处插入一个维度
img_t = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(img_t)
prob = torch.softmax(out[0], dim=0)
circuit_prob = prob[0].item() # 索引0对应电路板类别
non_circuit_prob = prob[1].item()
print(f"图片: {img_path}")
print(f" 电路板概率: {circuit_prob:.4f}, 非电路板概率: {non_circuit_prob:.4f}")
if circuit_prob > non_circuit_prob:
print(" 结论: 电路板")
else:
print(" 结论: 非电路板")
# 测试
predict("data/circuit_boards/51.jpg") # 训练集中的电路板
predict("new_circuit.jpg") # 新电路板(如果有)
predict("cat.png") # 猫图片
标签: AI
日历
最新文章
随机文章
热门文章
分类
存档
- 2026年5月(15)
- 2026年4月(7)
- 2026年3月(15)
- 2026年2月(3)
- 2026年1月(6)
- 2025年12月(1)
- 2025年11月(1)
- 2025年9月(3)
- 2025年7月(4)
- 2025年6月(5)
- 2025年5月(1)
- 2025年4月(5)
- 2025年3月(4)
- 2025年2月(3)
- 2025年1月(1)
- 2024年12月(5)
- 2024年11月(5)
- 2024年10月(5)
- 2024年9月(3)
- 2024年8月(3)
- 2024年7月(11)
- 2024年6月(3)
- 2024年5月(9)
- 2024年4月(10)
- 2024年3月(11)
- 2024年2月(24)
- 2024年1月(12)
- 2023年12月(3)
- 2023年11月(9)
- 2023年10月(7)
- 2023年9月(2)
- 2023年8月(7)
- 2023年7月(9)
- 2023年6月(6)
- 2023年5月(7)
- 2023年4月(11)
- 2023年3月(6)
- 2023年2月(11)
- 2023年1月(8)
- 2022年12月(2)
- 2022年11月(4)
- 2022年10月(10)
- 2022年9月(2)
- 2022年8月(13)
- 2022年7月(7)
- 2022年6月(11)
- 2022年5月(18)
- 2022年4月(29)
- 2022年3月(5)
- 2022年2月(6)
- 2022年1月(8)
- 2021年12月(5)
- 2021年11月(3)
- 2021年10月(4)
- 2021年9月(9)
- 2021年8月(14)
- 2021年7月(8)
- 2021年6月(5)
- 2021年5月(2)
- 2021年4月(3)
- 2021年3月(7)
- 2021年2月(2)
- 2021年1月(8)
- 2020年12月(7)
- 2020年11月(2)
- 2020年10月(6)
- 2020年9月(9)
- 2020年8月(10)
- 2020年7月(9)
- 2020年6月(18)
- 2020年5月(4)
- 2020年4月(25)
- 2020年3月(38)
- 2020年1月(21)
- 2019年12月(13)
- 2019年11月(29)
- 2019年10月(44)
- 2019年9月(17)
- 2019年8月(18)
- 2019年7月(25)
- 2019年6月(25)
- 2019年5月(17)
- 2019年4月(10)
- 2019年3月(36)
- 2019年2月(35)
- 2019年1月(28)
- 2018年12月(30)
- 2018年11月(22)
- 2018年10月(4)
- 2018年9月(7)
- 2018年8月(13)
- 2018年7月(13)
- 2018年6月(6)
- 2018年5月(5)
- 2018年4月(13)
- 2018年3月(5)
- 2018年2月(3)
- 2018年1月(8)
- 2017年12月(35)
- 2017年11月(17)
- 2017年10月(16)
- 2017年9月(17)
- 2017年8月(20)
- 2017年7月(34)
- 2017年6月(17)
- 2017年5月(15)
- 2017年4月(32)
- 2017年3月(8)
- 2017年2月(2)
- 2017年1月(5)
- 2016年12月(14)
- 2016年11月(26)
- 2016年10月(12)
- 2016年9月(25)
- 2016年8月(32)
- 2016年7月(14)
- 2016年6月(21)
- 2016年5月(17)
- 2016年4月(13)
- 2016年3月(8)
- 2016年2月(8)
- 2016年1月(18)
- 2015年12月(13)
- 2015年11月(15)
- 2015年10月(12)
- 2015年9月(18)
- 2015年8月(21)
- 2015年7月(35)
- 2015年6月(13)
- 2015年5月(9)
- 2015年4月(4)
- 2015年3月(5)
- 2015年2月(4)
- 2015年1月(13)
- 2014年12月(7)
- 2014年11月(5)
- 2014年10月(4)
- 2014年9月(8)
- 2014年8月(16)
- 2014年7月(26)
- 2014年6月(22)
- 2014年5月(28)
- 2014年4月(15)
友情链接
- Unity官网
- Unity圣典
- Unity在线手册
- Unity中文手册(圣典)
- Unity官方中文论坛
- Unity游戏蛮牛用户文档
- Unity下载存档
- Unity引擎源码下载
- Unity服务
- Unity Ads
- wiki.unity3d
- Visual Studio Code官网
- SenseAR开发文档
- MSDN
- C# 参考
- C# 编程指南
- .NET Framework类库
- .NET 文档
- .NET 开发
- WPF官方文档
- uLua
- xLua
- SharpZipLib
- Protobuf-net
- Protobuf.js
- OpenSSL
- OPEN CASCADE
- JSON
- MessagePack
- C在线工具
- 游戏蛮牛
- GreenVPN
- 聚合数据
- 热云
- 融云
- 腾讯云
- 腾讯开放平台
- 腾讯游戏服务
- 腾讯游戏开发者平台
- 腾讯课堂
- 微信开放平台
- 腾讯实时音视频
- 腾讯即时通信IM
- 微信公众平台技术文档
- 白鹭引擎官网
- 白鹭引擎开放平台
- 白鹭引擎开发文档
- FairyGUI编辑器
- PureMVC-TypeScript
- 讯飞开放平台
- 亲加通讯云
- Cygwin
- Mono开发者联盟
- Scut游戏服务器引擎
- KBEngine游戏服务器引擎
- Photon游戏服务器引擎
- 码云
- SharpSvn
- 腾讯bugly
- 4399原创平台
- 开源中国
- Firebase
- Firebase-Admob-Unity
- google-services-unity
- Firebase SDK for Unity
- Google-Firebase-SDK
- AppsFlyer SDK
- android-repository
- CQASO
- Facebook开发者平台
- gradle下载
- GradleBuildTool下载
- Android Developers
- Google中国开发者
- AndroidDevTools
- Android社区
- Android开发工具
- Google Play Games Services
- Google商店
- Google APIs for Android
- 金钱豹VPN
- TouchSense SDK
- MakeHuman
- Online RSA Key Converter
- Windows UWP应用
- Visual Studio For Unity
- Open CASCADE Technology
- 慕课网
- 阿里云服务器ECS
- 在线免费文字转语音系统
- AI Studio
- 网云穿
- 百度网盘开放平台
- 迅捷画图
- 菜鸟工具
- [CSDN] 程序员研修院
- 华为人脸识别
- 百度AR导航导览SDK
- 海康威视官网
- 海康开放平台
- 海康SDK下载
- git download
- Open CASCADE
- CascadeStudio
- OpenClaw中文社区
- three.js manual
- SVG官方文档
交流QQ群
-
Flash游戏设计: 86184192
Unity游戏设计: 171855449
游戏设计订阅号








