1.线性回归模型原理与推导
给定一组由输入𝒙和输出𝑦构成的数据集,
D
=
{
(
x
1
,
y
1
)
,
(
x
2
,
y
2
)
,
…
,
(
x
m
,
y
m
)
}
,
其中
x
i
=
(
x
i
1
;
x
i
2
;
…
;
x
i
d
)
,
y
i
∈
R
D = \{ (x_1,y_1), (x_2,y_2), \ldots, (x_m,y_m) \}, \text{其中} x_i = (x_{i1};x_{i2};\ldots;x_{id}), y_i \in \mathbb{R}
D={(x1,y1),(x2,y2),…,(xm,ym)},其中xi=(xi1;xi2;…;xid),yi∈R线性回归就是通过训练学习得到一个线性模型来最大程度地根据输入𝒙拟合输出𝑦。
线性回归学习的关键问题在于确定参数w和b,使得拟合输出y与真实输出y尽可能接近。在回归任务中,我们通常使用均方误差来度量预测与标签之间的损失,所以回归任务的优化目标就是使得拟合输出和真实输出之间的均方误差最小化,所以有:
(
w
∗
,
b
∗
)
=
arg
min
∑
i
=
1
m
(
y
−
y
i
)
2
=
arg
min
∑
i
=
1
m
(
w
x
i
+
b
−
y
i
)
2
\begin{align*} (w^*,b^*)&=\arg\min\sum_{i = 1}^{m}(y - y_i)^2\\ &=\arg\min\sum_{i = 1}^{m}(wx_i + b - y_i)^2 \end{align*}
(w∗,b∗)=argmini=1∑m(y−yi)2=argmini=1∑m(wxi+b−yi)2
为求得𝑤和𝑏的最小化参数𝑤∗和𝑏∗,可基于均方损失函数分别对𝑤和𝑏求一阶导数并令其为0,𝑤的求导推导过程如下所示:
∂
L
(
w
,
b
)
∂
w
=
∂
∂
w
[
∑
i
=
1
m
(
w
x
i
+
b
−
y
i
)
2
]
=
∑
i
=
1
m
∂
∂
w
[
(
y
i
−
w
x
i
−
b
)
2
]
=
∑
i
=
1
m
[
2
⋅
(
y
i
−
w
x
i
−
b
)
⋅
(
−
x
i
)
]
=
∑
i
=
1
m
[
2
⋅
(
w
x
i
2
−
y
i
x
i
+
b
x
i
)
]
=
2
⋅
(
w
∑
i
=
1
m
x
i
2
−
∑
i
=
1
m
y
i
x
i
+
b
∑
i
=
1
m
x
i
)
\begin{align*} \frac{\partial L(w,b)}{\partial w} &= \frac{\partial }{\partial w}\left[\sum_{i = 1}^{m}(wx_{i}+b - y_{i})^{2}\right] \\ &= \sum_{i = 1}^{m}\frac{\partial }{\partial w}[(y_{i}-wx_{i}-b)^{2}] \\ &= \sum_{i = 1}^{m}[2\cdot(y_{i}-wx_{i}-b)\cdot(-x_{i})] \\ &= \sum_{i = 1}^{m}[2\cdot(wx_{i}^{2}-y_{i}x_{i}+bx_{i})] \\ &= 2\cdot\left(w\sum_{i = 1}^{m}x_{i}^{2}-\sum_{i = 1}^{m}y_{i}x_{i}+b\sum_{i = 1}^{m}x_{i}\right) \end{align*}
∂w∂L(w,b)=∂w∂[i=1∑m(wxi+b−yi)2]=i=1∑m∂w∂[(yi−wxi−b)2]=i=1∑m[2⋅(yi−wxi−b)⋅(−xi)]=i=1∑m[2⋅(wxi2−yixi+bxi)]=2⋅(wi=1∑mxi2−i=1∑myixi+bi=1∑mxi)
∂
L
(
w
,
b
)
∂
b
=
∂
∂
b
[
∑
i
=
1
m
(
w
x
i
+
b
−
y
i
)
2
]
=
∑
i
=
1
m
∂
∂
b
[
(
y
i
−
w
x
i
−
b
)
2
]
=
∑
i
=
1
m
[
2
⋅
(
y
i
−
w
x
i
−
b
)
⋅
(
−
1
)
]
=
∑
i
=
1
m
[
2
⋅
(
−
y
i
+
w
x
i
+
b
)
]
=
2
⋅
(
−
∑
i
=
1
m
y
i
+
∑
i
=
1
m
w
x
i
+
∑
i
=
1
m
b
)
=
2
(
m
b
−
∑
i
=
1
m
(
y
i
−
w
x
i
)
)
\begin{align*} \frac{\partial L(w,b)}{\partial b} &= \frac{\partial }{\partial b}\left[\sum_{i = 1}^{m}(wx_{i}+b - y_{i})^{2}\right] \\ &= \sum_{i = 1}^{m}\frac{\partial }{\partial b}[(y_{i}-wx_{i}-b)^{2}] \\ &= \sum_{i = 1}^{m}[2\cdot(y_{i}-wx_{i}-b)\cdot(-1)] \\ &= \sum_{i = 1}^{m}[2\cdot(-y_{i}+wx_{i}+b)] \\ &= 2\cdot\left(-\sum_{i = 1}^{m}y_{i}+\sum_{i = 1}^{m}wx_{i}+\sum_{i = 1}^{m}b\right) \\ &= 2\left(mb - \sum_{i = 1}^{m}(y_{i}-wx_{i})\right) \end{align*}
∂b∂L(w,b)=∂b∂[i=1∑m(wxi+b−yi)2]=i=1∑m∂b∂[(yi−wxi−b)2]=i=1∑m[2⋅(yi−wxi−b)⋅(−1)]=i=1∑m[2⋅(−yi+wxi+b)]=2⋅(−i=1∑myi+i=1∑mwxi+i=1∑mb)=2(mb−i=1∑m(yi−wxi))
可解得𝑤和𝑏的最优表达为:
w
∗
=
∑
i
=
1
m
y
i
(
x
i
−
x
ˉ
)
∑
i
=
1
m
x
i
2
−
1
m
(
∑
i
=
1
m
x
i
)
2
w^{*}=\frac{\sum_{i = 1}^{m}y_{i}(x_{i}-\bar{x})}{\sum_{i = 1}^{m}x_{i}^{2}-\frac{1}{m}(\sum_{i = 1}^{m}x_{i})^{2}}
w∗=∑i=1mxi2−m1(∑i=1mxi)2∑i=1myi(xi−xˉ)
b
∗
=
1
m
∑
i
=
1
m
(
y
i
−
w
x
i
)
b^{*}=\frac{1}{m}\sum_{i = 1}^{m}(y_{i}-wx_{i})
b∗=m1i=1∑m(yi−wxi)
这种基于均方误差最小化求解线性回归参数的方法就是著名的最小二乘法(leastsquaresmethod),在统计分析中,也叫普通最小二乘(OrdinaryLeast Squares,OLS)。
2.多元回归的进一步推导
将上述推导过程进行矩阵化以适应多元线性回归问题。所谓多元问题,就是输入有 𝑑 个变量, 如前述影响薪资水平的因素包括城市、学历、年龄和经验等。 为方便矩阵化的最小二乘法的推导,可将参数𝑤和𝑏合并为向量表达形
w
^
=
(
w
;
b
)
\widehat{\boldsymbol{w}} = (\boldsymbol{w}; b)
w
=(w;b)。训练数据集D的输入部分可表示为一个𝑚 ×(𝑑 + 1)维的矩阵𝑿,其中𝑑为输入变量的个数。 则矩阵𝑿和𝑦可表示为:
X
=
(
x
11
x
12
⋯
x
1
d
1
x
21
x
22
⋯
x
2
d
1
⋮
⋮
⋱
⋮
⋮
x
m
1
x
m
2
⋯
x
m
d
1
)
=
(
x
1
T
1
⋮
⋮
x
m
T
1
)
y
=
(
y
1
;
y
2
;
⋯
;
y
m
)
\boldsymbol{X} = \begin{pmatrix} x_{11} & x_{12} & \cdots & x_{1d} & 1 \\ x_{21} & x_{22} & \cdots & x_{2d} & 1 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ x_{m1} & x_{m2} & \cdots & x_{md} & 1 \end{pmatrix} = \begin{pmatrix} \boldsymbol{x}_{1}^{\mathrm{T}} & 1 \\ \vdots & \vdots \\ \boldsymbol{x}_{m}^{\mathrm{T}} & 1 \end{pmatrix} \quad \boldsymbol{y} = (y_{1}; y_{2}; \cdots ; y_{m})
X=
x11x21⋮xm1x12x22⋮xm2⋯⋯⋱⋯x1dx2d⋮xmd11⋮1
=
x1T⋮xmT1⋮1
y=(y1;y2;⋯;ym)
参数优化目标函数的矩阵化表达式为:
w
^
∗
=
arg
min
(
y
−
X
w
^
)
T
(
y
−
X
w
^
)
\widehat{\boldsymbol{w}}^{*} = \arg\min (\boldsymbol{y} - \boldsymbol{X}\widehat{\boldsymbol{w}})^{\mathrm{T}}(\boldsymbol{y} - \boldsymbol{X}\widehat{\boldsymbol{w}})
w
∗=argmin(y−Xw
)T(y−Xw
)
令
L
=
(
y
−
X
w
^
)
T
(
y
−
X
w
^
)
L = (\boldsymbol{y} - \boldsymbol{X}\widehat{\boldsymbol{w}})^{\mathrm{T}}(\boldsymbol{y} - \boldsymbol{X}\widehat{\boldsymbol{w}})
L=(y−Xw
)T(y−Xw
),对参数
w
^
\widehat{\boldsymbol{w}}
w
求导,其推导过程如下:
L
=
y
T
y
−
y
T
X
w
^
−
w
^
T
X
T
y
+
w
^
T
X
T
X
w
^
L = \boldsymbol{y}^{\mathrm{T}}\boldsymbol{y} - \boldsymbol{y}^{\mathrm{T}}\boldsymbol{X}\widehat{\boldsymbol{w}} - \widehat{\boldsymbol{w}}^{\mathrm{T}}\boldsymbol{X}^{\mathrm{T}}\boldsymbol{y} + \widehat{\boldsymbol{w}}^{\mathrm{T}}\boldsymbol{X}^{\mathrm{T}}\boldsymbol{X}\widehat{\boldsymbol{w}}
L=yTy−yTXw
−w
TXTy+w
TXTXw
由微分矩阵公式
∂
a
T
x
∂
x
=
∂
x
T
a
∂
x
=
a
∂
x
T
A
x
∂
x
=
(
A
+
A
T
)
x
\frac{\partial \boldsymbol{a}^{\mathrm{T}}\boldsymbol{x}}{\partial \boldsymbol{x}} = \frac{\partial \boldsymbol{x}^{\mathrm{T}}\boldsymbol{a}}{\partial \boldsymbol{x}} = \boldsymbol{a} \quad \frac{\partial \boldsymbol{x}^{\mathrm{T}}\boldsymbol{A}\boldsymbol{x}}{\partial \boldsymbol{x}} = (\boldsymbol{A} + \boldsymbol{A}^{\mathrm{T}})\boldsymbol{x}
∂x∂aTx=∂x∂xTa=a∂x∂xTAx=(A+AT)x
可得:
∂
L
∂
w
^
=
0
−
X
T
y
−
X
T
y
+
(
X
T
X
+
X
T
X
)
w
^
=
2
X
T
(
X
w
^
−
y
)
\frac{\partial L}{\partial \widehat{\boldsymbol{w}}} = 0 - \boldsymbol{X}^{\mathrm{T}}\boldsymbol{y} - \boldsymbol{X}^{\mathrm{T}}\boldsymbol{y} + (\boldsymbol{X}^{\mathrm{T}}\boldsymbol{X} + \boldsymbol{X}^{\mathrm{T}}\boldsymbol{X})\widehat{\boldsymbol{w}} = 2\boldsymbol{X}^{\mathrm{T}}(\boldsymbol{X}\widehat{\boldsymbol{w}} - \boldsymbol{y})
∂w
∂L=0−XTy−XTy+(XTX+XTX)w
=2XT(Xw
−y)
当矩阵
(
X
T
X
)
(\boldsymbol{X}^{\mathrm{T}}\boldsymbol{X})
(XTX)为满秩矩阵或者正定矩阵时,令上式等于0,可解得参数为:
w
^
∗
=
(
X
T
X
)
−
1
X
T
y
\widehat{\boldsymbol{w}}^{*} = (\boldsymbol{X}^{\mathrm{T}}\boldsymbol{X})^{-1}\boldsymbol{X}^{\mathrm{T}}\boldsymbol{y}
w
∗=(XTX)−1XTy
最终得到多元线性回归模型为:
f
(
x
^
i
)
=
x
^
i
T
(
X
T
X
)
−
1
X
T
y
f(\widehat{\boldsymbol{x}}_{i}) = \widehat{\boldsymbol{x}}_{i}^{T}(\boldsymbol{X}^{\mathrm{T}}\boldsymbol{X})^{-1}\boldsymbol{X}^{\mathrm{T}}\boldsymbol{y}
f(x
i)=x
iT(XTX)−1XTy
3.线性回归模型的性能评估
根据机器学习三要素,线性回归模型可以看作是以线性回归为模型、以均方损失最小化为策略、以最小二乘法为算法的机器学习方法。
线性回归模型的性能评估方法主要包括平均绝对误差(MAE)、均方/根均方误差评估(MSE/RMSE)、判决系数R2评估等。
M
A
E
(
y
,
y
^
)
=
1
n
samples
∑
i
=
1
n
samples
∣
y
i
−
y
^
i
∣
M
S
E
(
y
,
y
^
)
=
1
n
samples
∑
i
=
1
n
samples
(
y
i
−
y
^
i
)
2
y
ˉ
=
1
n
∑
i
=
1
n
y
i
S
S
t
o
t
=
∑
i
=
1
n
(
y
i
−
y
ˉ
)
2
S
S
r
e
g
=
∑
i
=
1
n
(
f
i
−
y
ˉ
)
2
S
S
r
e
s
=
∑
i
=
1
n
(
y
i
−
f
i
)
2
=
∑
i
=
1
n
e
i
2
R
2
=
1
−
S
S
r
e
s
S
S
t
o
t
\begin{align*} \mathrm{MAE}(y, \hat{y}) &= \frac{1}{n_{\text{samples}}} \sum_{i = 1}^{n_{\text{samples}}} |y_i - \hat{y}_i| \\ \mathrm{MSE}(y, \hat{y}) &= \frac{1}{n_{\text{samples}}} \sum_{i = 1}^{n_{\text{samples}}} (y_i - \hat{y}_i)^2 \\ \bar{y} &= \frac{1}{n} \sum_{i = 1}^{n} y_i \\ SS_{tot} &= \sum_{i = 1}^{n} (y_i - \bar{y})^2 \\ SS_{reg} &= \sum_{i = 1}^{n} (f_i - \bar{y})^2 \\ SS_{res} &= \sum_{i = 1}^{n} (y_i - f_i)^2 = \sum_{i = 1}^{n} e_i^2 \\ R^2 &= 1 - \frac{SS_{res}}{SS_{tot}} \end{align*}
MAE(y,y^)MSE(y,y^)yˉSStotSSregSSresR2=nsamples1i=1∑nsamples∣yi−y^i∣=nsamples1i=1∑nsamples(yi−y^i)2=n1i=1∑nyi=i=1∑n(yi−yˉ)2=i=1∑n(fi−yˉ)2=i=1∑n(yi−fi)2=i=1∑nei2=1−SStotSSres
4.代码实操
线性回归
from sklearn.model_selection import train_test_split
import pandas as pd
import warnings
import numpy as np
warnings.filterwarnings("ignore")
Boston数据集的13个属性信息如下:
CRIM:城镇人均犯罪率
ZN:住宅用地所占比例
INDUS:城镇中非住宅用地所占比例
CHAS:是否靠近查尔斯河(1表示靠近,0表示不靠近)
NOX:一氧化氮浓度
RM:房屋平均房间数
AGE:自住房屋中建于1940年前的房屋所占比例
DIS:距离5个波士顿就业中心的加权距离
RAD:距离绿色公园的辐射范围
TAX:每10,000美元的全额物业税率
PTRATIO:城镇中学生与教师的比例
B:黑人占比
MEDV:房价中位数(单位:千美元
feature_name = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','MEDV']
# 加载波士顿房屋数据集
feature_name = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','MEDV']
data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]
# 数据准备
X = data
y = target
#dataset = load_boston()
df = pd.DataFrame(X,columns=feature_name)
df['price'] = y
df.head()
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | MEDV | price | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0.0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1.0 | 296.0 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0.0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2.0 | 242.0 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0.0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2.0 | 242.0 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0.0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3.0 | 222.0 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0.0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3.0 | 222.0 | 18.7 | 396.90 | 5.33 | 36.2 |
X = df['RM'].values
y = df['price'].values
X[:5], y[:5]
(array([6.575, 6.421, 7.185, 6.998, 7.147]), array([24. , 21.6, 34.7, 33.4, 36.2]))
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1, test_size=0.2, shuffle=True)
X_train =X_train.reshape((-1,1))
X_test = X_test.reshape((-1,1))
y_train = y_train.reshape((-1,1))
y_test = y_test.reshape((-1,1))
X_train.shape, y_train.shape
((404, 1), (404, 1))
### 初始化模型参数
def initialize_params(dims):
''' 输入:
dims:训练数据变量维度
输出: w:初始化权重参数值
b:初始化偏差参数值
''' # 初始化权重参数为零矩阵
w = np.zeros((dims, 1)) # 初始化偏差参数为零
b = 0 print('w.shape',w.shape) return w, b
### 定义模型主体部分
### 包括线性回归公式、均方损失和参数偏导三部分
def linear_loss(X, y, w, b):
''' 输入:
X:输入变量矩阵
y:输出标签向量
w:变量参数权重矩阵
b:偏差项
输出: y_hat:线性模型预测输出
loss:均方损失值
dw:权重参数一阶偏导
db:偏差项一阶偏导
''' # 训练样本数量
num_train = X.shape[0] # 训练特征数量
num_feature = X.shape[1] # 线性回归预测输出
y_hat = np.dot(X, w) + b # 计算预测输出与实际标签之间的均方损失
loss = np.sum((y_hat-y)**2)/num_train # 基于均方损失对权重参数的一阶偏导数
dw = np.dot(X.T, (y_hat-y)) /num_train # 基于均方损失对偏差项的一阶偏导数
db = np.sum((y_hat-y)) /num_train return y_hat, loss, dw, db
### 定义线性回归模型训练过程
def linear_train(X, y, learning_rate=0.01, epochs=10000):
''' 输入:
X:输入变量矩阵
y:输出标签向量
learning_rate:学习率
epochs:训练迭代次数
输出: loss_his:每次迭代的均方损失
params:优化后的参数字典
grads:优化后的参数梯度字典
''' # 记录训练损失的空列表
loss_his = [] # 初始化模型参数
w, b = initialize_params(X.shape[1]) # 迭代训练
for i in range(1, epochs): # 计算当前迭代的预测值、损失和梯度
y_hat, loss, dw, db = linear_loss(X, y, w, b) # 基于梯度下降的参数更新
w += -learning_rate * dw b += -learning_rate * db # 记录当前迭代的损失
loss_his.append(loss) # 每1000次迭代打印当前损失信息
if i % 1000 == 0: print('epoch %d loss %f' % (i, loss)) # 将当前迭代步优化后的参数保存到字典
params = { 'w': w, 'b': b } # 将当前迭代步的梯度保存到字典
grads = { 'dw': dw, 'db': db } return loss_his, params, grads
linear_train(X_train, y_train)
w.shape (1, 1) epoch 1000 loss 54.673255 epoch 2000 loss 52.544814 epoch 3000 loss 50.857963 epoch 4000 loss 49.521084 epoch 5000 loss 48.461568 epoch 6000 loss 47.621871 epoch 7000 loss 46.956387 epoch 8000 loss 46.428972 epoch 9000 loss 46.010979
([588.0343564356435, 243.58805442113618, 122.70684288122406, 80.28309345783622, 65.39304088550571, 60.16560630941734, 58.329150981462085, 57.68271921745599, 57.45391226824732, 57.37166614462048, 57.340854221392505, 57.32809296110642, 57.32166681538127, 57.317464354254795, 57.314042721056, 57.310895562600535, 57.307845180263804, 57.30482921275983, 57.301825775045145, 57.298827186723564, 57.295830752347875, 57.29283552587841, 57.289841175208934, 57.28684758368567, 57.28385471026541, 57.28086254043945, 57.2778710690112, 57.27488029405198, 57.27189021478, 57.268900830816, 57.26591214192197, 57.26292414790952, 57.25993684860769, 57.256950243851655, 57.25396433347878, 57.25097911732721, 57.247994595235404, 57.24501076704195, 57.242027632585476, 57.239045191704705, 57.23606344423838, 57.23308239002526, 57.23010202890419, 57.22712236071399, 57.224143385293615, 57.22116510248193, 57.21818751211796, 57.21521061404067, 57.212234408089124, 57.20925889410239, 57.20628407191961, 57.203309941379935, 57.200336502322536, 57.19736375458666, 57.19439169801158, 57.19142033243659, 57.18844965770105, 57.18547967364433, 57.18251038010583, 57.17954177692505, 57.17657386394143, 57.17360664099454, 57.17064010792392, 57.16767426456921, 57.164709110770005, 57.161744646366, 57.15878087119695, 57.15581778510257, 57.152855387922656, 57.149893679497026, 57.14693265966558, 57.143972328268184, 57.141012685144794, 57.13805373013538, 57.13509546307998, 57.13213788381861, 57.12918099219139, 57.12622478803842, 57.12326927119988, 57.12031444151596, 57.1173602988269, 57.11440684297298, 57.111454073794505, 57.10850199113183, 57.105550594825345, 57.10259988471545, 57.09964986064263, 57.096700522447364, 57.0937518699702, 57.0908039030517, 57.08785662153248, 57.08491002525318, 57.081964114054486, 57.0790188877771, 57.07607434626183, 57.0731304893494, 57.07018731688067, 57.06724482869653, 57.06430302463786, 57.061361904545606, 57.05842146826075, 57.0554817156243, 57.05254264647731, 57.04960426066089, 57.04666655801614, 57.043729538384234, 57.04079320160637, 57.037857547523785, 57.03492257597777, 57.03198828680959, 57.02905467986064, 57.0261217549723, 57.023189511985954, 57.0202579507431, 57.01732707108523, 57.01439687285384, 57.01146735589053, 57.00853852003692, 57.00561036513462, 57.00268289102532, 56.999756097550744, 56.996829984552626, 56.99390455187279, 56.99097979935303, 56.98805572683523, 56.98513233416128, 56.98220962117311, 56.979287587712705, 56.97636623362209, 56.97344555874328, 56.97052556291839, 56.96760624598951, 56.9646876077988, 56.9617696481885, 56.958852367000794, 56.95593576407795, 56.953019839262296, 56.95010459239617, 56.94719002332194, 56.944276131882, 56.941362917918845, 56.93845038127494, 56.93553852179279, 56.93262733931498, 56.92971683368411, 56.9268070047428, 56.92389785233372, 56.9209893762996, 56.91808157648315, 56.91517445272717, 56.91226800487446, 56.90936223276791, 56.906457136250374, 56.903552715164786, 56.90064896935412, 56.89774589866136, 56.894843502929554, 56.891941782001766, 56.88904073572112, 56.88614036393074, 56.883240666473824, 56.88034164319358, 56.87744329393325, 56.87454561853617, 56.871648616845626, 56.86875228870499, 56.86585663395767, 56.8629616524471, 56.86006734401676, 56.857173708510146, 56.85428074577081, 56.85138845564234, 56.84849683796835, 56.84560589259249, 56.84271561935845, 56.839826018109974, 56.8369370886908, 56.83404883094477, 56.83116124471568, 56.82827432984741, 56.82538808618389, 56.82250251356905, 56.81961761184687, 56.81673338086138, 56.813849820456625, 56.81096693047668, 56.808084710765705, 56.805203161167846, 56.80232228152731, 56.79944207168833, 56.79656253149515, 56.793683660792134, 56.790805459423574, 56.78792792723387, 56.785051064067446, 56.78217486976875, 56.77929934418228, 56.77642448715254, 56.7735502985241, 56.770676778141556, 56.76780392584955, 56.76493174149275, 56.76206022491586, 56.759189375963615, 56.75631919448079, 56.75344968031222, 56.750580833302735, 56.747712653297235, 56.74484514014063, 56.74197829367788, 56.73911211375398, 56.73624660021398, 56.733381752902936, 56.73051757166593, 56.72765405634811, 56.72479120679466, 56.72192902285079, 56.71906750436174, 56.7162066511728, 56.71334646312928, 56.71048694007652, 56.707628081859966, 56.704769888324975, 56.701912359317056, 56.69905549468169, 56.696199294264396, 56.69334375791078, 56.69048888546642, 56.68763467677697, 56.6847811316881, 56.68192825004554, 56.67907603169502, 56.67622447648232, 56.67337358425329, 56.67052335485376, 56.66767378812964, 56.66482488392685, 56.66197664209136, 56.65912906246916, 56.656282144906285, 56.65343588924884, 56.650590295342894, 56.64774536303459, 56.64490109217012, 56.64205748259573, 56.639214534157624, 56.6363722467021, 56.63353062007549, 56.63068965412415, 56.62784934869447, 56.625009703632884, 56.62217071878584, 56.61933239399986, 56.61649472912148, 56.613657723997264, 56.61082137847381, 56.60798569239777, 56.60515066561584, 56.60231629797472, 56.59948258932114, 56.59664953950194, 56.59381714836388, 56.59098541575386, 56.58815434151876, 56.58532392550549, 56.582494167561045, 56.57966506753242, 56.57683662526663, 56.57400884061075, 56.5711817134119, 56.56835524351721, 56.56552943077388, 56.562704275029084, 56.55987977613011, 56.55705593392421, 56.55423274825874, 56.55141021898102, 56.54858834593845, 56.54576712897846, 56.54294656794852, 56.54012666269609, 56.53730741306875, 56.534488818914056, 56.531670880079595, 56.528853596413015, 56.52603696776199, 56.52322099397423, 56.52040567489748, 56.51759101037951, 56.514777000268154, 56.511963644411246, 56.50915094265669, 56.50633889485239, 56.50352750084632, 56.500716760486455, 56.49790667362084, 56.495097240097536, 56.49228845976462, 56.48948033247025, 56.48667285806258, 56.483866036389834, 56.48105986730024, 56.47825435064206, 56.475449486263635, 56.47264527401327, 56.469841713739385, 56.467038805290365, 56.464236548514684, 56.46143494326083, 56.458633989377304, 56.45583368671267, 56.45303403511554, 56.45023503443451, 56.447436684518266, 56.444638985215505, 56.44184193637495, 56.43904553784537, 56.43624978947557, 56.43345469111439, 56.43066024261072, 56.42786644381343, 56.4250732945715, 56.42228079473389, 56.41948894414961, 56.41669774266773, 56.41390719013732, 56.41111728640751, 56.40832803132744, 56.40553942474631, 56.40275146651334, 56.39996415647778, 56.39717749448896, 56.39439148039616, 56.391606114048784, 56.38882139529621, 56.38603732398789, 56.38325389997328, 56.3804711231019, 56.37768899322328, 56.374907510186986, 56.37212667384264, 56.36934648403987, 56.36656694062839, 56.36378804345788, 56.36100979237811, 56.358232187238855, 56.35545522788995, 56.35267891418122, 56.34990324596257, 56.347128223083935, 56.34435384539526, 56.34158011274654, 56.338807024987815, 56.33603458196914, 56.33326278354061, 56.33049162955237, 56.32772111985458, 56.324951254297446, 56.32218203273119, 56.319413455006114, 56.3166455209725, 56.31387823048071, 56.31111158338112, 56.30834557952412, 56.305580218760184, 56.302815500939765, 56.30005142591341, 56.29728799353164, 56.29452520364507, 56.2917630561043, 56.289001550759984, 56.28624068746284, 56.283480466063565, 56.28072088641292, 56.27796194836171, 56.275203651760755, 56.272445996460924, 56.269688982313106, 56.26693260916825, 56.264176876877315, 56.2614217852913, 56.25866733426125, 56.25591352363821, 56.25316035327332, 56.250407823017724, 56.24765593272255, 56.24490468223906, 56.242154071418476, 56.239404100112075, 56.236654768171164, 56.23390607544711, 56.23115802179129, 56.22841060705512, 56.22566383109004, 56.222917693747554, 56.220172194879176, 56.217427334336435, 56.21468311197098, 56.21193952763439, 56.209196581178325, 56.20645427245451, 56.20371260131465, 56.20097156761049, 56.198231171193854, 56.19549141191658, 56.192752289630505, 56.190013804187544, 56.18727595543963, 56.18453874323875, 56.18180216743688, 56.17906622788606, 56.176330924438375, 56.17359625694594, 56.17086222526087, 56.16812882923536, 56.165396068721606, 56.16266394357188, 56.159932453638426, 56.15720159877357, 56.15447137882967, 56.151741793659085, 56.14901284311426, 56.14628452704763, 56.14355684531166, 56.1408297977589, 56.13810338424189, 56.13537760461322, 56.13265245872549, 56.129927946431394, 56.127204067583605, 56.12448082203483, 56.12175820963787, 56.11903623024547, 56.11631488371048, 56.11359416988577, 56.11087408862423, 56.10815463977876, 56.10543582320238, 56.10271763874803, 56.10000008626878, 56.09728316561769, 56.09456687664785, 56.09185121921241, 56.08913619316452, 56.08642179835739, 56.08370803464428, 56.08099490187842, 56.07828239991316, 56.075570528601794, 56.072859287797726, 56.07014867735436, 56.067438697125134, 56.06472934696352, 56.062020626723026, 56.0593125362572, 56.05660507541963, 56.05389824406391, 56.05119204204372, 56.04848646921269, 56.045781525424566, 56.0430772105331, 56.04037352439204, 56.03767046685526, 56.034968037776565, 56.032266237009836, 56.02956506440902, 56.026864519828045, 56.024164603120916, 56.021465314141636, 56.018766652744276, 56.0160686187829, 56.01337121211166, 56.010674432584686, 56.00797828005618, 56.00528275438037, 56.002587855411505, 55.999893583003875, 55.9971999370118, 55.994506917289655, 55.99181452369182, 55.98912275607274, 55.986431614286865, 55.98374109818866, 55.9810512076327, 55.97836194247352, 55.97567330256573, 55.972985287763954, 55.97029789792284, 55.96761113289711, 55.964924992541484, 55.96223947671072, 55.95955458525964, 55.956870318043045, 55.95418667491583, 55.95150365573287, 55.94882126034911, 55.94613948861953, 55.94345834039911, 55.94077781554289, 55.93809791390593, 55.93541863534335, 55.932739979710306, 55.93006194686191, 55.927384536653406, 55.92470774894002, 55.92203158357703, 55.919356040419736, 55.91668111932347, 55.9140068201436, 55.911333142735536, 55.908660086954725, 55.90598765265665, 55.90331583969677, 55.900644647930676, 55.89797407721391, 55.89530412740209, 55.89263479835086, 55.88996608991589, 55.88729800195287, 55.884630534317566, 55.88196368686574, 55.87929745945321, 55.876631851935805, 55.87396686416941, 55.87130249600993, 55.8686387473133, 55.865975617935504, 55.86331310773256, 55.860651216560484, 55.85798994427536, 55.855329290733316, 55.8526692557905, 55.85000983930304, 55.84735104112721, 55.844692861119185, 55.842035299135304, 55.83937835503185, 55.83672202866516, 55.83406631989162, 55.83141122856762, 55.82875675454964, 55.82610289769413, 55.823449657857616, 55.82079703489662, 55.81814502866776, 55.8154936390276, 55.812842865832806, 55.81019270894006, 55.80754316820605, 55.80489424348755, 55.80224593464132, 55.79959824152418, 55.79695116399296, 55.79430470190453, 55.79165885511582, 55.78901362348377, 55.78636900686536, 55.78372500511758, 55.781081618097495, 55.77843884566216, 55.77579668766871, 55.773155143974265, 55.770514214436034, 55.76787389891119, 55.76523419725698, 55.7625951093307, 55.75995663498964, 55.75731877409116, 55.75468152649261, 55.75204489205142, 55.74940887062503, 55.74677346207092, 55.74413866624657, 55.74150448300955, 55.73887091221742, 55.73623795372778, 55.733605607398296, 55.73097387308661, 55.728342750650455, 55.725712239947555, 55.72308234083569, 55.72045305317266, 55.71782437681631, 55.71519631162451, 55.71256885745516, 55.7099420141662, 55.7073157816156, 55.704690159661375, 55.702065148161545, 55.699440746974204, 55.696816955957416, 55.69419377496936, 55.69157120386817, 55.688949242512074, 55.686327890759294, 55.68370714846809, 55.681087015496786, 55.6784674917037, 55.675848576947196, 55.67323027108567, 55.670612573977586, 55.667995485481384, 55.66537900545554, 55.66276313375863, 55.6601478702492, 55.657533214785836, 55.654919167227185, 55.652305727431894, 55.649692895258674, 55.64708067056624, 55.64446905321337, 55.64185804305884, 55.63924763996149, 55.636637843780164, 55.63402865437378, 55.631420071601234, 55.628812095321514, 55.62620472539359, 55.6235979616765, 55.62099180402927, 55.61838625231103, 55.61578130638088, 55.61317696609798, 55.61057323132152, 55.607970101910695, 55.60536757772481, 55.6027656586231, 55.60016434446491, 55.59756363510959, 55.59496353041652, 55.59236403024511, 55.58976513445483, 55.58716684290513, 55.584569155455554, 55.58197207196564, 55.579375592294944, 55.57677971630312, 55.57418444384979, 55.571589774794646, 55.56899570899736, 55.56640224631772, 55.563809386615496, 55.56121712975047, 55.558625475582495, 55.556034423971454, 55.55344397477727, 55.55085412785983, 55.54826488307915, 55.5456762402952, 55.543088199368036, 55.540500760157734, 55.53791392252438, 55.535327686328124, 55.532742051429125, 55.530157017687564, 55.527572584963714, 55.52498875311779, 55.522405522010125, 55.51982289150103, 55.51724086145087, 55.51465943172005, 55.51207860216898, 55.50949837265812, 55.50691874304798, 55.504339713199045, 55.50176128297193, 55.49918345222717, 55.49660622082542, 55.4940295886273, 55.49145355549353, 55.488878121284806, 55.486303285861894, 55.48372904908556, 55.48115541081664, 55.47858237091596, 55.47600992924441, 55.473438085662906, 55.470866840032386, 55.468296192213835, 55.46572614206825, 55.46315668945668, 55.4605878342402, 55.45801957627992, 55.455451915436974, 55.45288485157254, 55.45031838454779, 55.44775251422401, 55.44518724046242, 55.44262256312436, 55.44005848207112, 55.43749499716411, 55.43493210826468, 55.4323698152343, 55.42980811793441, 55.427247016226495, 55.424686509972105, 55.42212659903278, 55.419567283270126, 55.41700856254573, 55.4144504367213, 55.411892905658476, 55.409335969219, 55.406779627264605, 55.404223879657096, 55.40166872625828, 55.39911416692998, 55.396560201534115, 55.39400682993257, 55.391454051987296, 55.38890186756026, 55.3863502765135, 55.383799278709006, 55.3812488740089, 55.37869906227524, 55.3761498433702, 55.37360121715594, 55.37105318349464, 55.36850574224854, 55.36595889327993, 55.363412636451066, 55.3608669716243, 55.358321898662, 55.35577741742653, 55.35323352778033, 55.35069022958586, 55.34814752270562, 55.34560540700208, 55.34306388233786, 55.34052294857549, 55.337982605577615, 55.335442853206864, 55.33290369132593, 55.33036511979753, 55.32782713848439, 55.32528974724929, 55.322752945955045, 55.320216734464495, 55.31768111264052, 55.315146080346, 55.312611637443865, 55.31007778379711, 55.30754451926873, 55.305011843721736, 55.3024797570192, 55.29994825902422, 55.297417349599925, 55.29488702860946, 55.29235729591603, 55.28982815138284, 55.287299594873154, 55.28477162625026, 55.28224424537746, 55.279717452118135, 55.277191246335626, 55.27466562789336, 55.27214059665479, 55.26961615248338, 55.267092295242655, 55.26456902479614, 55.2620463410074, 55.25952424374005, 55.257002732857714, 55.254481808224064, 55.25196146970281, 55.24944171715766, 55.246922550452375, 55.24440396945078, 55.24188597401667, 55.2393685640139, 55.236851739306374, 55.234335499758, 55.23181984523273, 55.22930477559455, 55.22679029070747, 55.22427639043555, 55.22176307464285, 55.21925034319348, 55.2167381959516, 55.21422663278137, 55.21171565354699, 55.20920525811269, 55.20669544634274, 55.204186218101455, 55.20167757325317, 55.19916951166221, 55.19666203319301, 55.19415513770996, 55.191648825077536, 55.18914309516021, 55.18663794782252, 55.18413338292901, 55.18162940034426, 55.179125999932886, 55.17662318155953, 55.174120945088866, 55.17161929038562, 55.16911821731452, 55.16661772574033, 55.16411781552785, 55.16161848654193, 55.159119738647426, 55.15662157170924, 55.154123985592285, 55.15162698016155, 55.149130555282014, 55.14663471081868, 55.144139446636615, 55.14164476260091, 55.13915065857668, 55.136657134429065, 55.13416419002325, 55.13167182522445, 55.129180039897896, 55.12668883390887, 55.12419820712268, 55.12170815940464, 55.11921869062014, 55.116729800634566, 55.114241489313365, 55.11175375652198, 55.109266602125906, 55.10678002599068, 55.10429402798184, 55.10180860796498, 55.09932376580573, 55.09683950136971, 55.094355814522615, 55.09187270513016, 55.089390173058085, 55.08690821817218, 55.08442684033822, 55.08194603942204, 55.07946581528954, 55.07698616780658, 55.07450709683912, 55.07202860225311, 55.06955068391454, 55.06707334168944, 55.06459657544386, 55.062120385043876, 55.05964477035561, 55.057169731245224, 55.05469526757887, 55.052221379222786, 55.04974806604319, 55.04727532790637, 55.04480316467862, 55.04233157622629, 55.039860562415726, 55.03739012311332, 55.03492025818552, 55.032450967498775, 55.02998225091957, 55.02751410831443, 55.025046539549905, 55.02257954449258, 55.020113123009075, 55.01764727496602, 55.015182000230105, 55.01271729866803, 55.01025317014653, 55.00778961453237, 55.00532663169236, 55.002864221493326, 55.000402383802125, 54.99794111848566, 54.99548042541086, 54.993020304444656, 54.99056075545405, 54.98810177830606, 54.98564337286772, 54.983185539006115, 54.980728276588366, 54.97827158548158, 54.975815465552955, 54.9733599166697, 54.97090493869903, 54.968450531508196, 54.96599669496452, 54.963543428935324, 54.961090733287946, 54.95863860788979, 54.95618705260826, 54.95373606731081, 54.95128565186493, 54.94883580613811, 54.9463865299979, 54.94393782331189, 54.94148968594765, 54.93904211777283, 54.9365951186551, 54.93414868846214, 54.93170282706169, 54.929257534321486, 54.92681281010934, 54.92436865429306, 54.92192506674049, 54.919482047319505, 54.91703959589803, 54.914597712344005, 54.91215639652539, 54.90971564831017, 54.907275467566414, 54.90483585416217, 54.90239680796553, 54.89995832884463, 54.897520416667625, 54.89508307130268, 54.89264629261804, 54.89021008048191, 54.88777443476262, 54.88533935532844, 54.88290484204773, 54.88047089478885, 54.87803751342021, 54.87560469781022, 54.873172447827365, 54.87074076334014, 54.868309644217035, 54.865879090326644, 54.863449101537526, 54.861019677718296, 54.85859081873762, 54.85616252446414, 54.85373479476659, 54.85130762951369, 54.848881028574226, 54.84645499181698, 54.84402951911078, 54.84160461032449, 54.839180265327, 54.83675648398723, 54.83433326617412, 54.83191061175667, 54.829488520603874, 54.82706699258478, 54.824646027568456, 54.822225625424004, 54.819805786020574, 54.81738650922731, 54.81496779491341, 54.81254964294809, 54.81013205320063, 54.807715025540304, 54.80529855983641, 54.8028826559583, 54.800467313775364, 54.798052533157005, 54.79563831397265, 54.79322465609178, 54.79081155938387, 54.788399023718476, 54.78598704896514, 54.78357563499345, 54.78116478167303, 54.77875448887352, 54.77634475646462, 54.773935584316014, 54.77152697229745, 54.769118920278714, 54.76671142812959, 54.76430449571991, 54.76189812291955, 54.759492309598386, 54.757087055626336, 54.75468236087336, 54.75227822520946, 54.74987464850461, 54.74747163062888, 54.74506917145233, 54.74266727084507, 54.74026592867724, 54.73786514481897, 54.73546491914048, 54.733065251512016, 54.73066614180378, 54.72826758988609, 54.72586959562927, 54.72347215890363, 54.72107527957958, 54.71867895752749, 54.716283192617816, 54.71388798472102, 54.7114933337076, 54.70909923944807, 54.706705701813, 54.70431272067297, 54.701920295898574, 54.699528427360484, 54.69713711492938, 54.69474635847595, 54.69235615787093, 54.6899665129851, 54.68757742368925, 54.685188889854196, 54.6828009113508, 54.68041348804996, 54.678026619822575, 54.675640306539606, 54.67325454807201, ...], {'w': array([[7.13563861]]), 'b': -22.08974647222114}, {'dw': array([[-0.01889231]]), 'db': 0.11982501271954342})
from sklearn.linear_model import LinearRegression as LR
lr = LR()
model = lr.fit(X_train, y_train)
y_pre = model.predict(X_test)
model.coef_
array([[8.76050748]])
model.intercept_
array([-32.39552265])
from sklearn.metrics import r2_score
r2_score(y_pre,y_test)
#或者直接指定参数
r2_score(y_true = y_test,y_pred = y_pre)
0.5877214395051775