機器學習 - 多項式回歸
多項式回歸(Polynomial Regression)
如果您的數據點顯然不適合線性回歸(穿過數據點之間的直線),那么多項式回歸可能是理想的選擇。
像線性回歸一樣,多項式回歸使用變量 x 和 y 之間的關系來找到繪制數據點線的最佳方法。

工作原理
Python 有一些方法可以找到數據點之間的關系并畫出多項式回歸線。我們將向您展示如何使用這些方法而不是通過數學公式。
在下面的例子中,我們注冊了 18 輛經過特定收費站的汽車。
我們已經記錄了汽車的速度和通過時間(小時)。
x 軸表示一天中的小時,y 軸表示速度:
實例
首先繪制散點圖:
import matplotlib.pyplot as plt x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22] y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100] plt.scatter(x, y) plt.show()
結果:

實例
導入 numpy
和 matplotlib
,然后畫出多項式回歸線:
import numpy import matplotlib.pyplot as plt x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22] y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100] mymodel = numpy.poly1d(numpy.polyfit(x, y, 3)) myline = numpy.linspace(1, 22, 100) plt.scatter(x, y) plt.plot(myline, mymodel(myline)) plt.show()
結果:

例子解釋
導入所需模塊:
import numpy import matplotlib.pyplot as plt
創建表示 x 和 y 軸值的數組:
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22] y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
NumPy 有一種方法可以讓我們建立多項式模型:
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
然后指定行的顯示方式,我們從位置 1 開始,到位置 22 結束:
myline = numpy.linspace(1, 22, 100)
繪制原始散點圖:
plt.scatter(x, y)
畫出多項式回歸線:
plt.plot(myline, mymodel(myline))
顯示圖表:
plt.show()
R-Squared
重要的是要知道 x 軸和 y 軸的值之間的關系有多好,如果沒有關系,則多項式回歸不能用于預測任何東西。
該關系用一個稱為 r 平方( r-squared)的值來度量。
r 平方值的范圍是 0 到 1,其中 0 表示不相關,而 1 表示 100% 相關。
Python 和 Sklearn 模塊將為您計算該值,您所要做的就是將 x 和 y 數組輸入:
實例
我的數據在多項式回歸中的擬合度如何?
import numpy from sklearn.metrics import r2_score x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22] y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100] mymodel = numpy.poly1d(numpy.polyfit(x, y, 3)) print(r2_score(y, mymodel(x)))
注釋:結果 0.94 表明存在很好的關系,我們可以在將來的預測中使用多項式回歸。
預測未來值
現在,我們可以使用收集到的信息來預測未來的值。
例如:讓我們嘗試預測在晚上 17 點左右通過收費站的汽車的速度:
為此,我們需要與上面的實例相同的 mymodel 數組:
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
實例
預測下午 17 點過車的速度:
import numpy from sklearn.metrics import r2_score x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22] y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100] mymodel = numpy.poly1d(numpy.polyfit(x, y, 3)) speed = mymodel(17) print(speed)
該例預測速度為 88.87,我們也可以在圖中看到:

糟糕的擬合度?
讓我們創建一個實例,其中多項式回歸不是預測未來值的最佳方法。
實例
x 和 y 軸的這些值會導致多項式回歸的擬合度非常差:
import numpy import matplotlib.pyplot as plt x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40] y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15] mymodel = numpy.poly1d(numpy.polyfit(x, y, 3)) myline = numpy.linspace(2, 95, 100) plt.scatter(x, y) plt.plot(myline, mymodel(myline)) plt.show()
結果:

r-squared 值呢?
實例
您應該得到一個非常低的 r-squared 值。
import numpy from sklearn.metrics import r2_score x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40] y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15] mymodel = numpy.poly1d(numpy.polyfit(x, y, 3)) print(r2_score(y, mymodel(x)))
結果:0.00995 表示關系很差,并告訴我們該數據集不適合多項式回歸。