Liniowa regresja w uczeniu maszynowym

Regresja

Gdy próbujesz znaleźć relacje między zmiennymi, używa się terminu „regresja” (regression).

W uczeniu maszynowym i modelowaniu statystycznym, ta relacja jest używana do przewidywania wyników przyszłych wydarzeń.

Regresja liniowa

Regresja liniowa rysuje prostą, łączącą wszystkie punkty danych na podstawie zależności między punktami.

Ta linia może być używana do przewidywania przyszłych wartości.


W uczeniu maszynowym przewidywanie przyszłości jest bardzo ważne.

Działanie

Python oferuje różne metody do wyszukiwania zależności między punktami danych i rysowania linii regresji. Pokażemy Ci, jak używać tych metod zamiast matematycznych wzorów.

W poniższym przykładzie, oś x reprezentuje wiek pojazdu, a oś y prędkość. Zarejestrowaliśmy wiek i prędkość 13 samochodów, które przeszły przez bramę płatną. Zobaczmy, czy zebrane dane mogą być użyte do regresji liniowej:

Przykład

Najpierw narysuj wykres punktów:

import matplotlib.pyplot as plt
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
plt.scatter(x, y)
plt.show()

Wynik:


Uruchom przykład

Przykład

zaimportuj scipy i narysuj linię regresji:

import matplotlib.pyplot as plt
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

Wynik:


Uruchom przykład

Przykład wyjaśnienie

Zaimportuj wymagane moduły:

import matplotlib.pyplot as plt
from scipy import stats

Stwórz tablicę reprezentującą wartości osi x i osi y:

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

Wykonaj metodę, która zwraca niektóre ważne kluczowe wartości regresji liniowej:

slope, intercept, r, p, std_err = stats.linregress(x, y)

Stwórz coś, co używa slope i intercept Funkcja wartości zwraca nowe wartości. Nowa wartość oznacza pozycję, w której odpowiednia wartość x zostanie umieszczona na osi y:

def myfunc(x):
  return slope * x + intercept

Przeprowadź funkcję dla każdego wartości z tablicy x. To wygeneruje nową tablicę, w której osa y ma nowe wartości:

mymodel = list(map(myfunc, x))

Narysuj oryginalny wykres punktów:

plt.scatter(x, y)

Narysuj linię regresji:

plt.plot(x, mymodel)

Wyświetl wykres:

plt.show()

R-Squared

Ważne jest, aby wiedzieć, jak dobrze są ze sobą powiązane wartości osi x i osi y, ponieważ jeśli nie ma związku, regresja liniowa nie może być użyta do przewidywania niczego.

Relacja ta jest mierzona wartością nazywaną r kwadrat (r-squared).

Zakres wartości r kwadratowego wynosi od 0 do 1, gdzie 0 oznacza brak korelacji, a 1 oznacza 100% korelacji.

Moduł Python i Scipy obliczą dla Ciebie tę wartość, musisz tylko dostarczyć mu wartości x i y:

Przykład

Jakie jest dopasowanie moich danych w regresji liniowej?

from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)

Uruchom przykład

Komentarz:Wynik -0.76 wskazuje na pewne powiązanie, ale nie idealne, ale wskazuje, że możemy użyć regresji liniowej do przewidywań w przyszłości.

Przewidywanie przyszłych wartości

Teraz możemy użyć zebranych informacji do przewidywania przyszłych wartości.

Na przykład: spróbujmy przewidzieć prędkość samochodu o 10-letnim przebiegu.

Dla tego potrzebujemy tego samego co w poprzednim przykładzie myfunc() Funkcja:

def myfunc(x):
  return slope * x + intercept

Przykład

Przewiduj prędkość samochodu o 10-letnim przebiegu:

from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
speed = myfunc(10)
print(speed)

Uruchom przykład

Przewidywana prędkość wynosi 85.6, możemy to również odczytać z wykresu:


Zła dopasowanie?

Pozwólmy stworzyć przykład, w którym regresja liniowa nie jest najlepszym sposobem przewidywania przyszłych wartości.

Przykład

Te wartości osi x i y spowodują bardzo słabe dopasowanie regresji liniowej:

import matplotlib.pyplot as plt
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(x):
  return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

Wynik:


Uruchom przykład

i wartość r-squared?

Przykład

Powinieneś otrzymać bardzo niską wartość r-squared.

import numpy
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)

Uruchom przykład

Wynik: 0,013 oznacza bardzo słabe powiązanie, i mówi nam, że zestaw danych nie jest odpowiedni do regresji liniowej.