Học máy - Cây quyết định
- Trang Trước Huấn Luyện/Kiểm Tra
- Trang Tiếp Theo MySQL Cơ Bản

Cây quyết định (Decision Tree)
Trong chương này, chúng ta sẽ hướng dẫn bạn cách tạo ra 'Cây quyết định'. Cây quyết định là một quy trình đồ họa giúp bạn ra quyết định dựa trên kinh nghiệm trước đây.
Trong ví dụ này, một người sẽ thử quyết định liệu họ có nên tham gia chương trình hài kịch hay không.
May mắn rằng, trong ví dụ của chúng ta, nhân vật mỗi lần tổ chức chương trình hài kịch ở thị trấn đều đăng ký, và đăng ký một số thông tin về diễn viên hài, và còn ghi lại xem họ có đi qua hay không.
Tuổi | Kinh nghiệm | Rank | Nationality | Go |
---|---|---|---|---|
36 | 10 | 9 | UK | NO |
42 | 12 | 4 | USA | NO |
23 | 4 | 6 | N | NO |
52 | 4 | 4 | USA | NO |
43 | 21 | 8 | USA | YES |
44 | 14 | 5 | UK | NO |
66 | 3 | 7 | N | YES |
35 | 14 | 9 | UK | YES |
52 | 13 | 7 | N | YES |
35 | 5 | 9 | N | YES |
24 | 3 | 5 | USA | NO |
18 | 3 | 7 | UK | YES |
45 | 9 | 9 | UK | YES |
Hiện tại, dựa trên bộ dữ liệu này, Python có thể tạo ra cây quyết định, cây quyết định này có thể được sử dụng để quyết định liệu có đáng tham gia bất kỳ buổi biểu diễn mới nào hay không.
Cách hoạt động
Trước tiên, nhập các mô-đun cần thiết và sử dụng pandas để đọc bộ dữ liệu:
Ví dụ
Đọc và in bộ dữ liệu:
import pandas from sklearn import tree import pydotplus from sklearn.tree import DecisionTreeClassifier import matplotlib.pyplot as plt import matplotlib.image as pltimg df = pandas.read_csv("shows.csv") print(df)
Để tạo ra cây quyết định, tất cả dữ liệu phải là số.
Chúng ta phải chuyển đổi các cột không phải số 'Nationality' và 'Go' thành số.
Pandas có một map()
Phương pháp, phương pháp này chấp nhận từ điển chứa thông tin về cách chuyển đổi giá trị.
{'UK': 0, 'USA': 1, 'N': 2}
Điều này có nghĩa là chuyển đổi giá trị 'UK' thành 0, chuyển đổi 'USA' thành 1, chuyển đổi 'N' thành 2.
Ví dụ
Chuyển đổi giá trị chuỗi thành số:
d = {'UK': 0, 'USA': 1, 'N': 2} df['Nationality'] = df['Nationality'].map(d) d = {'YES': 1, 'NO': 0} df['Go'] = df['Go'].map(d) print(df)
Sau đó, chúng ta phải tách riêng cột đặc trưng và cột mục tiêu.
Cột đặc trưng là cột mà chúng ta thử dự đoán, cột mục tiêu là cột có giá trị mà chúng ta thử dự đoán.
Ví dụ
X là cột đặc trưng, y là cột mục tiêu:
features = ['Age', 'Experience', 'Rank', 'Nationality'] X = df[features] y = df['Go'] print(X) print(y)
Hiện tại, chúng ta có thể tạo ra cây quyết định thực tế, phù hợp với chi tiết của chúng ta, sau đó lưu một tệp .png trên máy tính:
Ví dụ
Tạo một cây quyết định, lưu nó dưới dạng ảnh và hiển thị ảnh đó:
dtree = DecisionTreeClassifier() dtree = dtree.fit(X, y) data = tree.export_graphviz(dtree, out_file=None, feature_names=features) graph = pydotplus.graph_from_dot_data(data) graph.write_png('mydecisiontree.png') img=pltimg.imread('mydecisiontree.png') imgplot = plt.imshow(img) plt.show()
Giải thích kết quả
Cây quyết định sử dụng quyết định trước đó của bạn để tính toán xác suất bạn có muốn đi xem diễn viên hài.
Hãy đọc các góc độ khác nhau của cây quyết định:

Rank
Rank <= 6.5
Chỉ ra rằng các diễn viên hài có mức độRank dưới 6.5 sẽ tuân theo True
Mũi tên (đi về trái), còn lại thì tuân theo False
Mũi tên (đi về phải).
gini = 0.497
Chỉ ra chất lượng chia mẫu, và luôn là số từ 0.0 đến 0.5, trong đó 0.0 biểu thị tất cả các mẫu đều nhận được cùng một kết quả, còn 0.5 biểu thị chia hoàn toàn ở giữa.
samples = 13
Chỉ ra rằng tại điểm quyết định này vẫn còn 13 diễn viên hài, vì đây là bước đầu tiên,所以他们 đều là diễn viên hài.
value = [6, 7]
Chỉ ra rằng trong 13 diễn viên hài này, có 6 người sẽ nhận được "NO", còn 7 người sẽ nhận được "GO".
Gini
Có rất nhiều phương pháp để chia mẫu, trong hướng dẫn này chúng ta sử dụng phương pháp GINI.
Phương pháp Gini sử dụng công thức sau:
Gini = 1 - (x/n)2 - (y/n)2
Trong đó, x là số lượng câu trả lời tích cực ("GO"), n là số lượng mẫu, y là số lượng câu trả lời tiêu cực ("NO"), sử dụng công thức sau để tính toán:
1 - (7 / 13)2 - (6 / 13)2 = 0.497

Bước tiếp theo bao gồm hai khung, một khung cho diễn viên hài có 'Rank' là 6.5 hoặc thấp hơn, còn lại là một khung.
True - 5 diễn viên hài kết thúc ở đây:
gini = 0.0
Chỉ ra rằng tất cả các mẫu đều nhận được cùng một kết quả.
samples = 5
Chỉ ra rằng trong nhánh này vẫn còn 5 diễn viên hài (5 diễn viên hài có mức độ 6.5 hoặc thấp hơn).
value = [5, 0]
Điều này có nghĩa là 5 nhận được "NO" và 0 nhận được "GO".
False - 8 tên diễn viên kịch tiếp tục:
Nationality (quốc tịch)
Nationality <= 0.5
Điều này có nghĩa là các diễn viên hài có giá trị quốc tịch nhỏ hơn 0.5 sẽ theo mũi tên bên trái (điều này có nghĩa là tất cả mọi người đến từ Anh), còn lại sẽ theo mũi tên bên phải.
gini = 0.219
Điều này có nghĩa là khoảng 22% mẫu sẽ di chuyển theo một hướng.
samples = 8
Điều này có nghĩa là trong nhánh này còn lại 8 diễn viên hài (8 diễn viên hài có评级 cao hơn 6.5).
value = [1, 7]
Điều này có nghĩa là trong 8 diễn viên hài này, 1 sẽ nhận được "NO" và 7 sẽ nhận được "GO".

True - 4 tên diễn viên kịch tiếp tục:
Age (tuổi)
Age <= 35.5
Điều này có nghĩa là các diễn viên hài dưới 35.5 tuổi sẽ theo mũi tên bên trái, còn lại sẽ theo mũi tên bên phải.
gini = 0.375
Điều này có nghĩa là khoảng 37.5% mẫu sẽ di chuyển theo một hướng.
samples = 4
Điều này có nghĩa là trong nhánh này còn lại 4 diễn viên hài (4 diễn viên hài đến từ Anh).
value = [1, 3]
Điều này có nghĩa là trong 4 diễn viên hài này, 1 sẽ nhận được "NO" và 3 sẽ nhận được "GO".
False - 4 tên diễn viên hài kết thúc ở đây:
gini = 0.0
Biểu thị rằng tất cả các mẫu đều nhận được kết quả tương tự.
samples = 4
Điều này có nghĩa là trong nhánh này còn lại 4 diễn viên hài (4 diễn viên hài đến từ Anh).
value = [0, 4]
Điều này có nghĩa là trong 4 diễn viên hài này, 0 sẽ nhận được "NO" và 4 sẽ nhận được "GO".

True - 2 tên diễn viên hài kết thúc ở đây:
gini = 0.0
Biểu thị rằng tất cả các mẫu đều nhận được kết quả tương tự.
samples = 2
Điều này có nghĩa là trong nhánh này còn lại 2 diễn viên hài (2 diễn viên hài 35.5 tuổi hoặc trẻ hơn).
value = [0, 2]
Điều này có nghĩa là trong 2 diễn viên hài này, 0 sẽ nhận được "NO" và 2 sẽ nhận được "GO".
False - 2 tên diễn viên kịch tiếp tục:
Experience (kinh nghiệm)
Experience <= 9.5
Điều này có nghĩa là các diễn viên hài có kinh nghiệm 9.5 năm hoặc hơn sẽ theo mũi tên bên trái, còn lại sẽ theo mũi tên bên phải.
gini = 0.5
Điều này có nghĩa là 50% mẫu sẽ di chuyển theo một hướng.
samples = 2
Điều này có nghĩa là trong nhánh này còn lại 2 diễn viên hài (2 diễn viên hài trên 35.5 tuổi).
value = [1, 1]
Điều này có nghĩa là trong 2 diễn viên hài này, 1 sẽ nhận được "NO" và 1 sẽ nhận được "GO".

True - 1 tên diễn viên hài kết thúc ở đây:
gini = 0.0
Biểu thị rằng tất cả các mẫu đều nhận được kết quả tương tự.
samples = 1
Biểu thị rằng còn lại 1 diễn viên hài trong nhánh này (1 diễn viên hài có 9.5 năm hoặc ít kinh nghiệm hơn).
value = [0, 1]
Biểu thị 0 là "KHÔNG", 1 là "ĐI".
False - 1 diễn viên hài đến đây:
gini = 0.0
Biểu thị rằng tất cả các mẫu đều nhận được kết quả tương tự.
samples = 1
Biểu thị rằng còn lại 1 diễn viên hài trong nhánh này (trong đó 1 diễn viên hài có kinh nghiệm hơn 9.5 năm).
value = [1, 0]
Biểu thị 1 là "KHÔNG", 0 là "ĐI".
Giá trị dự đoán
Chúng ta có thể sử dụng cây quyết định để dự đoán giá trị mới.
Ví dụ: Tôi có nên xem một chương trình do diễn viên hài Mỹ 40 tuổi với 10 năm kinh nghiệm và xếp hạng hài kịch là 7 không?
Ví dụ
Sử dụng predict()
Cách để dự đoán giá trị mới:
print(dtree.predict([[40, 10, 7, 1]]))
Ví dụ
Nếu cấp độ hài kịch là 6, câu trả lời là gì?
print(dtree.predict([[40, 10, 6, 1]]))
Kết quả khác nhau
Nếu chạy đủ lần, ngay cả khi bạn nhập dữ liệu tương tự, cây quyết định cũng sẽ cung cấp cho bạn kết quả khác nhau.
Đây là vì cây quyết định không thể cung cấp cho chúng ta câu trả lời 100% chắc chắn. Nó dựa trên khả năng của kết quả, câu trả lời sẽ khác nhau.
- Trang Trước Huấn Luyện/Kiểm Tra
- Trang Tiếp Theo MySQL Cơ Bản