1
0
Files
ai-school/dl-exp/exp1/source.py

62 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
class Net(torch.nn.Module): #继承 torch 的module
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__() #继承_init_功能
#定理每层用什么样的形式
self.hidden1 = torch.nn.Linear(n_feature, n_hidden) #隐藏层线性输出
self.hidden2 = torch.nn.Linear(n_hidden, n_hidden) #输出层线性输出
self.hidden3 = torch.nn.Linear(n_hidden, n_hidden) #输出层线性输出
self.predict = torch.nn.Linear(n_hidden, n_output) #输出层线性输出
def forward(self, x): #这同时也是module中的forward功能
#正向传播输入值,神经网络分析出输出值
x = F.relu(self.hidden1(x)) #激励函数(隐藏层的线性值)
x = F.relu(self.hidden2(x))
x = F.relu(self.hidden3(x))
x = self.predict(x) #输出值
return x
def main():
x = torch.unsqueeze(torch.linspace(-1, 1, 100),
dim=1) #x data(tensor),shape=(100,1)
y = -x.pow(3) + 2 * x.pow(2) + 0.2 * torch.rand(x.size())
#y=math.sinx)+o.2*torch.rand(x.size())
net = Net(n_feature=1, n_hidden=20, n_output=1)
#optimizer是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.01) #传入net的所有参数学习率
loss_func = torch.nn.MSELoss() #预测值和真实值的误差计算公式(均方差)
for t in range(2000):
prediction = net(x) #喂给net训练数据x输出预测值
loss = loss_func(prediction, y) #计算两者的误差
optimizer.zero_grad() #清空上一步的残余更新参数值
loss.backward() #误差反向传播,计算参数更新值
optimizer.step() #将参数更新值施加到net的parameters上
if t % 5 == 0:
#plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.scatter(x.data.numpy(), prediction.data.numpy())
plt.text(0.5,
0,
'Loss=%.4f' % loss.data.numpy(),
fontdict={
'size': 20,
'color': 'red'
})
plt.show()
if __name__ == "__main__":
main()