3.1.2 向量化加速

1
2
3
4
5
6
%matplotlib inline
import math
import time
import numpy as np
import torch
from d2l import torch as d2l
1
2
3
4
n = 10000#10000维向量
a = torch.ones(n)
b = torch.ones(n)
a, b
(tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 tensor([1., 1., 1.,  ..., 1., 1., 1.]))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Timer():#@save
"""记录多次运行时间"""
def __init__(self):
self.times = []
self.start()

def start(self):
"""启动计时器"""
self.tik = time.time()#获取当前时间

def stop(self):
"""停止计时器并且将时间记录在列表中"""
self.times.append(time.time() - self.tik)#记录本次运行时间
return self.times[-1]

def avg(self):
"""返回平均时间"""
return sum(self.times) / len(self.times)

def sum(self):
return sum(self.times)#计算时间总和!!
def cumsum(self):
"""返回累计时间"""
return np.array(self.times).cumsum().tolist()#返回累计运行时间的列表!

1
2
3
4
5
c = torch.zeros(n)
timer = Timer()
for i in range(n):
c[i] = a[i] + b[i]
f'{timer.stop():.5f} sec'
'0.05257 sec'
1
2
3
4
timer.start()
d = a + b
f'{timer.stop():.5f} sec'
#显然直接用向量加法更快!
'0.00014 sec'

3.1.3 normal distribution and square loss

1
2
3
4
def normal(x, mu, sigma):
p = 1/math.sqrt(2*math.pi*sigma**2)
return p * np.exp(-0.5 / sigma**2 * (x - mu)**2)
#define a normal distribution
1
2
3
4
5
6
7
x = np.arange(-7, 7, 0.01)

#average and standard variance
params = [(0, 1), (0, 2), (0, 3)]#namely a couple of mu's and sigma's
d2l.plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x', ylabel='p(x)', figsize=(4.5, 2.5),
legend=[f'mean {mu}, std {sigma}' for mu, sigma in params])
#利用我们之前定义的函数画出三条图像,在matlab中考虑knocker积来画。

svg

3.2 linear regression from zero

1
2
3
4
%matplotlib inline
import random
import torch
from d2l import torch as d2l
1
2
3
4
5
6
7
8
9
10
#we generate a dataset with noise. Our task is to recover parameters from the finite samples in the dataset.
#we will use low dimension dataset
#使用线性模型参数w = [2, 3.4]', b = 4.2和噪声epsilon, 假设epsilon服从正态分布
def synthetic_data(w, b, num_examples): #@save
"""生成y = Xw + b+noise"""
X = torch.normal(0, 1, (num_examples, len(w)))#后面是矩阵大小
y = torch.matmul(X, w) + b#matmul为矩阵乘法
y += torch.normal(0, 0.01, y.shape)#噪声标准差0.01
return X, y.reshape((-1, 1))
#返回数据集, 线性输出结果列向量
1
2
3
4
5
true_w = torch.tensor([2, -3.4])#this is our real param
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
print('features:', features[0], '\nlabel:', labels[0])
print(features[:,(1)])
features: tensor([-0.4575,  1.1481]) 
label: tensor([-0.6028])
tensor([ 1.1481e+00,  9.8841e-01,  3.5535e-01,  2.0784e+00,  9.6879e-01,
         1.5259e+00,  2.0957e+00, -1.8084e+00, -4.5611e-02, -2.8632e-01,
         1.1861e-01,  1.2378e+00, -1.4917e+00, -1.4077e-02, -8.8317e-01,
        -3.6147e-01,  9.3682e-01, -8.3408e-03,  2.8196e-01,  1.5797e+00,
         1.8961e-01, -1.1522e+00,  2.0520e-01, -6.0677e-01,  2.8371e-01,
        -1.8548e+00, -1.0169e+00, -1.2417e+00, -7.8944e-01,  6.5647e-01,
        -1.3798e+00,  8.5492e-01,  1.3620e+00,  2.1012e-01, -2.8030e-02,
        -7.9160e-02, -1.8883e-01, -1.4650e+00, -1.5059e+00,  5.3730e-01,
        -3.6419e-01, -2.5970e-01, -2.9555e-01,  7.3052e-01,  2.0045e+00,
        -6.2671e-01, -1.5374e-01, -1.8446e+00, -1.9459e-01,  6.2921e-01,
         4.2467e-02, -1.8700e+00, -1.8505e+00, -1.5815e+00,  9.1741e-01,
        -2.7313e+00, -4.8553e-01,  1.0087e+00, -5.2039e-01, -2.9656e-01,
        -9.1277e-01, -1.8571e+00, -1.1047e+00,  1.3227e+00,  8.7071e-01,
         2.9789e-01,  2.1700e+00, -2.6391e+00,  1.2253e-01, -1.7574e-01,
         8.7608e-01, -4.0556e-01,  1.9379e+00, -1.8671e+00,  3.0901e-01,
        -6.7684e-01,  3.1759e-01,  3.7318e-01, -1.2070e+00, -2.8530e-01,
         1.7684e-01,  8.0286e-01,  2.3324e+00, -1.0233e+00,  5.9065e-01,
        -1.1620e+00, -7.2939e-01, -1.0113e+00, -2.4611e-01,  2.1603e+00,
        -4.6869e-01,  4.5667e-01,  8.1412e-02, -1.5250e+00,  9.1282e-01,
        -6.6413e-02,  1.7847e+00,  1.0840e+00, -1.4668e+00, -7.6567e-01,
        -1.0015e+00, -2.3842e-03,  2.2501e+00,  7.9039e-01, -2.3140e-02,
         2.3391e+00, -6.3166e-01,  1.5859e+00,  4.9477e-01, -7.8896e-01,
        -1.1146e+00, -1.3595e+00,  3.8490e-01,  5.7789e-02,  1.8755e-01,
         3.0663e-02,  1.5449e-01,  5.4152e-01,  1.0721e+00, -7.6712e-01,
         8.5919e-01, -3.2355e-01,  8.8787e-01,  8.9284e-01, -1.1509e+00,
        -1.0984e+00,  6.9687e-01,  4.8890e-01, -1.6014e-01,  1.1933e-01,
         1.8930e-01, -1.5231e+00, -1.1571e+00, -1.5938e+00,  3.7390e-02,
         1.1513e+00, -1.9291e+00,  1.7362e+00, -2.2732e-02,  3.8519e-01,
        -5.1800e-01,  6.5560e-01,  5.2895e-01,  6.2111e-01,  7.1219e-01,
        -1.5480e+00, -8.2423e-02,  1.2572e+00, -1.6600e+00, -6.6887e-01,
        -6.4396e-01, -1.8052e+00,  4.6123e-01,  1.3803e+00, -7.3163e-01,
         3.9166e-01, -1.0357e+00, -2.5712e-02,  2.6065e+00,  1.4915e-01,
         1.4722e+00, -4.2426e-01,  1.5963e-01, -1.2018e+00, -2.8747e-01,
         5.6584e-01, -4.2296e-01,  2.0545e-01, -3.6279e-01,  2.4457e+00,
        -4.3528e-02,  1.8527e+00, -1.4745e+00,  9.0513e-01, -1.1770e+00,
         3.2235e-01, -1.6264e+00,  2.9787e-01, -2.5515e-01,  3.8579e-01,
         2.9109e-01, -1.8794e+00, -3.8222e-01,  1.4327e+00,  1.2353e+00,
         2.1096e+00, -3.6490e-01, -2.1201e+00, -1.2036e+00,  3.1987e-01,
         5.9609e-01,  1.2725e-01, -9.6513e-01,  7.3614e-01,  1.1720e+00,
        -1.1179e+00, -2.4735e-01, -9.3840e-01, -5.9294e-01, -7.0450e-01,
        -1.6842e-01, -9.8896e-01, -1.8894e-01, -7.6063e-02, -1.1151e+00,
         1.8027e-01, -1.2681e+00, -5.0569e-01,  4.2060e-01, -5.1863e-01,
        -6.1278e-01, -7.2373e-01,  1.6779e-02,  1.0886e+00,  4.3431e-01,
        -1.5948e+00,  9.9059e-01, -1.4882e-01,  2.8052e-01,  2.1443e+00,
         1.9607e+00,  3.8806e-01,  2.8489e+00,  2.7184e-01, -3.0007e-01,
         2.9903e-01,  2.8759e-01, -1.2914e+00,  7.9686e-01, -1.0355e+00,
         4.1506e-01,  2.7759e-02,  2.9032e-01,  1.9015e-01,  6.4119e-01,
         1.0285e+00, -1.7690e+00, -6.2829e-01,  6.2969e-01, -9.9775e-01,
         1.0963e-01,  1.4626e-01, -7.9280e-01,  7.6115e-01,  1.3151e+00,
        -5.0517e-01, -1.5206e+00,  3.3245e-01,  1.8158e+00, -1.1382e+00,
         1.3609e-01,  2.0802e-01,  2.3387e+00,  1.2389e+00, -1.9891e+00,
         1.4855e+00, -2.0487e+00, -1.0958e+00,  6.7396e-02, -3.1978e-01,
         2.0079e-01, -1.2860e+00, -2.6958e-01,  5.9093e-01, -1.0039e+00,
         1.0405e+00, -1.6445e+00, -2.3928e+00, -2.5411e+00,  5.7786e-02,
        -5.6163e-01, -1.3880e+00, -1.3622e+00,  2.0252e-01, -3.8418e-01,
         1.0342e+00,  3.3516e-01, -5.5250e-01,  2.6750e-01, -1.6293e+00,
        -2.2705e+00, -3.6701e-01, -4.4601e-01,  1.0332e-01, -9.0555e-01,
        -7.8224e-01,  6.5012e-01,  2.0365e-01, -9.7166e-01, -8.8770e-01,
        -1.1458e-01,  6.1896e-01,  3.8177e-02,  7.9368e-01,  8.5608e-01,
        -7.3723e-01,  4.4987e-01,  7.0388e-02, -2.3882e+00, -1.2298e+00,
        -9.4504e-01,  7.4099e-01,  1.6731e-01,  1.4965e+00,  5.2760e-01,
         3.8883e-01,  1.3643e+00, -4.6518e-01,  2.4443e-01,  3.3352e-01,
         1.2131e+00,  7.7936e-01,  1.6513e+00,  3.9003e-01,  2.6580e+00,
        -2.0143e+00,  4.6442e-01,  1.0394e+00,  6.2862e-01, -1.1107e-01,
        -3.1953e-01,  2.8409e-01, -2.2856e-01, -1.2910e+00,  4.0124e-01,
        -1.4591e+00, -6.3800e-01, -7.6559e-02,  1.5400e+00, -4.4240e-01,
         7.6727e-01,  5.2881e-01,  3.3930e-01, -5.6261e-01,  9.2551e-02,
         1.0922e+00, -8.5927e-03, -1.9315e+00, -8.0120e-01, -3.7330e-01,
        -9.0129e-01, -5.0834e-01,  1.8708e+00, -2.5837e-02, -1.8937e-01,
        -2.2318e+00,  1.1952e+00, -2.0461e+00, -8.2401e-01, -1.2479e+00,
         1.6429e-01,  1.0243e+00,  2.4229e-01, -4.3179e-01, -5.3647e-01,
        -1.2427e+00,  9.8076e-01,  2.1812e+00, -4.8656e-01, -1.6169e+00,
         5.3849e-01,  1.5744e+00,  1.6007e+00,  8.7652e-01,  1.2158e+00,
         2.0958e+00, -1.6757e+00, -5.7488e-01, -1.6686e+00, -8.7088e-01,
        -1.0723e+00,  2.4962e-01,  3.1641e-01, -7.9529e-01, -4.0261e-01,
        -3.8990e-01, -4.7629e-01,  9.0027e-01, -6.2792e-01,  1.2507e+00,
        -2.0434e-01,  1.1350e+00, -5.6942e-01,  3.3200e-01, -6.2645e-01,
         6.6778e-01,  1.2263e+00,  1.4614e+00,  9.3139e-01, -5.3332e-01,
        -4.6097e-02, -1.9035e+00, -1.0132e-01,  1.7642e-02, -9.5453e-01,
         7.9801e-02, -7.4119e-01, -1.8454e-01, -9.8458e-01,  2.6712e-01,
        -6.9740e-01,  2.0131e-01,  6.5156e-01,  5.0564e-02, -6.6864e-01,
        -5.5951e-01,  1.6883e+00,  5.6423e-01, -1.2320e+00,  6.2235e-01,
        -3.6802e-01,  1.7606e-02, -8.9760e-02,  8.4216e-02,  2.3550e+00,
        -6.1954e-01,  1.1986e+00,  1.4307e-01, -1.2371e+00,  6.1353e-01,
        -4.7802e-01, -8.9870e-01,  2.4927e-01,  1.2539e+00,  4.6585e-01,
        -5.7510e-01, -1.4058e+00,  6.8614e-01,  1.0716e+00, -1.1346e+00,
         2.0603e-01,  3.1547e-01,  3.3887e-01,  1.0074e+00,  1.8888e+00,
        -6.0657e-01, -6.4736e-01,  1.1325e+00,  1.1908e+00,  2.2778e+00,
         2.0696e+00, -5.1126e-01,  6.0378e-01, -9.2561e-01,  8.2603e-01,
        -5.6075e-01,  1.2561e-02,  4.7810e-01,  5.9240e-01,  1.0609e+00,
        -6.5395e-01, -3.8559e-02, -3.1292e+00,  2.4227e+00,  6.3595e-01,
        -2.8610e-01,  6.5506e-01,  1.6298e+00, -5.9440e-01, -2.3867e+00,
         3.7490e-01, -2.8954e-01,  1.9887e+00,  2.1178e+00, -9.4124e-01,
        -2.1529e-01, -3.3774e-01, -2.5555e-01,  4.5399e-01,  8.2870e-01,
        -1.9606e+00,  1.8591e+00,  8.9065e-01, -1.1209e+00,  8.5078e-01,
        -5.0887e-01,  1.0496e-01,  7.8327e-01,  1.6164e+00, -9.9667e-01,
        -1.2325e+00, -3.5175e-01, -1.2760e-01,  1.2511e+00,  5.7518e-01,
         5.7423e-01,  3.7567e-02,  1.8558e-01,  5.8916e-01, -9.9956e-01,
        -1.1871e+00, -1.3949e-01,  7.8632e-01,  5.4783e-01, -2.5953e-01,
        -7.0179e-01,  1.3462e+00, -8.7076e-01,  1.5781e+00, -1.6149e+00,
        -9.4347e-01, -1.0031e+00, -9.5817e-01,  5.4513e-01,  1.8137e-01,
         7.3919e-01,  8.8540e-01, -2.0286e-01,  7.7634e-01,  7.3887e-01,
        -9.2417e-01, -5.8629e-01, -9.6122e-01, -6.4402e-01,  1.0313e+00,
         9.3293e-01, -2.0455e+00,  6.5080e-01, -2.3454e+00, -4.2221e-01,
        -1.7840e+00, -1.5330e+00, -3.0028e-01, -9.0652e-01,  5.7490e-01,
        -1.3215e+00, -5.2164e-02, -1.3944e+00,  4.6915e-01, -1.3685e+00,
        -2.2256e+00,  2.6487e-01, -2.6672e-02,  9.9761e-01, -1.3467e+00,
         8.3699e-01, -1.7368e-01, -2.7459e-01,  1.1095e-01, -2.1293e+00,
        -8.2726e-01,  1.2120e+00, -5.2558e-01, -1.9627e-01,  4.5909e-02,
         1.2375e-01,  1.1702e+00, -4.9554e-02,  2.8394e-01, -1.5373e+00,
         1.6364e+00, -7.3641e-01, -9.1173e-01,  2.0838e+00, -6.4926e-01,
        -3.9731e-01,  2.7819e-01,  4.8103e-01,  1.9605e+00, -7.4808e-01,
         8.9170e-01, -6.3474e-01, -7.3605e-01,  2.7312e+00, -1.2479e+00,
         5.9305e-01, -1.9004e+00,  5.2773e-01,  1.2009e+00,  2.4878e-01,
         3.9960e-02,  1.1254e+00, -1.2421e+00,  1.6638e+00, -4.2698e-01,
        -8.6493e-01,  7.9775e-01,  1.5701e+00, -1.4982e+00, -3.6172e-01,
         2.0635e+00,  3.1505e-01, -1.3707e+00,  6.8623e-02, -8.3243e-02,
         1.6411e+00, -8.7246e-01,  1.2962e-01,  2.5374e-01,  9.3570e-01,
        -2.1959e-01, -6.3309e-01,  5.2649e-01,  6.5101e-01, -8.0604e-01,
         2.3749e-01, -2.2694e-01,  6.4538e-01,  3.5337e-01, -9.6457e-01,
        -1.2333e+00,  2.7997e-01, -2.0663e-01,  2.3671e+00,  9.3082e-01,
         7.3876e-01,  1.8297e-01,  3.0784e-01, -2.0485e-01, -1.4161e+00,
         1.2005e+00, -1.3524e+00, -9.8312e-01, -2.8585e-02,  8.3377e-01,
         2.0012e-01, -1.0599e+00, -1.6392e+00, -1.1237e+00, -7.3808e-01,
         9.0606e-02, -3.8058e-01,  1.4705e+00,  8.4527e-01,  9.8091e-02,
         1.3605e+00, -5.5692e-03, -1.0698e+00,  6.9970e-01, -2.0957e+00,
        -1.3866e+00,  1.0081e+00, -9.4732e-01, -1.5614e+00,  1.0418e+00,
        -1.2773e+00, -6.1403e-01, -7.7416e-01, -3.1242e-01, -7.1166e-01,
         6.4152e-01,  1.0537e+00,  2.6068e-01,  9.9155e-02,  3.6387e-01,
        -5.6042e-02, -9.2726e-01,  7.4153e-01, -1.3334e+00, -3.5452e-01,
         1.8230e-01, -1.4518e-01, -4.9744e-01,  2.9926e-01, -1.1999e+00,
        -1.0276e+00, -2.0264e-01,  9.9026e-02, -5.8485e-03, -9.3419e-01,
         1.1419e+00, -1.0526e+00, -4.2708e-02,  1.7832e+00,  9.0260e-01,
        -8.5001e-02,  7.0745e-01, -1.5511e+00, -1.0399e+00,  1.4423e+00,
         2.5840e-01,  1.3650e+00, -4.1277e-01,  1.1814e+00,  1.9051e-01,
        -8.4195e-01, -6.3608e-01, -8.2525e-01,  2.9556e-02,  1.5895e+00,
         5.4681e-01, -7.4028e-01, -1.8888e+00, -3.8418e-01,  6.2783e-01,
         2.6117e-01,  1.7963e+00, -1.0656e+00,  2.9323e-01,  8.8196e-01,
        -1.6574e+00, -8.4709e-01, -3.0478e-01, -1.4116e+00,  2.0713e+00,
        -7.1377e-01, -3.1828e-01,  6.0991e-01, -1.2600e-01, -1.1872e+00,
         2.5949e+00, -2.1323e-01, -4.9387e-01,  1.3590e-01,  2.7005e-01,
         4.9349e-02,  1.7641e+00,  5.5424e-01, -2.8919e-01, -1.1024e+00,
        -5.4493e-01,  2.8734e-01,  4.9375e-02,  2.7105e+00, -1.4239e+00,
        -6.9621e-01,  6.8720e-02, -8.6293e-01,  1.0441e+00, -8.4265e-01,
         3.4621e-01, -9.4755e-01,  1.6345e+00,  2.4400e-01, -6.6802e-01,
        -5.8259e-01, -3.4806e-02, -7.5941e-01,  1.0871e+00,  5.7423e-01,
         1.4147e+00,  9.9963e-01,  8.7708e-01,  9.2887e-01, -3.8470e-01,
         7.3699e-01, -9.9289e-02, -1.8127e-01,  8.4384e-01, -1.7517e-01,
        -8.2891e-01,  3.9462e-01,  2.2936e-01, -1.4852e+00, -1.0159e+00,
        -3.3372e-01, -4.3101e-01, -6.1192e-01, -1.0163e+00,  1.0189e+00,
        -1.4249e-01, -3.1265e-01, -2.6392e-01, -2.6849e+00, -4.8196e-01,
        -4.6253e-02, -2.6939e-01,  4.9775e-01,  2.2392e+00, -6.4311e-01,
        -2.4017e-02,  4.2888e-01,  8.5980e-01, -1.1387e-01, -4.3800e-01,
        -7.3540e-01,  1.0186e+00, -1.0442e+00, -1.1324e+00, -6.1361e-01,
        -7.2583e-01, -9.3297e-01,  5.7577e-01, -9.6112e-01,  8.0874e-01,
         2.5687e-01, -6.4977e-01,  4.4468e-02, -4.6040e-01, -6.9195e-01,
         2.6161e-01, -7.3509e-01, -1.5686e+00, -5.9142e-01,  3.1386e-01,
         6.3763e-01,  4.0360e-01, -5.1863e-01,  8.7735e-01, -1.0085e+00,
        -4.3077e-01, -2.3868e+00, -9.3748e-01, -4.3363e-01,  1.3374e+00,
         1.0526e+00, -1.5137e+00,  1.5175e+00,  1.9621e+00,  5.9761e-02,
        -6.6585e-01, -5.7567e-01,  1.8836e+00, -4.0392e-01, -1.6843e+00,
        -6.2651e-02, -5.6046e-01, -2.2240e+00,  2.2366e+00,  7.4441e-01,
         8.6763e-01,  2.7160e-02,  6.1294e-01, -1.3371e+00, -1.4363e+00,
         3.7063e-01,  3.2868e-01,  1.4252e+00,  1.9193e+00,  4.6192e-01,
        -6.5172e-01,  1.0483e+00, -2.0439e+00, -4.3385e-01, -4.1280e-01,
         2.1602e+00, -3.0209e-01, -1.3598e+00, -9.2362e-01,  6.8442e-01,
         1.5070e+00, -1.3786e+00,  2.0884e-01, -1.2019e+00,  1.5919e-01,
        -5.1237e-01, -6.4286e-01, -5.4308e-01, -1.6116e-01,  3.2348e-01,
        -9.9109e-01,  3.9232e-01, -5.3482e-01, -1.1709e+00, -7.9754e-01,
         3.1636e-01, -1.1140e+00, -8.9587e-01, -1.0187e+00,  3.2878e-01,
        -8.7841e-01, -1.0270e+00, -4.0047e-01,  9.1419e-01, -4.1257e-01,
        -3.4638e-01, -1.0262e+00,  2.5113e+00, -1.7556e+00, -8.8387e-01,
         1.2796e+00,  5.0357e-02, -8.0814e-02,  1.1154e+00, -4.7198e-01,
         3.3411e-01, -3.3916e-01,  6.1048e-01,  1.1607e+00, -1.4151e-01,
         1.1891e+00, -1.4370e-02, -4.7600e-01,  9.1887e-01,  1.1997e-01,
         1.2459e+00, -1.6652e+00,  7.1569e-01, -9.8907e-02,  1.0175e+00,
        -8.4436e-01,  1.2284e+00, -1.4385e-01, -1.1966e+00,  6.0931e-01,
        -5.9934e-01, -2.0330e-01,  2.1682e-01, -1.5576e+00, -7.3551e-01,
         8.3409e-02, -1.2729e+00,  3.7821e-01, -6.0331e-01, -5.4615e-01,
         4.1880e-01, -1.2772e+00,  5.8423e-01,  6.4823e-01, -9.0434e-01,
        -1.7290e-01,  1.0629e-01, -4.2987e-01,  2.7938e+00,  3.7136e-01,
        -4.6170e-01, -4.2099e-01,  1.3325e-01, -3.4060e-01,  9.4501e-01,
         4.2017e-01, -8.6550e-01,  7.3243e-02,  1.9991e-01,  8.7975e-01,
        -1.3850e+00, -1.4591e+00, -2.0303e-01, -1.0262e+00, -9.9105e-01,
         1.5270e+00, -7.0185e-02, -9.4997e-01, -1.0364e-01, -1.1562e+00,
         8.3660e-01, -8.9686e-01, -1.1466e+00, -2.6425e+00, -1.1398e+00,
         2.0886e-01,  7.1526e-01, -3.0458e-01, -7.2149e-01, -7.1204e-02,
        -5.7747e-01, -5.9964e-01, -3.0580e-01,  1.9363e-01, -1.2074e+00,
        -1.2838e-01, -2.5453e-01, -1.8684e-01,  1.4170e+00, -1.9690e+00,
         1.9070e+00,  9.0862e-01, -3.0649e-01,  6.8586e-01,  1.9581e-01,
         2.9553e-01,  2.0426e-02, -4.9903e-01,  7.2440e-01, -2.3443e+00,
         5.7277e-01,  3.9135e-01,  6.4137e-01,  7.8038e-01,  3.2200e-01,
        -7.8534e-01,  2.1155e+00, -2.3740e-01, -8.0613e-01,  7.0055e-01,
         2.1210e-01, -1.0981e+00,  2.7863e-01, -4.4150e-02,  2.5538e-01,
         8.5349e-01,  6.0327e-01,  7.5350e-01,  9.8726e-01, -1.6799e+00,
        -1.6166e-01,  1.1238e-01, -7.9880e-02, -2.5753e-01,  1.1932e-01,
         1.2365e+00,  8.5960e-02,  5.5032e-01,  7.7533e-01, -1.1113e-01,
         2.1430e-01, -6.1338e-02, -1.9932e-01, -6.8872e-01, -2.0505e-01,
        -1.8301e+00, -7.6811e-01, -7.4837e-02, -6.5999e-01,  3.8283e-01,
        -4.1343e-01, -3.0340e-01,  1.2524e+00, -6.2270e-01, -4.8251e-01])
1
2
3
#Generate points graph between the second feature and labels
d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1);

svg

3.2.2 read the dataset

1
2
3
4
5
6
7
8
9
10
11
#When we are training the model, we need to traverse the whole dataset.
def data_iter(batch_size, features, labels):#generate batch with batch_size
num_examples = len(features)
indices = list(range(num_examples))
#the sample is read randomly without specific orders
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i:min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]
#every batch consists of a group of feature and a group of label with number of batch_size, namely batch_size row.
1
2
3
4
5
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, '\n', y)
break
#输出一个小batch!我们如果继续执行,会连续不断获得小批量,知道遍历结束。
tensor([[-1.5650,  1.0609],
        [-0.9305, -1.8794],
        [-0.4093, -0.9996],
        [ 0.7869, -0.9384],
        [ 0.1738, -0.6689],
        [ 1.6200, -0.0613],
        [-0.5328,  0.2143],
        [-0.2533,  2.0045],
        [ 0.9031, -0.9043],
        [ 1.3816, -1.5059]]) 
 tensor([[-2.5447],
        [ 8.7360],
        [ 6.7868],
        [ 8.9856],
        [ 6.8253],
        [ 7.6404],
        [ 2.4112],
        [-3.1235],
        [ 9.0686],
        [12.0905]])

3.2.3 Initialize model parameters

1
2
3
4
5
6
#比如我们将偏置先置为0,通过均值为0,标准差为0.01正态分布中臭氧随机数来初始化权重w。
# w = torch.normal(0, 0.01, size=(2, 1), requires_grad = True)
# b = torch.zeros(1, requires_grad = True)
w = torch.zeros((2, 1), requires_grad = True)
b = torch.zeros(1, requires_grad = True)
w, b#these are random params, we then upgrade these params using small gradient descent until they match our data.
(tensor([[0.],
         [0.]], requires_grad=True),
 tensor([0.], requires_grad=True))

3.2.4 define model

1
2
3
4
#define model
def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b

3.2.5 define loss function

1
2
3
def squared_loss(y_hat, y): #@save
"""均方损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2#即对矩阵每个对应元素算一下平方误差

3.2.6 define optimization algorithm

1
2
3
4
5
6
def sgd(params, lr, batch_size): #@save params are w and b, sgd is small gradient descent.
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size#之前学的自动微分相减的过程
param.grad.zero_()#清空梯度,不积累
1
2
3
4
5
6
#we use data_iter to traverse the whole dataset, and we set the times of training as num_epochs. Num_epochs and 
#learning rate are hyperparameters.
lr = 0.03
num_epochs = 3# we run 3 times
net = linreg#net is just linreg
loss = squared_loss#loss is just squared_loss
1
2
3
4
5
6
7
8
9
10
for epoch in range(num_epochs):
for X, y in data_iter(batch_size, features, labels):
l = loss(net(X, w, b), y)#X 和 y的小批量损失,也就是预测结果和真实结果比较一遍得到损失大小
#net is a neural network model, which accepts input X, then set weight w and bias b to calculate and obtain results.
l.sum().backward()#计算和函数对分量求导,也就是sgd函数中的param.grad部分
sgd([w, b], lr, batch_size)
with torch.no_grad():#every cycle we clean up the gradient.
train_l = loss(net(features, w, b), labels)#the whole loss during the training process.
#according the best params to calculate the results and then compare with the labels originally.
print(f'epoch{epoch + 1}, loss{float(train_l.mean()):f}')#timely print the training time
epoch1, loss0.029253
epoch2, loss0.000097
epoch3, loss0.000049
1
2
3
#Because the dataset is made by ourselves, so we know the real params w and b.
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')
w的估计误差:tensor([ 0.0002, -0.0004], grad_fn=<SubBackward0>)
b的估计误差:tensor([0.0003], grad_fn=<RsubBackward1>)