Extracting Rules from Random Forest Classifier written in PySpark and visualize it using GraphViz
Shorya Sharma
Assistant Manager at Bank Of America | Ex - Data Engineer at IBM | Azure data engineer certified | AWS CP certified
In the ever-evolving field of fraud detection, understanding the intricate patterns and relationships between variables is crucial. Leveraging machine learning models like Random Forest classifiers can significantly enhance our ability to detect fraudulent activities. In this article, I will demonstrate how to automatically extract rules from a Random Forest classifier using PySpark, providing valuable insights into the variables that influence our fraud detection models.
While I'll be using Google Colab for this demonstration, you can apply the same principles and techniques with any PySpark environment.
Download Dataset from:
To get started, We'll initialize our Spark session and define the schema for our dataset. Specifying the schema ensures each column is interpreted correctly, with appropriate data types. After defining the schema, we'll load the CSV file into a DataFrame, ensuring our data is structured and ready for analysis. This DataFrame will serve as the basis for training our Random Forest classifier to detect fraudulent activities.
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
spark = SparkSession.builder.getOrCreate()
from pyspark.sql.types import *
Schema = StructType([
StructField('Account ID', FloatType(), True),
StructField('Subscription', StringType(), True),
StructField('Subscription Year', StringType(), True),
StructField('user age', FloatType(), True),
StructField('job', StringType(), True),
StructField('marital', StringType(), True),
StructField('education', StringType(), True),
StructField('housing', StringType(), True),
StructField('loan', StringType(), True),
StructField('contact', StringType(), True),
StructField('Account Age', FloatType(), True),
StructField('Count of linking accounts', FloatType(), True),
StructField('default', FloatType(), True)
])
df = spark.read.csv("subscription.csv", header=True, schema=Schema)
df = df.dropna()
df.show(10)
Next, we'll focus on preparing our data for the Random Forest classifier by selecting important features and transforming categorical variables.
df_important_features = df.select('job', 'marital', 'education', 'user age', 'Account Age', 'Count of linking accounts', 'default')
categorical_columns = ['job', 'marital', 'education']
numerical_columns = ['user age', 'Account Age', 'Count of linking accounts']
indexers = [StringIndexer(inputCol=col, outputCol=col+"_index") for col in categorical_columns]
encoders = [OneHotEncoder(inputCol=col+"_index", outputCol=col+"_vec") for col in categorical_columns]
assembler = VectorAssembler(
inputCols=[col+"_vec" for col in categorical_columns] + numerical_columns,
outputCol="features"
)
pipeline = Pipeline(stages=indexers + encoders + [assembler])
pipeline_model = pipeline.fit(df_important_features)
After crafting a tailored pipeline to encode categorical variables and assemble features, we apply the fitted pipeline model to transform the selected features into a new DataFrame. This step ensures that our dataset is properly prepared for training the Random Forest classifier, with categorical variables encoded and numerical variables combined into a unified feature representation.
df_transformed = pipeline_model.transform(df_important_features)
With our dataset transformed and prepared, we're now poised to train our Random Forest classifier for fraud detection.
领英推荐
df_ml = df_transformed.select("features", "default")
from pyspark.ml.classification import RandomForestClassifier
train_df, test_df = df_ml.randomSplit([0.7, 0.3], seed=42)
rf = RandomForestClassifier(labelCol="default", featuresCol="features", numTrees=10)
rf_model = rf.fit(train_df)
To evaluate the performance of our trained Random Forest classifier, we employ a Multiclass Classification Evaluator from the PySpark ML library.
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
predictions = rf_model.transform(test_df)
evaluator = MulticlassClassificationEvaluator(labelCol="default", predictionCol="prediction", metricName="accuracy")
evaluator.evaluate(predictions)
To gain deeper insights into the Random Forest classifier's decision-making process, we delve into the conversion of Java-based decision trees into Python-readable structures.
from collections import namedtuple
import numpy as np
LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
"InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
def jtree_to_python(jtree):
def jsplit_to_python(jsplit):
if jsplit.getClass().toString().endswith(".ContinuousSplit"):
return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
else:
jcat = jsplit.toOld().categories()
return CategoricalSplit(
jsplit.featureIndex(),
[jcat.apply(i) for i in range(jcat.length())])
def jnode_to_python(jnode):
prediction = jnode.prediction()
stats = np.array(list(jnode.impurityStats().stats()))
if jnode.numDescendants() != 0: # InternalNode
left = jnode_to_python(jnode.leftChild())
right = jnode_to_python(jnode.rightChild())
split = jsplit_to_python(jnode.split())
return InternalNode(left, right, prediction, stats, split)
else:
return LeafNode(prediction, stats)
return jnode_to_python(jtree.rootNode())
nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
from graphviz import Digraph
Next We introduce a function visualize_tree that enables us to visualize decision trees in a human-readable format directly within our Python environment. Let's explore how it works:
for i in range(0, len(nodes)):
tree_visual = visualize_tree(nodes[i])
tree_visual.render(f"DecisionTree_{i}", format="png", cleanup=True)
Let's download anyone of the decision tree and look at the graph :
In summary, our journey from data preparation to visualization equips us with powerful tools for insightful analysis. By leveraging PySpark, we've transformed data into actionable insights, trained robust models, and visualized decision trees to understand predictive logic. This holistic approach enhances accuracy and deepens our understanding of data relationships, paving the way for innovation in data-driven decision-making.
Business Analyst - Analytics at Paytm || Engineer
9 个月????
Assistant Manager @ Bank of America | Ex- Amex
9 个月Insightful ??
Senior Manager at Bank of America | ARM Fraud Detection Strategies
9 个月????
Associate Manager (FP/A) At HCLTech
9 个月??
Performance Marketing | Client success | Account management | Digital marketing
9 个月????