PyTorch 神经网络核心构建:nn.Module 与卷积操作底层解析
在 PyTorch 中,torch.nn 模块用于搭建深度学习模型。本文总结了 nn.Module 的基本结构,以及卷积操作的底层计算过程。
一、 nn.Module 基础结构
在 PyTorch 中,所有自定义的神经网络模型都需要继承 nn.Module 基类。
1. 构建步骤
构建自定义神经网络包含三个基本步骤:
-
继承基类:创建自定义类并继承
nn.Module。 -
初始化 (
__init__):定义模型所需的网络层组件。注意:此处必须调用父类的初始化方法super().__init__()。 -
前向传播 (
forward):定义输入数据x在网络层中的计算流向。
2. 基础代码示例
import torchfrom torch import nn
# 1. 继承 nn.Moduleclass MyNeuralNetwork(nn.Module):
def __init__(self): # 2. 调用父类的初始化方法 super().__init__()
# 定义网络层组件 # self.conv1 = nn.Conv2d(...)
def forward(self, x): # 3. 定义前向传播逻辑 # output = self.conv1(x) output = x + 1 return output
# 实例化模型model = MyNeuralNetwork()input_tensor = torch.tensor(1.0)
# 执行模型计算output_tensor = model(input_tensor)
print("输出结果:", output_tensor) # 输出: tensor(2.)二、 前向传播调用机制
调用 model(x) 与 model.forward(x) 的区别
nn.Module 内部实现了 Python 的 __call__ 方法。当执行 model(x) 时,底层会自动处理内部的 Hooks(钩子函数),随后再自动调用 forward 方法。
建议:在实际开发中应直接使用 output = model(input),避免使用 output = model.forward(input),以免跳过 Hooks 的执行导致不可预期的错误。
三、 卷积计算原理 (Convolution)
此处通过 torch.nn.functional.conv2d 演示二维卷积的基础计算过程。
1. 核心要素
-
输入 (Input):图像数据矩阵。
-
卷积核/权重 (Weight/Kernel):在训练中学习的权重矩阵。
-
输出 (Output):卷积核在输入矩阵上滑动计算得到的特征图矩阵。
2. 计算过程
-
卷积核覆盖在输入矩阵对应位置。
-
对应位置元素相乘并求和,得到单一数值,填入输出矩阵。
-
按指定步长(Stride)移动卷积核,遍历输入矩阵。
3. 底层计算代码示例 (F.conv2d)
import torchimport torch.nn.functional as F
input = torch.tensor([[1, 2, 0, 3, 1], [0, 1, 2, 3, 1], [1, 2, 1, 0, 0], [5, 2, 3, 1, 1], [2, 1, 0, 1, 1]])
kernel = torch.tensor([[1, 2, 1], [0, 1, 0], [2, 1, 0]])
input = torch.reshape(input, (1, 1, 5, 5))kernel = torch.reshape(kernel, (1, 1, 3, 3))
print(input.shape)print(kernel.shape)
output = F.conv2d(input, kernel, stride=1)print(output)
output2 = F.conv2d(input, kernel, stride=2)print(output2)
output3 = F.conv2d(input, kernel, stride=1, padding=1)print(output3)4. 关键控制参数
-
Stride (步长):卷积核每次移动的格数。步长越大,输出尺寸越小。
-
Padding (填充):在输入矩阵边缘填充数值(通常为
0),用于控制输出尺寸或保留边缘信息。
四、 注意事项
-
父类初始化:在重写
__init__方法时,必须包含super().__init__(),否则会引发AttributeError。 -
输入维度限制:进行二维卷积操作时,输入数据和卷积核必须严格满足 4 维形状:
(Batch_Size, Channels, Height, Width)。遇到维度报错时,建议优先检查tensor.shape。 -
nn与nn.functional的区分:-
torch.nn.Conv2d:网络层(类),内部自动实例化并管理权重参数(Weight),主要用于搭建模型。 -
torch.nn.functional.conv2d:纯数学计算函数,需手动传入输入和权重张量,主要用于自定义底层计算逻辑。
-
正在加载评论...
链上评论区