Decoding the Black Box: An Important Introduction to Interpretable Machine Learning Models in Python

Here, we will work on the implementation of both the methods we covered above.

We will use the big mart sales problem hosted on our Datahack Platform.

The problem statement includes predicting sales for different items being sold at different outlets.

You can download the dataset from the above link.

Note: You can go through this course to fully understand how to build models using this data.

Our focus here is to focus on the interpretability part.

  Building and Understanding Interpretable Machine Learning Models Let us first look at how to do interpretability for inherently interpretable machine learning models.

Importing the Required Libraries # importing the required libraries import pandas as pd import numpy as np from sklearn.

model_selection import train_test_split from sklearn.

metrics import mean_squared_error from sklearn.

linear_model import LinearRegression from sklearn.

tree import DecisionTreeRegressor from sklearn.

ensemble import RandomForestRegressor from xgboost.

sklearn import XGBRegressor from sklearn.

preprocessing import OneHotEncoder, LabelEncoder from sklearn import tree import matplotlib.

pyplot as plt %matplotlib inline Reading Data # reading the data df = pd.

read_csv(data.

csv) Missing Value Treatment # imputing missing values in Item_Weight by median and Outlet_Size with mode df[Item_Weight].

fillna(df[Item_Weight].

median(), inplace=True) df[Outlet_Size].

fillna(df[Outlet_Size].

mode()[0], inplace=True) Feature Engineering # creating a broad category of type of Items df[Item_Type_Combined] = df[Item_Identifier].

apply(lambda df: df[0:2]) df[Item_Type_Combined] = df[Item_Type_Combined].

map({FD:Food, NC:Non-Consumable, DR:Drinks}) df[Item_Type_Combined].

value_counts() # operating years of the store df[Outlet_Years] = 2013 – df[Outlet_Establishment_Year] # modifying categories of Item_Fat_Content df[Item_Fat_Content] = df[Item_Fat_Content].

replace({LF:Low Fat, reg:Regular, low fat:Low Fat}) df[Item_Fat_Content].

value_counts() Data Preprocessing # label encoding the ordinal variables le = LabelEncoder() df[Outlet] = le.

fit_transform(df[Outlet_Identifier]) var_mod = [Item_Fat_Content,Outlet_Location_Type,Outlet_Size,Item_Type_Combined,Outlet_Type,Outlet] le = LabelEncoder() for i in var_mod: df[i] = le.

fit_transform(df[i]) # one hot encoding the remaining categorical variables df = pd.

get_dummies(df, columns=[Item_Fat_Content,Outlet_Location_Type,Outlet_Size,Outlet_Type, Item_Type_Combined,Outlet]) Train-Test Split # dropping the ID variables and variables that have been used to extract new variables df.

drop([Item_Type,Outlet_Establishment_Year, Item_Identifier, Outlet_Identifier],axis=1,inplace=True) # separating the dependent and independent variables X = df.

drop(Item_Outlet_Sales,1) y = df[Item_Outlet_Sales] # creating the training and validation set X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.

25, random_state=42) Training a Decision Tree Model dt = DecisionTreeRegressor(max_depth = 5, random_state=10) # fitting the decision tree model on the training set dt.

fit(X_train, y_train) Use the Graphviz library to visualize the decision tree # Visualising the decision tree decision_tree = tree.

export_graphviz(dt, out_file=tree.

dot, feature_names=X_train.

columns, filled=True, max_depth=2) # converting the dot image to png format !dot -Tpng tree.

dot -o tree.

png #plotting the decision tree image = plt.

imread(tree.

png) plt.

figure(figsize=(25,25)) plt.

imshow(image) This visualization of our decision tree clearly displays the rules it is using to make a prediction.

Here, Item_MRP & Outlet_Type are the first features that are affecting the sales of various items at each outlet.

If you want to look at the complete decision tree, you can easily do that by changing the max_depth parameter using the export_graphviz function.

  Feature Importance Now, we will have a look at the feature importance for each feature in case of a random forest.

# creating the Random Forest Regressor model rf = RandomForestRegressor(n_estimators=200, max_depth=5, min_samples_leaf=100,n_jobs=-1) # feature importance of the random forest model feature_importance = pd.

DataFrame() feature_importance[variable] = X_train.

columns feature_importance[importance] = rf.

feature_importances_ # feature_importance values in descending order feature_importance.

sort_values(by=importance, ascending=False).

head(10) The random forest model gives a similar interpretation.

Item_MRP still remains the most important feature (exactly as the decision tree model above).

Relative importance also helps us compare each feature.

For example, Outlet_Type_0 is a much more important feature than other outlet types.

Exercise As an exercise try calculating the feature importance for the decision tree we fit earlier and compare:   Global Surrogate Next, we will create a surrogate decision tree model for this random forest model and see what we get.

# saving the predictions of Random Forest as new target new_target = rf.

predict(X_train) # defining the interpretable decision tree model dt_model = DecisionTreeRegressor(max_depth=5, random_state=10) # fitting the surrogate decision tree model using the training set and new target dt_model.

fit(X_train,new_target) This decision tree performs well on the new target and can be used as a surrogate model to explain the predictions of a random forest model.

Similarly, we can use it for any other complex model.

Just make sure your decision tree fits well, otherwise, you might get wrong interpretations (a nightmare!).

  Implementing LIME in Python to generate local interpretations of black-box models We can implement the LIME technique in both R & Python using the LIME package.

 Let’s jump into implementation for the same to check the local interpretation of a given prediction using LIME: # installing lime library !pip install lime # import Explainer function from lime_tabular module of lime library from lime.

lime_tabular import LimeTabularExplainer # training the random forest model rf_model = RandomForestRegressor(n_estimators=200,max_depth=5, min_samples_leaf=100,n_jobs=-1, random_state=10) rf_model.

fit(X_train, y_train) # creating the explainer function explainer = LimeTabularExplainer(X_train.

values, mode=”regression”, feature_names=X_train.

columns) # storing a new observation i = 6 X_observation = X_test.

iloc[[i], :] * RF prediction: {rf_model.

predict(X_observation)[0]}     Generate Explanations using LIME # explanation using the random forest model explanation = explainer.

explain_instance(X_observation.

values[0], rf_model.

predict) explanation.

show_in_notebook(show_table=True, show_all=False) print(explanation.

score) The predicted value for sales is 185.

40.

Each feature’s contribution to this prediction is shown in the right bar plot.

Orange signifies the positive impact and blue signifies the negative impact of that feature on the target.

For example, Item_MRP has a positive impact on sales.

Exercise 2 Using the coding window below, try applying LIME to more complex models such as Xgboost & LightGBM.

The code for preprocessing and model building is already inserted here:   For more details on LIME and its implementation, you can go through this article.

  End Notes LIME is a powerful technique but has its disadvantages as it relies on locally generated fake data and uses simple linear models to explain predictions.

However, it can be used for text and image data as well.

As I mentioned earlier, interpretable machine learning is part of our utterly comprehensive end-to-end course: Applied Machine Learning Make sure you check it out.

And if you have any questions or feedback regarding this article, let me know in the comments section below.

You can also read this article on Analytics Vidhyas Android APP Share this:Click to share on LinkedIn (Opens in new window)Click to share on Facebook (Opens in new window)Click to share on Twitter (Opens in new window)Click to share on Pocket (Opens in new window)Click to share on Reddit (Opens in new window) Related Articles (adsbygoogle = window.

adsbygoogle || []).

push({});.

. More details

Leave a Reply