機器學習 - 線性回歸

回歸

當您嘗試找到變量之間的關系時,會用到術語“回歸”(regression)。

在機器學習和統計建模中,這種關系用于預測未來事件的結果。

線性回歸

線性回歸使用數據點之間的關系在所有數據點之間畫一條直線。

這條線可以用來預測未來的值。


在機器學習中,預測未來非常重要。

工作原理

Python 提供了一些方法來查找數據點之間的關系并繪制線性回歸線。我們將向您展示如何使用這些方法而不是通過數學公式。

在下面的示例中,x 軸表示車齡,y 軸表示速度。我們已經記錄了 13 輛汽車通過收費站時的車齡和速度。讓我們看看我們收集的數據是否可以用于線性回歸:

實例

首先繪制散點圖:

import matplotlib.pyplot as plt
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
plt.scatter(x, y)
plt.show()

結果:


運行實例

實例

導入 scipy 并繪制線性回歸線:

import matplotlib.pyplot as plt
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

結果:


運行實例

例子解釋

導入所需模塊:

import matplotlib.pyplot as plt
from scipy import stats

創建表示 x 和 y 軸值的數組:

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

執行一個方法,該方法返回線性回歸的一些重要鍵值:

slope, intercept, r, p, std_err = stats.linregress(x, y)

創建一個使用 slopeintercept 值的函數返回新值。這個新值表示相應的 x 值將在 y 軸上放置的位置:

def myfunc(x):
  return slope * x + intercept

通過函數運行 x 數組的每個值。這將產生一個新的數組,其中的 y 軸具有新值:

mymodel = list(map(myfunc, x))

繪制原始散點圖:

plt.scatter(x, y)

繪制線性回歸線:

plt.plot(x, mymodel)

顯示圖:

plt.show()

R-Squared

重要的是要知道 x 軸的值和 y 軸的值之間的關系有多好,如果沒有關系,則線性回歸不能用于預測任何東西。

該關系用一個稱為 r 平方(r-squared)的值來度量。

r 平方值的范圍是 0 到 1,其中 0 表示不相關,而 1 表示 100% 相關。

Python 和 Scipy 模塊將為您計算該值,您所要做的就是將 x 和 y 值提供給它:

實例

我的數據在線性回歸中的擬合度如何?

from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)

運行實例

注釋:結果 -0.76 表明存在某種關系,但不是完美的關系,但它表明我們可以在將來的預測中使用線性回歸。

預測未來價值

現在,我們可以使用收集到的信息來預測未來的值。

例如:讓我們嘗試預測一輛擁有 10 年歷史的汽車的速度。

為此,我們需要與上例中相同的 myfunc() 函數:

def myfunc(x):
  return slope * x + intercept

實例

預測一輛有 10年車齡的汽車的速度:

from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
speed = myfunc(10)
print(speed)

運行實例

該例預測速度為 85.6,我們也可以從圖中讀取:


糟糕的擬合度?

讓我們創建一個實例,其中的線性回歸并不是預測未來值的最佳方法。

實例

x 和 y 軸的這些值將導致線性回歸的擬合度非常差:

import matplotlib.pyplot as plt
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

結果:


運行實例

以及 r-squared 值?

實例

您應該得到了一個非常低的 r-squared 值。

import numpy
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)

運行實例

結果:0.013 表示關系很差,并告訴我們該數據集不適合線性回歸。