Unlocking Decision-Making: An In-Depth Analysis of Entropy in Decision Trees
Decision trees are a popular machine learning algorithm used for classification and regression tasks. They work by splitting data into subsets based on feature values, ultimately leading to decisions. A crucial concept in decision trees is entropy, which measures the impurity or disorder of a dataset. This article explores how entropy is calculated and its significance in decision-making processes, using a Python implementation to illustrate these concepts.
Understanding Entropy
Entropy quantifies the uncertainty or randomness in a dataset. It is defined mathematically as:
Where:
In decision trees, we use entropy to determine the best feature to split the data. The goal is to minimize entropy after the split, leading to more homogeneous subsets.
Code Implementation
The following Python code calculates the entropy of decision tree splits based on a dataset. We will break down the code step by step.
import numpy as np
import math
# Sample dataset: features and labels
df = np.array([[1, 0, 18, 1],
[1, 1, 15, 1],
[0, 1, 65, 0],
[0, 0, 33, 0],
[1, 0, 37, 1],
[0, 1, 45, 1],
[0, 1, 50, 0],
[1, 0, 75, 0],
[1, 0, 67, 1],
[1, 1, 60, 1],
[0, 1, 55, 1],
[0, 0, 69, 0],
[0, 0, 80, 0],
[0, 1, 87, 1],
[1, 0, 38, 1]
])
Dataset Explanation
The dataset consists of several features (the first three columns) and a label (the last column). Each row represents an instance, where the label indicates the outcome we want to predict.
Entropy Calculation Functions
The following functions calculate the weighted average entropy and the entropy for a given dataset.
def calc_wighted_average(enp1, enp1_multiplier, enp2, enp2_multiplier):
return round((((enp1 * enp1_multiplier) + (enp2 * enp2_multiplier)) / (enp1_multiplier + enp2_multiplier)), 3)
def calc_entropy(data):
if len(np.unique(data[:, 0])) > 2:
sorted_data = data[data[:, 0].argsort()]
main_dict = {}
for i in range(1, len(sorted_data)):
first_number = sorted_data[i-1, 0]
second_number = sorted_data[i, 0]
avg = (first_number + second_number) / 2
true_xs = data[data[:, 0] < avg]
count_true_xs = len(true_xs)
true_xs_true_ys = len(true_xs[true_xs[:, 1] == True])
true_xs_false_ys = len(true_xs[true_xs[:, 1] == False])
try:
entrp1 = round(( (-(true_xs_true_ys/count_true_xs) * math.log2(true_xs_true_ys/count_true_xs)) + ((-true_xs_false_ys/count_true_xs) * math.log2(true_xs_false_ys/count_true_xs))) , 3)
except:
entrp1 = 0
false_xs = data[data[:, 0] > avg]
count_false_xs = len(false_xs)
false_xs_true_ys = len(false_xs[false_xs[:, 1] == True])
false_xs_false_ys = len(false_xs[false_xs[:, 1] == False])
try:
entrp2 = round(( (-(false_xs_true_ys/count_false_xs) * math.log2(false_xs_true_ys/count_false_xs)) + ((-false_xs_false_ys/count_false_xs) * math.log2(false_xs_false_ys/count_false_xs))) , 3)
except:
entrp2 = 0
main_dict[str(avg)] = (calc_wighted_average(entrp1, count_true_xs, entrp2, count_false_xs))
return {min(main_dict, key=main_dict.get): main_dict[min(main_dict, key=main_dict.get)]}
else:
true_xs = data[data[:, 0] == True]
count_true_xs = len(true_xs)
if count_true_xs == 0:
entrp1 = 0
else:
true_xs_true_ys = len(true_xs[true_xs[:, 1] == True])
true_xs_false_ys = len(true_xs[true_xs[:, 1] == False])
try:
entrp1 = round(( (-(true_xs_true_ys/count_true_xs) * math.log2(true_xs_true_ys/count_true_xs)) + ((-true_xs_false_ys/count_true_xs) * math.log2(true_xs_false_ys/count_true_xs))) , 3)
except:
entrp1 = 0
false_xs = data[data[:, 0] == False]
count_false_xs = len(false_xs)
if count_false_xs == 0:
entrp2 = 0
else:
false_xs_true_ys = len(false_xs[false_xs[:, 1] == True])
false_xs_false_ys = len(false_xs[false_xs[:, 1] == False])
try:
entrp2 = round(( (-(false_xs_true_ys/count_false_xs) * math.log2(false_xs_true_ys/count_false_xs)) + ((-false_xs_false_ys/count_false_xs) * math.log2(false_xs_false_ys/count_false_xs))) , 3)
except:
entrp2 = 0
return calc_wighted_average(entrp1, count_true_xs, entrp2, count_false_xs)
Function Explanation
calc_wighted_average: This function calculates the weighted average of two entropy values based on their respective counts. The formula for the weighted average is:
Where:
calc_entropy: This function computes the entropy for a given dataset. It sorts the data and iterates through possible split points, calculating the entropy for each subset formed by the split.
Example Usage
The following lines demonstrate how to calculate entropy for different features in the dataset:
print(calc_entropy(df[:, [0, -1]])) # Entropy for feature 0
print(calc_entropy(df[:, [1, -1]])) # Entropy for feature 1
print(calc_entropy(df[:, [2, -1]])) # Entropy for feature 2
This code will print the entropy values for each feature, allowing us to determine which feature provides the most informative splits.
Analyzing Subsets
The code further analyzes subsets of the data based on a threshold (in this case, the third feature):
true_side_df = df[df[:, 2] < 68]
print(true_side_df)
print(calc_entropy(true_side_df[:, [0, -1]]))
print(calc_entropy(true_side_df[:, [1, -1]]))
false_side_df = df[df[:, 2] > 68]
print(false_side_df)
print(calc_entropy(false_side_df[:, [0, -1]]))
print(calc_entropy(false_side_df[:, [1, -1]]))
This part of the code separates the dataset into two parts based on whether the third feature is less than or greater than 68, and then calculates the entropy for each subset.
Conclusion
Entropy is a vital concept in decision trees that helps in making data-driven decisions. By calculating entropy for different features and their splits, we can enhance the efficiency and accuracy of classification tasks. The provided Python code serves as a practical example of how to implement these concepts in a real-world scenario, paving the way for more informed decision-making processes in machine learning applications.
By understanding and applying these principles, data scientists and machine learning practitioners can build more robust models that effectively leverage the power of data.
For the complete code, please visit my GitHub repository: GitHub Repository.
Cyber Security Master’s Student at Macquarie University | Back-End & AI Developer
2 周Great Article.