Python ResNet 训练二分类模型

作者:追风剑情 发布于:2026-5-21 18:11 分类:AI

用 ResNet 训练一个图像二分类模型。

训练脚本:train_binary.py

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader, random_split
import os

# ------------------- 1. 超参数 -------------------
# 每次同时处理几张图片
BATCH_SIZE = 8
# 训练次数
EPOCHS = 20
# 学习率
LEARNING_RATE = 0.001
# 当前目录,下面有 circuit_boards 和 non_circuit_boards 两个文件夹
DATA_DIR = "./data"
# 二分类
NUM_CLASSES = 2

# ------------------- 2. 数据预处理 -------------------
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ------------------- 3. 加载数据集 -------------------
# 假设文件夹结构:
#   circuit_boards/       -> 正样本,自动分配标签 0(电路板)
#   non_circuit_boards/   -> 负样本,自动分配标签 1(非电路板)
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=train_transforms)
# 检查类别映射
# 预期 ['circuit_boards', 'non_circuit_boards']
print("类别映射:", full_dataset.classes)

# 划分训练集和验证集(80% 训练,20% 验证)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# 为验证集使用不同的预处理(无数据增强)
val_dataset.dataset.transform = val_transforms

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"训练样本数: {len(train_dataset)}, 验证样本数: {len(val_dataset)}")

# ------------------- 4. 加载预训练模型(本地权重) -------------------
model = models.resnet50(weights=None)
state_dict = torch.load("resnet50-0676ba61.pth", map_location='cpu')
model.load_state_dict(state_dict)

# 替换最后的全连接层为二分类头
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)

# 冻结前面的所有层(可选,也可以全部微调,这里演示只训练新头部)
for param in model.parameters():
    param.requires_grad = False
# 只让最后的 fc 层可训练
for param in model.fc.parameters():
    param.requires_grad = True

# 判断模型是加载进GPU还是CPU运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# ------------------- 5. 损失函数和优化器 -------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=LEARNING_RATE)

# ------------------- 6. 训练与验证 -------------------
best_acc = 0.0
for epoch in range(EPOCHS):
    # 训练阶段
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        # 将输入图像和标签数据从CPU内存移动到指定的计算设备(GPU或CPU)上
        # .to(device) 确保后续计算在同一个设备上进行,避免跨设备数据传输
        inputs, labels = inputs.to(device), labels.to(device)
        # 将优化器中所有待更新参数的梯度清零
        # 因为梯度在每次反向传播后会累积,不手动清零会导致梯度叠加,影响参数更新
        optimizer.zero_grad()
        # 前向传播:将输入图像送入模型,得到原始输出logits(尚未经过softmax)
        # 对于分类任务,输出形状通常为 [batch_size, num_classes]
        outputs = model(inputs)
        # 计算损失:使用损失函数(例如交叉熵)比较模型预测值outputs与真实标签labels
        # 得到的loss是一个标量张量,代表当前batch的平均损失
        loss = criterion(outputs, labels)
        # 反向传播:根据损失值自动计算模型中每个可训练参数的梯度
        # 梯度会累积到各参数的 .grad 属性中,用于后续优化器更新
        loss.backward()
        # 优化器更新参数:根据计算出的梯度和预设的学习率等超参数,更新模型参数
        # 这是模型学习的关键步骤,参数沿梯度下降方向调整
        optimizer.step()
        
        # 累加当前 batch 的总损失(非平均损失)
        # loss.item() 是当前 batch 的平均损失(标量),乘以 batch 大小 inputs.size(0) 得到该 batch 的总损失
        # 这样便于后面计算整个 epoch 的平均损失(总损失 / 总样本数)
        running_loss += loss.item() * inputs.size(0)

        # 对模型输出的 logits 沿着类别维度(dim=1)取最大值
        # torch.max(outputs, 1) 返回一个元组:(最大值, 最大值所在的索引)
        # 使用下划线 _ 忽略最大值本身,只保留索引,赋值给 predicted
        # predicted 形状为 [batch_size],每个元素是模型预测的类别标签(0 或 1,对于二分类)
        _, predicted = torch.max(outputs, 1)

        # 累加当前 batch 的样本总数
        # labels.size(0) 等于 inputs.size(0),即 batch 大小
        total += labels.size(0)

        # 计算当前 batch 中预测正确的样本数
        # predicted == labels 返回布尔型张量,逐元素比较是否相等
        # .sum() 对布尔张量求和(True 视为 1,False 视为 0),得到该 batch 正确的个数(张量)
        # .item() 将该张量的值转换为 Python 标量(整数),便于累加到 correct 变量
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    # 验证阶段
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_acc = val_correct / val_total
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f} | Val Acc: {val_acc:.4f}")
    
    # 保存最佳模型
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_binary_circuit.pth")
        print(f"  -> 保存最佳模型 (验证准确率: {best_acc:.4f})")

print("训练完成!最佳模型已保存为 best_binary_circuit.pth")

训练结果
111111.png

测试脚本:test_binary.py

import torch
from torchvision import transforms, models
from PIL import Image

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载模型结构
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_binary_circuit.pth"))
model.eval()

def predict(img_path):
    img = Image.open(img_path).convert('RGB')
    #在索引0处插入一个维度
    img_t = transform(img).unsqueeze(0)
    with torch.no_grad():
        out = model(img_t)
    prob = torch.softmax(out[0], dim=0)
    circuit_prob = prob[0].item()   # 索引0对应电路板类别
    non_circuit_prob = prob[1].item()
    print(f"图片: {img_path}")
    print(f"  电路板概率: {circuit_prob:.4f}, 非电路板概率: {non_circuit_prob:.4f}")
    if circuit_prob > non_circuit_prob:
        print("  结论: 电路板")
    else:
        print("  结论: 非电路板")

# 测试
predict("data/circuit_boards/51.jpg")     # 训练集中的电路板
predict("new_circuit.jpg")                # 新电路板(如果有)
predict("cat.png")                        # 猫图片

测试结果
2222.png

标签: AI

Powered by emlog  蜀ICP备18021003号-1   sitemap

川公网安备 51019002001593号