Liniowa regresja w uczeniu maszynowym
- Poprzednia strona Dyskretna diagram
- Następna strona Regressja wielokrotna
Regresja
Gdy próbujesz znaleźć relacje między zmiennymi, używa się terminu „regresja” (regression).
W uczeniu maszynowym i modelowaniu statystycznym, ta relacja jest używana do przewidywania wyników przyszłych wydarzeń.
Regresja liniowa
Regresja liniowa rysuje prostą, łączącą wszystkie punkty danych na podstawie zależności między punktami.
Ta linia może być używana do przewidywania przyszłych wartości.

W uczeniu maszynowym przewidywanie przyszłości jest bardzo ważne.
Działanie
Python oferuje różne metody do wyszukiwania zależności między punktami danych i rysowania linii regresji. Pokażemy Ci, jak używać tych metod zamiast matematycznych wzorów.
W poniższym przykładzie, oś x reprezentuje wiek pojazdu, a oś y prędkość. Zarejestrowaliśmy wiek i prędkość 13 samochodów, które przeszły przez bramę płatną. Zobaczmy, czy zebrane dane mogą być użyte do regresji liniowej:
Przykład
Najpierw narysuj wykres punktów:
import matplotlib.pyplot as plt x = [5,7,8,7,2,17,2,9,4,11,12,9,6] y = [99,86,87,88,111,86,103,87,94,78,77,85,86] plt.scatter(x, y) plt.show()
Wynik:

Przykład
zaimportuj scipy
i narysuj linię regresji:
import matplotlib.pyplot as plt from scipy import stats x = [5,7,8,7,2,17,2,9,4,11,12,9,6] y = [99,86,87,88,111,86,103,87,94,78,77,85,86] slope, intercept, r, p, std_err = stats.linregress(x, y) def myfunc(x): return slope * x + intercept mymodel = list(map(myfunc, x)) plt.scatter(x, y) plt.plot(x, mymodel) plt.show()
Wynik:

Przykład wyjaśnienie
Zaimportuj wymagane moduły:
import matplotlib.pyplot as plt from scipy import stats
Stwórz tablicę reprezentującą wartości osi x i osi y:
x = [5,7,8,7,2,17,2,9,4,11,12,9,6] y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
Wykonaj metodę, która zwraca niektóre ważne kluczowe wartości regresji liniowej:
slope, intercept, r, p, std_err = stats.linregress(x, y)
Stwórz coś, co używa slope
i intercept
Funkcja wartości zwraca nowe wartości. Nowa wartość oznacza pozycję, w której odpowiednia wartość x zostanie umieszczona na osi y:
def myfunc(x): return slope * x + intercept
Przeprowadź funkcję dla każdego wartości z tablicy x. To wygeneruje nową tablicę, w której osa y ma nowe wartości:
mymodel = list(map(myfunc, x))
Narysuj oryginalny wykres punktów:
plt.scatter(x, y)
Narysuj linię regresji:
plt.plot(x, mymodel)
Wyświetl wykres:
plt.show()
R-Squared
Ważne jest, aby wiedzieć, jak dobrze są ze sobą powiązane wartości osi x i osi y, ponieważ jeśli nie ma związku, regresja liniowa nie może być użyta do przewidywania niczego.
Relacja ta jest mierzona wartością nazywaną r kwadrat (r-squared).
Zakres wartości r kwadratowego wynosi od 0 do 1, gdzie 0 oznacza brak korelacji, a 1 oznacza 100% korelacji.
Moduł Python i Scipy obliczą dla Ciebie tę wartość, musisz tylko dostarczyć mu wartości x i y:
Przykład
Jakie jest dopasowanie moich danych w regresji liniowej?
from scipy import stats x = [5,7,8,7,2,17,2,9,4,11,12,9,6] y = [99,86,87,88,111,86,103,87,94,78,77,85,86] slope, intercept, r, p, std_err = stats.linregress(x, y) print(r)
Komentarz:Wynik -0.76 wskazuje na pewne powiązanie, ale nie idealne, ale wskazuje, że możemy użyć regresji liniowej do przewidywań w przyszłości.
Przewidywanie przyszłych wartości
Teraz możemy użyć zebranych informacji do przewidywania przyszłych wartości.
Na przykład: spróbujmy przewidzieć prędkość samochodu o 10-letnim przebiegu.
Dla tego potrzebujemy tego samego co w poprzednim przykładzie myfunc()
Funkcja:
def myfunc(x): return slope * x + intercept
Przykład
Przewiduj prędkość samochodu o 10-letnim przebiegu:
from scipy import stats x = [5,7,8,7,2,17,2,9,4,11,12,9,6] y = [99,86,87,88,111,86,103,87,94,78,77,85,86] slope, intercept, r, p, std_err = stats.linregress(x, y) def myfunc(x): return slope * x + intercept speed = myfunc(10) print(speed)
Przewidywana prędkość wynosi 85.6, możemy to również odczytać z wykresu:

Zła dopasowanie?
Pozwólmy stworzyć przykład, w którym regresja liniowa nie jest najlepszym sposobem przewidywania przyszłych wartości.
Przykład
Te wartości osi x i y spowodują bardzo słabe dopasowanie regresji liniowej:
import matplotlib.pyplot as plt from scipy import stats x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40] y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15] slope, intercept, r, p, std_err = stats.linregress(x, y) def myfunc(x): return slope * x + intercept mymodel = list(map(myfunc, x)) plt.scatter(x, y) plt.plot(x, mymodel) plt.show()
Wynik:

i wartość r-squared?
Przykład
Powinieneś otrzymać bardzo niską wartość r-squared.
import numpy from scipy import stats x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40] y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15] slope, intercept, r, p, std_err = stats.linregress(x, y) print(r)
Wynik: 0,013 oznacza bardzo słabe powiązanie, i mówi nam, że zestaw danych nie jest odpowiedni do regresji liniowej.
- Poprzednia strona Dyskretna diagram
- Następna strona Regressja wielokrotna