1
0
Files
ai-school/exp2/modified/model.py

54 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn.functional as F
class Cnn(torch.nn.Module):
"""卷积神经网络模型"""
def __init__(self):
super(Cnn, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
self.flatten = torch.nn.Flatten()
# 28x28过第一轮卷积后变为26x26过第一轮池化后变为13x13
# 过第二轮卷积后变为11x11过第二轮池化后变为5x5
# 过第三轮卷积后变为3x3。
# 最后一轮卷积核个数为64。
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
self.fc2 = torch.nn.Linear(64, 10)
# 初始化模型参数
self._initialize_weights()
def _initialize_weights(self):
# YYC MARK:
# 把两个全连接线性层按tensorflow默认设置初始化
# - kernel_initializer='glorot_uniform'
# - bias_initializer='zeros'
torch.nn.init.xavier_normal_(self.fc1.weight)
torch.nn.init.zeros_(self.fc1.bias)
torch.nn.init.xavier_normal_(self.fc2.weight)
torch.nn.init.zeros_(self.fc2.bias)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
# YYC MARK:
# 绝对不要在这里用F.softmax(x, dim=1)输出!
# 由于这些代码是从tensorflow里转换过来的
# tensorflow的loss function是接受possibility作为交叉熵计算的
# 而pytorch要求接受logits即模型softmax之前的参数作为交叉熵计算。
# 所以这里直接输出模型结果。
return x