3.1.2 向量化加速 1 2 3 4 5 6 %matplotlib inline import mathimport timeimport numpy as npimport torchfrom d2l import torch as d2l
1 2 3 4 n = 10000 a = torch.ones(n) b = torch.ones(n) a, b
(tensor([1., 1., 1., ..., 1., 1., 1.]),
tensor([1., 1., 1., ..., 1., 1., 1.]))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class Timer (): """记录多次运行时间""" def __init__ (self ): self.times = [] self.start() def start (self ): """启动计时器""" self.tik = time.time() def stop (self ): """停止计时器并且将时间记录在列表中""" self.times.append(time.time() - self.tik) return self.times[-1 ] def avg (self ): """返回平均时间""" return sum (self.times) / len (self.times) def sum (self ): return sum (self.times) def cumsum (self ): """返回累计时间""" return np.array(self.times).cumsum().tolist()
1 2 3 4 5 c = torch.zeros(n) timer = Timer() for i in range (n): c[i] = a[i] + b[i] f'{timer.stop():.5 f} sec'
'0.05257 sec'
1 2 3 4 timer.start() d = a + b f'{timer.stop():.5 f} sec'
'0.00014 sec'
3.1.3 normal distribution and square loss 1 2 3 4 def normal (x, mu, sigma ): p = 1 /math.sqrt(2 *math.pi*sigma**2 ) return p * np.exp(-0.5 / sigma**2 * (x - mu)**2 )
1 2 3 4 5 6 7 x = np.arange(-7 , 7 , 0.01 ) params = [(0 , 1 ), (0 , 2 ), (0 , 3 )] d2l.plot(x, [normal(x, mu, sigma) for mu, sigma in params], xlabel='x' , ylabel='p(x)' , figsize=(4.5 , 2.5 ), legend=[f'mean {mu} , std {sigma} ' for mu, sigma in params])
3.2 linear regression from zero 1 2 3 4 %matplotlib inline import randomimport torchfrom d2l import torch as d2l
1 2 3 4 5 6 7 8 9 10 def synthetic_data (w, b, num_examples ): """生成y = Xw + b+noise""" X = torch.normal(0 , 1 , (num_examples, len (w))) y = torch.matmul(X, w) + b y += torch.normal(0 , 0.01 , y.shape) return X, y.reshape((-1 , 1 ))
1 2 3 4 5 true_w = torch.tensor([2 , -3.4 ]) true_b = 4.2 features, labels = synthetic_data(true_w, true_b, 1000 ) print ('features:' , features[0 ], '\nlabel:' , labels[0 ])print (features[:,(1 )])
features: tensor([-0.4575, 1.1481])
label: tensor([-0.6028])
tensor([ 1.1481e+00, 9.8841e-01, 3.5535e-01, 2.0784e+00, 9.6879e-01,
1.5259e+00, 2.0957e+00, -1.8084e+00, -4.5611e-02, -2.8632e-01,
1.1861e-01, 1.2378e+00, -1.4917e+00, -1.4077e-02, -8.8317e-01,
-3.6147e-01, 9.3682e-01, -8.3408e-03, 2.8196e-01, 1.5797e+00,
1.8961e-01, -1.1522e+00, 2.0520e-01, -6.0677e-01, 2.8371e-01,
-1.8548e+00, -1.0169e+00, -1.2417e+00, -7.8944e-01, 6.5647e-01,
-1.3798e+00, 8.5492e-01, 1.3620e+00, 2.1012e-01, -2.8030e-02,
-7.9160e-02, -1.8883e-01, -1.4650e+00, -1.5059e+00, 5.3730e-01,
-3.6419e-01, -2.5970e-01, -2.9555e-01, 7.3052e-01, 2.0045e+00,
-6.2671e-01, -1.5374e-01, -1.8446e+00, -1.9459e-01, 6.2921e-01,
4.2467e-02, -1.8700e+00, -1.8505e+00, -1.5815e+00, 9.1741e-01,
-2.7313e+00, -4.8553e-01, 1.0087e+00, -5.2039e-01, -2.9656e-01,
-9.1277e-01, -1.8571e+00, -1.1047e+00, 1.3227e+00, 8.7071e-01,
2.9789e-01, 2.1700e+00, -2.6391e+00, 1.2253e-01, -1.7574e-01,
8.7608e-01, -4.0556e-01, 1.9379e+00, -1.8671e+00, 3.0901e-01,
-6.7684e-01, 3.1759e-01, 3.7318e-01, -1.2070e+00, -2.8530e-01,
1.7684e-01, 8.0286e-01, 2.3324e+00, -1.0233e+00, 5.9065e-01,
-1.1620e+00, -7.2939e-01, -1.0113e+00, -2.4611e-01, 2.1603e+00,
-4.6869e-01, 4.5667e-01, 8.1412e-02, -1.5250e+00, 9.1282e-01,
-6.6413e-02, 1.7847e+00, 1.0840e+00, -1.4668e+00, -7.6567e-01,
-1.0015e+00, -2.3842e-03, 2.2501e+00, 7.9039e-01, -2.3140e-02,
2.3391e+00, -6.3166e-01, 1.5859e+00, 4.9477e-01, -7.8896e-01,
-1.1146e+00, -1.3595e+00, 3.8490e-01, 5.7789e-02, 1.8755e-01,
3.0663e-02, 1.5449e-01, 5.4152e-01, 1.0721e+00, -7.6712e-01,
8.5919e-01, -3.2355e-01, 8.8787e-01, 8.9284e-01, -1.1509e+00,
-1.0984e+00, 6.9687e-01, 4.8890e-01, -1.6014e-01, 1.1933e-01,
1.8930e-01, -1.5231e+00, -1.1571e+00, -1.5938e+00, 3.7390e-02,
1.1513e+00, -1.9291e+00, 1.7362e+00, -2.2732e-02, 3.8519e-01,
-5.1800e-01, 6.5560e-01, 5.2895e-01, 6.2111e-01, 7.1219e-01,
-1.5480e+00, -8.2423e-02, 1.2572e+00, -1.6600e+00, -6.6887e-01,
-6.4396e-01, -1.8052e+00, 4.6123e-01, 1.3803e+00, -7.3163e-01,
3.9166e-01, -1.0357e+00, -2.5712e-02, 2.6065e+00, 1.4915e-01,
1.4722e+00, -4.2426e-01, 1.5963e-01, -1.2018e+00, -2.8747e-01,
5.6584e-01, -4.2296e-01, 2.0545e-01, -3.6279e-01, 2.4457e+00,
-4.3528e-02, 1.8527e+00, -1.4745e+00, 9.0513e-01, -1.1770e+00,
3.2235e-01, -1.6264e+00, 2.9787e-01, -2.5515e-01, 3.8579e-01,
2.9109e-01, -1.8794e+00, -3.8222e-01, 1.4327e+00, 1.2353e+00,
2.1096e+00, -3.6490e-01, -2.1201e+00, -1.2036e+00, 3.1987e-01,
5.9609e-01, 1.2725e-01, -9.6513e-01, 7.3614e-01, 1.1720e+00,
-1.1179e+00, -2.4735e-01, -9.3840e-01, -5.9294e-01, -7.0450e-01,
-1.6842e-01, -9.8896e-01, -1.8894e-01, -7.6063e-02, -1.1151e+00,
1.8027e-01, -1.2681e+00, -5.0569e-01, 4.2060e-01, -5.1863e-01,
-6.1278e-01, -7.2373e-01, 1.6779e-02, 1.0886e+00, 4.3431e-01,
-1.5948e+00, 9.9059e-01, -1.4882e-01, 2.8052e-01, 2.1443e+00,
1.9607e+00, 3.8806e-01, 2.8489e+00, 2.7184e-01, -3.0007e-01,
2.9903e-01, 2.8759e-01, -1.2914e+00, 7.9686e-01, -1.0355e+00,
4.1506e-01, 2.7759e-02, 2.9032e-01, 1.9015e-01, 6.4119e-01,
1.0285e+00, -1.7690e+00, -6.2829e-01, 6.2969e-01, -9.9775e-01,
1.0963e-01, 1.4626e-01, -7.9280e-01, 7.6115e-01, 1.3151e+00,
-5.0517e-01, -1.5206e+00, 3.3245e-01, 1.8158e+00, -1.1382e+00,
1.3609e-01, 2.0802e-01, 2.3387e+00, 1.2389e+00, -1.9891e+00,
1.4855e+00, -2.0487e+00, -1.0958e+00, 6.7396e-02, -3.1978e-01,
2.0079e-01, -1.2860e+00, -2.6958e-01, 5.9093e-01, -1.0039e+00,
1.0405e+00, -1.6445e+00, -2.3928e+00, -2.5411e+00, 5.7786e-02,
-5.6163e-01, -1.3880e+00, -1.3622e+00, 2.0252e-01, -3.8418e-01,
1.0342e+00, 3.3516e-01, -5.5250e-01, 2.6750e-01, -1.6293e+00,
-2.2705e+00, -3.6701e-01, -4.4601e-01, 1.0332e-01, -9.0555e-01,
-7.8224e-01, 6.5012e-01, 2.0365e-01, -9.7166e-01, -8.8770e-01,
-1.1458e-01, 6.1896e-01, 3.8177e-02, 7.9368e-01, 8.5608e-01,
-7.3723e-01, 4.4987e-01, 7.0388e-02, -2.3882e+00, -1.2298e+00,
-9.4504e-01, 7.4099e-01, 1.6731e-01, 1.4965e+00, 5.2760e-01,
3.8883e-01, 1.3643e+00, -4.6518e-01, 2.4443e-01, 3.3352e-01,
1.2131e+00, 7.7936e-01, 1.6513e+00, 3.9003e-01, 2.6580e+00,
-2.0143e+00, 4.6442e-01, 1.0394e+00, 6.2862e-01, -1.1107e-01,
-3.1953e-01, 2.8409e-01, -2.2856e-01, -1.2910e+00, 4.0124e-01,
-1.4591e+00, -6.3800e-01, -7.6559e-02, 1.5400e+00, -4.4240e-01,
7.6727e-01, 5.2881e-01, 3.3930e-01, -5.6261e-01, 9.2551e-02,
1.0922e+00, -8.5927e-03, -1.9315e+00, -8.0120e-01, -3.7330e-01,
-9.0129e-01, -5.0834e-01, 1.8708e+00, -2.5837e-02, -1.8937e-01,
-2.2318e+00, 1.1952e+00, -2.0461e+00, -8.2401e-01, -1.2479e+00,
1.6429e-01, 1.0243e+00, 2.4229e-01, -4.3179e-01, -5.3647e-01,
-1.2427e+00, 9.8076e-01, 2.1812e+00, -4.8656e-01, -1.6169e+00,
5.3849e-01, 1.5744e+00, 1.6007e+00, 8.7652e-01, 1.2158e+00,
2.0958e+00, -1.6757e+00, -5.7488e-01, -1.6686e+00, -8.7088e-01,
-1.0723e+00, 2.4962e-01, 3.1641e-01, -7.9529e-01, -4.0261e-01,
-3.8990e-01, -4.7629e-01, 9.0027e-01, -6.2792e-01, 1.2507e+00,
-2.0434e-01, 1.1350e+00, -5.6942e-01, 3.3200e-01, -6.2645e-01,
6.6778e-01, 1.2263e+00, 1.4614e+00, 9.3139e-01, -5.3332e-01,
-4.6097e-02, -1.9035e+00, -1.0132e-01, 1.7642e-02, -9.5453e-01,
7.9801e-02, -7.4119e-01, -1.8454e-01, -9.8458e-01, 2.6712e-01,
-6.9740e-01, 2.0131e-01, 6.5156e-01, 5.0564e-02, -6.6864e-01,
-5.5951e-01, 1.6883e+00, 5.6423e-01, -1.2320e+00, 6.2235e-01,
-3.6802e-01, 1.7606e-02, -8.9760e-02, 8.4216e-02, 2.3550e+00,
-6.1954e-01, 1.1986e+00, 1.4307e-01, -1.2371e+00, 6.1353e-01,
-4.7802e-01, -8.9870e-01, 2.4927e-01, 1.2539e+00, 4.6585e-01,
-5.7510e-01, -1.4058e+00, 6.8614e-01, 1.0716e+00, -1.1346e+00,
2.0603e-01, 3.1547e-01, 3.3887e-01, 1.0074e+00, 1.8888e+00,
-6.0657e-01, -6.4736e-01, 1.1325e+00, 1.1908e+00, 2.2778e+00,
2.0696e+00, -5.1126e-01, 6.0378e-01, -9.2561e-01, 8.2603e-01,
-5.6075e-01, 1.2561e-02, 4.7810e-01, 5.9240e-01, 1.0609e+00,
-6.5395e-01, -3.8559e-02, -3.1292e+00, 2.4227e+00, 6.3595e-01,
-2.8610e-01, 6.5506e-01, 1.6298e+00, -5.9440e-01, -2.3867e+00,
3.7490e-01, -2.8954e-01, 1.9887e+00, 2.1178e+00, -9.4124e-01,
-2.1529e-01, -3.3774e-01, -2.5555e-01, 4.5399e-01, 8.2870e-01,
-1.9606e+00, 1.8591e+00, 8.9065e-01, -1.1209e+00, 8.5078e-01,
-5.0887e-01, 1.0496e-01, 7.8327e-01, 1.6164e+00, -9.9667e-01,
-1.2325e+00, -3.5175e-01, -1.2760e-01, 1.2511e+00, 5.7518e-01,
5.7423e-01, 3.7567e-02, 1.8558e-01, 5.8916e-01, -9.9956e-01,
-1.1871e+00, -1.3949e-01, 7.8632e-01, 5.4783e-01, -2.5953e-01,
-7.0179e-01, 1.3462e+00, -8.7076e-01, 1.5781e+00, -1.6149e+00,
-9.4347e-01, -1.0031e+00, -9.5817e-01, 5.4513e-01, 1.8137e-01,
7.3919e-01, 8.8540e-01, -2.0286e-01, 7.7634e-01, 7.3887e-01,
-9.2417e-01, -5.8629e-01, -9.6122e-01, -6.4402e-01, 1.0313e+00,
9.3293e-01, -2.0455e+00, 6.5080e-01, -2.3454e+00, -4.2221e-01,
-1.7840e+00, -1.5330e+00, -3.0028e-01, -9.0652e-01, 5.7490e-01,
-1.3215e+00, -5.2164e-02, -1.3944e+00, 4.6915e-01, -1.3685e+00,
-2.2256e+00, 2.6487e-01, -2.6672e-02, 9.9761e-01, -1.3467e+00,
8.3699e-01, -1.7368e-01, -2.7459e-01, 1.1095e-01, -2.1293e+00,
-8.2726e-01, 1.2120e+00, -5.2558e-01, -1.9627e-01, 4.5909e-02,
1.2375e-01, 1.1702e+00, -4.9554e-02, 2.8394e-01, -1.5373e+00,
1.6364e+00, -7.3641e-01, -9.1173e-01, 2.0838e+00, -6.4926e-01,
-3.9731e-01, 2.7819e-01, 4.8103e-01, 1.9605e+00, -7.4808e-01,
8.9170e-01, -6.3474e-01, -7.3605e-01, 2.7312e+00, -1.2479e+00,
5.9305e-01, -1.9004e+00, 5.2773e-01, 1.2009e+00, 2.4878e-01,
3.9960e-02, 1.1254e+00, -1.2421e+00, 1.6638e+00, -4.2698e-01,
-8.6493e-01, 7.9775e-01, 1.5701e+00, -1.4982e+00, -3.6172e-01,
2.0635e+00, 3.1505e-01, -1.3707e+00, 6.8623e-02, -8.3243e-02,
1.6411e+00, -8.7246e-01, 1.2962e-01, 2.5374e-01, 9.3570e-01,
-2.1959e-01, -6.3309e-01, 5.2649e-01, 6.5101e-01, -8.0604e-01,
2.3749e-01, -2.2694e-01, 6.4538e-01, 3.5337e-01, -9.6457e-01,
-1.2333e+00, 2.7997e-01, -2.0663e-01, 2.3671e+00, 9.3082e-01,
7.3876e-01, 1.8297e-01, 3.0784e-01, -2.0485e-01, -1.4161e+00,
1.2005e+00, -1.3524e+00, -9.8312e-01, -2.8585e-02, 8.3377e-01,
2.0012e-01, -1.0599e+00, -1.6392e+00, -1.1237e+00, -7.3808e-01,
9.0606e-02, -3.8058e-01, 1.4705e+00, 8.4527e-01, 9.8091e-02,
1.3605e+00, -5.5692e-03, -1.0698e+00, 6.9970e-01, -2.0957e+00,
-1.3866e+00, 1.0081e+00, -9.4732e-01, -1.5614e+00, 1.0418e+00,
-1.2773e+00, -6.1403e-01, -7.7416e-01, -3.1242e-01, -7.1166e-01,
6.4152e-01, 1.0537e+00, 2.6068e-01, 9.9155e-02, 3.6387e-01,
-5.6042e-02, -9.2726e-01, 7.4153e-01, -1.3334e+00, -3.5452e-01,
1.8230e-01, -1.4518e-01, -4.9744e-01, 2.9926e-01, -1.1999e+00,
-1.0276e+00, -2.0264e-01, 9.9026e-02, -5.8485e-03, -9.3419e-01,
1.1419e+00, -1.0526e+00, -4.2708e-02, 1.7832e+00, 9.0260e-01,
-8.5001e-02, 7.0745e-01, -1.5511e+00, -1.0399e+00, 1.4423e+00,
2.5840e-01, 1.3650e+00, -4.1277e-01, 1.1814e+00, 1.9051e-01,
-8.4195e-01, -6.3608e-01, -8.2525e-01, 2.9556e-02, 1.5895e+00,
5.4681e-01, -7.4028e-01, -1.8888e+00, -3.8418e-01, 6.2783e-01,
2.6117e-01, 1.7963e+00, -1.0656e+00, 2.9323e-01, 8.8196e-01,
-1.6574e+00, -8.4709e-01, -3.0478e-01, -1.4116e+00, 2.0713e+00,
-7.1377e-01, -3.1828e-01, 6.0991e-01, -1.2600e-01, -1.1872e+00,
2.5949e+00, -2.1323e-01, -4.9387e-01, 1.3590e-01, 2.7005e-01,
4.9349e-02, 1.7641e+00, 5.5424e-01, -2.8919e-01, -1.1024e+00,
-5.4493e-01, 2.8734e-01, 4.9375e-02, 2.7105e+00, -1.4239e+00,
-6.9621e-01, 6.8720e-02, -8.6293e-01, 1.0441e+00, -8.4265e-01,
3.4621e-01, -9.4755e-01, 1.6345e+00, 2.4400e-01, -6.6802e-01,
-5.8259e-01, -3.4806e-02, -7.5941e-01, 1.0871e+00, 5.7423e-01,
1.4147e+00, 9.9963e-01, 8.7708e-01, 9.2887e-01, -3.8470e-01,
7.3699e-01, -9.9289e-02, -1.8127e-01, 8.4384e-01, -1.7517e-01,
-8.2891e-01, 3.9462e-01, 2.2936e-01, -1.4852e+00, -1.0159e+00,
-3.3372e-01, -4.3101e-01, -6.1192e-01, -1.0163e+00, 1.0189e+00,
-1.4249e-01, -3.1265e-01, -2.6392e-01, -2.6849e+00, -4.8196e-01,
-4.6253e-02, -2.6939e-01, 4.9775e-01, 2.2392e+00, -6.4311e-01,
-2.4017e-02, 4.2888e-01, 8.5980e-01, -1.1387e-01, -4.3800e-01,
-7.3540e-01, 1.0186e+00, -1.0442e+00, -1.1324e+00, -6.1361e-01,
-7.2583e-01, -9.3297e-01, 5.7577e-01, -9.6112e-01, 8.0874e-01,
2.5687e-01, -6.4977e-01, 4.4468e-02, -4.6040e-01, -6.9195e-01,
2.6161e-01, -7.3509e-01, -1.5686e+00, -5.9142e-01, 3.1386e-01,
6.3763e-01, 4.0360e-01, -5.1863e-01, 8.7735e-01, -1.0085e+00,
-4.3077e-01, -2.3868e+00, -9.3748e-01, -4.3363e-01, 1.3374e+00,
1.0526e+00, -1.5137e+00, 1.5175e+00, 1.9621e+00, 5.9761e-02,
-6.6585e-01, -5.7567e-01, 1.8836e+00, -4.0392e-01, -1.6843e+00,
-6.2651e-02, -5.6046e-01, -2.2240e+00, 2.2366e+00, 7.4441e-01,
8.6763e-01, 2.7160e-02, 6.1294e-01, -1.3371e+00, -1.4363e+00,
3.7063e-01, 3.2868e-01, 1.4252e+00, 1.9193e+00, 4.6192e-01,
-6.5172e-01, 1.0483e+00, -2.0439e+00, -4.3385e-01, -4.1280e-01,
2.1602e+00, -3.0209e-01, -1.3598e+00, -9.2362e-01, 6.8442e-01,
1.5070e+00, -1.3786e+00, 2.0884e-01, -1.2019e+00, 1.5919e-01,
-5.1237e-01, -6.4286e-01, -5.4308e-01, -1.6116e-01, 3.2348e-01,
-9.9109e-01, 3.9232e-01, -5.3482e-01, -1.1709e+00, -7.9754e-01,
3.1636e-01, -1.1140e+00, -8.9587e-01, -1.0187e+00, 3.2878e-01,
-8.7841e-01, -1.0270e+00, -4.0047e-01, 9.1419e-01, -4.1257e-01,
-3.4638e-01, -1.0262e+00, 2.5113e+00, -1.7556e+00, -8.8387e-01,
1.2796e+00, 5.0357e-02, -8.0814e-02, 1.1154e+00, -4.7198e-01,
3.3411e-01, -3.3916e-01, 6.1048e-01, 1.1607e+00, -1.4151e-01,
1.1891e+00, -1.4370e-02, -4.7600e-01, 9.1887e-01, 1.1997e-01,
1.2459e+00, -1.6652e+00, 7.1569e-01, -9.8907e-02, 1.0175e+00,
-8.4436e-01, 1.2284e+00, -1.4385e-01, -1.1966e+00, 6.0931e-01,
-5.9934e-01, -2.0330e-01, 2.1682e-01, -1.5576e+00, -7.3551e-01,
8.3409e-02, -1.2729e+00, 3.7821e-01, -6.0331e-01, -5.4615e-01,
4.1880e-01, -1.2772e+00, 5.8423e-01, 6.4823e-01, -9.0434e-01,
-1.7290e-01, 1.0629e-01, -4.2987e-01, 2.7938e+00, 3.7136e-01,
-4.6170e-01, -4.2099e-01, 1.3325e-01, -3.4060e-01, 9.4501e-01,
4.2017e-01, -8.6550e-01, 7.3243e-02, 1.9991e-01, 8.7975e-01,
-1.3850e+00, -1.4591e+00, -2.0303e-01, -1.0262e+00, -9.9105e-01,
1.5270e+00, -7.0185e-02, -9.4997e-01, -1.0364e-01, -1.1562e+00,
8.3660e-01, -8.9686e-01, -1.1466e+00, -2.6425e+00, -1.1398e+00,
2.0886e-01, 7.1526e-01, -3.0458e-01, -7.2149e-01, -7.1204e-02,
-5.7747e-01, -5.9964e-01, -3.0580e-01, 1.9363e-01, -1.2074e+00,
-1.2838e-01, -2.5453e-01, -1.8684e-01, 1.4170e+00, -1.9690e+00,
1.9070e+00, 9.0862e-01, -3.0649e-01, 6.8586e-01, 1.9581e-01,
2.9553e-01, 2.0426e-02, -4.9903e-01, 7.2440e-01, -2.3443e+00,
5.7277e-01, 3.9135e-01, 6.4137e-01, 7.8038e-01, 3.2200e-01,
-7.8534e-01, 2.1155e+00, -2.3740e-01, -8.0613e-01, 7.0055e-01,
2.1210e-01, -1.0981e+00, 2.7863e-01, -4.4150e-02, 2.5538e-01,
8.5349e-01, 6.0327e-01, 7.5350e-01, 9.8726e-01, -1.6799e+00,
-1.6166e-01, 1.1238e-01, -7.9880e-02, -2.5753e-01, 1.1932e-01,
1.2365e+00, 8.5960e-02, 5.5032e-01, 7.7533e-01, -1.1113e-01,
2.1430e-01, -6.1338e-02, -1.9932e-01, -6.8872e-01, -2.0505e-01,
-1.8301e+00, -7.6811e-01, -7.4837e-02, -6.5999e-01, 3.8283e-01,
-4.1343e-01, -3.0340e-01, 1.2524e+00, -6.2270e-01, -4.8251e-01])
1 2 3 d2l.set_figsize() d2l.plt.scatter(features[:, 1 ].detach().numpy(), labels.detach().numpy(), 1 );
3.2.2 read the dataset 1 2 3 4 5 6 7 8 9 10 11 def data_iter (batch_size, features, labels ): num_examples = len (features) indices = list (range (num_examples)) random.shuffle(indices) for i in range (0 , num_examples, batch_size): batch_indices = torch.tensor( indices[i:min (i + batch_size, num_examples)]) yield features[batch_indices], labels[batch_indices]
1 2 3 4 5 batch_size = 10 for X, y in data_iter(batch_size, features, labels): print (X, '\n' , y) break
tensor([[-1.5650, 1.0609],
[-0.9305, -1.8794],
[-0.4093, -0.9996],
[ 0.7869, -0.9384],
[ 0.1738, -0.6689],
[ 1.6200, -0.0613],
[-0.5328, 0.2143],
[-0.2533, 2.0045],
[ 0.9031, -0.9043],
[ 1.3816, -1.5059]])
tensor([[-2.5447],
[ 8.7360],
[ 6.7868],
[ 8.9856],
[ 6.8253],
[ 7.6404],
[ 2.4112],
[-3.1235],
[ 9.0686],
[12.0905]])
3.2.3 Initialize model parameters 1 2 3 4 5 6 w = torch.zeros((2 , 1 ), requires_grad = True ) b = torch.zeros(1 , requires_grad = True ) w, b
(tensor([[0.],
[0.]], requires_grad=True),
tensor([0.], requires_grad=True))
3.2.4 define model 1 2 3 4 def linreg (X, w, b ): """线性回归模型""" return torch.matmul(X, w) + b
3.2.5 define loss function 1 2 3 def squared_loss (y_hat, y ): """均方损失""" return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
3.2.6 define optimization algorithm 1 2 3 4 5 6 def sgd (params, lr, batch_size ): """小批量随机梯度下降""" with torch.no_grad(): for param in params: param -= lr * param.grad / batch_size param.grad.zero_()
1 2 3 4 5 6 lr = 0.03 num_epochs = 3 net = linreg loss = squared_loss
1 2 3 4 5 6 7 8 9 10 for epoch in range (num_epochs): for X, y in data_iter(batch_size, features, labels): l = loss(net(X, w, b), y) l.sum ().backward() sgd([w, b], lr, batch_size) with torch.no_grad(): train_l = loss(net(features, w, b), labels) print (f'epoch{epoch + 1 } , loss{float (train_l.mean()):f} ' )
epoch1, loss0.029253
epoch2, loss0.000097
epoch3, loss0.000049
1 2 3 print (f'w的估计误差:{true_w - w.reshape(true_w.shape)} ' )print (f'b的估计误差:{true_b - b} ' )
w的估计误差:tensor([ 0.0002, -0.0004], grad_fn=<SubBackward0>)
b的估计误差:tensor([0.0003], grad_fn=<RsubBackward1>)