Customer Churn Prediction with PySpark on IBM Watson Studio, AWS and Databricks

Customer Churn Prediction with PySpark on IBM Watson Studio, AWS and DatabricksPredicting customer churn for a digital music service using big data tools and cloud computing servicesJosh Xin Jie LeeBlockedUnblockFollowFollowingFeb 4Image taken from www.



ukThe code accompanying this article can be found here.

Customer churn refers to the situation when customers stop doing business with a company.

According to the article by Harvard Business Review, acquiring a new customer can be 5 to 25 times more expensive than retaining an existing one.

In fact, a research conducted by Frederick Reichheld of Bain & Company shown that increasing customer retention rates by 5% can increase profits by 25% to 95%.

Hence, it should be a company’s priority to minimize the rate of customer churn.

If we can successfully predict customers who will churn in advance, we can entice them to stay by providing discounts and incentives.

In this article, we will tackle a customer churn prediction problem for a fictitious digital music service called Sparkify.

We will train our predictive model on a large dataset (~12GB) of customer’s activities on the service and attempt to predict customers who will churn based on their past behaviors.

Since this dataset is too large to fit on a single computer, we will be using Apache Spark to help us analyze this dataset.

Apache Spark is one of the most popular big data distributed processing frameworks (using multiple computers) in the world.

Spark is significantly faster than Hadoop MapReduce, and it has a user-friendly API that can be accessed through a number of popular languages: Scala, Java, Python and R.

For this task, we will be using Spark’s Python API, PySpark.

PySpark provides two approaches to manipulate data frames: first of which is similar to Python’s Pandas library, and the other using SQL queries.

Spark ML is the data frame based API for Spark’s Machine Learning library, and it provides users with popular machine learning algorithms such as Linear Regression, Logistic Regression, Random Forests, Gradient-Boosted Tress, etc.

Setting up PySpark on your local computer can be tricky.

Running PySpark on cloud services will simplify the setup process, but may incur some costs in the process.

Ideally, you should create the model’s prototype on your local machine with a smaller dataset to save cost, and then transfer the code to a cloud-based service when you are ready to analyze the larger dataset.

Ultimately, you will need a cloud service to take full advantage of Spark’s distributed computing framework.

An alternative approach will be to set up an account with Databricks and use their Community Edition to build your model’s prototype.

The Databricks Community Edition provides a free 6GB micro-cluster as well as a cluster manager and notebook environment.

The best part is that the access is not time-limited!If you are interested in reproducing the code or trying out the dataset, I have included the setup instructions for three cloud services, IBM Studio Watson, Amazon AWS and Databricks.

Brief Description of ProblemSparkify has a free-tier and a premium subscription plan, and customers can cancel or downgrade from premium to free-tier at any time.

We will define a customer churn event as downgrading from premium to free tier or a cancellation of service.

Ideally, we want to predict churn events using data from the past to avoid any look-ahead biases.

There are two months of customer activity data available.

Hence, we will build a model to predict if users will churn in the second month using their behavioral data from the first month.

The full dataset is 12 GB.

Alternatively, you can try out smaller instances of the dataset, and I have included the datasets’ download links on my GitHub page.

I will now cover the setup instructions for the three cloud computing services: IBM Watson Studio, AWS and Databricks.

Feel free to skip them and continue reading on ‘Dataset’ if you wish.

Setting Up IBM Watson StudioOne of the easiest way to set up and run Spark will be through the IBM Watson Studio platform.

It has a user-friendly interface and there is a free “Lite Plan” available.

You will be given 50 capacity units per month with the “Lite Plan”.

The default Spark Python 3.

5 environment will consume 1.

5 capacity units per hour, giving you approximately 33 hours to work on a project.

To setup IBM Watson Studio, you need to register for an IBM Cloud account if you do not have one.

Next, log on to the IBM Watson Studio home page and log in.

Select log inAfter logging in, you will be taken to this page.

Select “Create a project”.

Select “Create a project”Next, hover around “Data Science” and click “Create Project”.

Select “Data Science” and click “Create Project”Enter a name for your project, and select “Create”.

Enter a project name and select “Create”Select “Add to project”.

Select “Add to project”Select “Notebook” for your asset type.

Select “Notebook”Give a name to the notebook and select “Default Spark Python 3.

5XS (Driver with 1 vCPU and 4 GB RAM, 2 executors with 1 vCPU and 4 GB RAM each)”.

Next, click “Create Notebook”.

This will create a new notebook where you can start coding.

To insert a data file, select “Find and add data” icon in the top right hand corner.

Simply drag and drop your desired data file into the box.

To create a new Spark Session and read in the data file, select “insert to code” and click on “Insert SparkSession DataFrame”.

This will generate a pre-written cell.

Uncomment the last two lines to read in the data frameYou can uncomment the last two lines to read in the data file.

Feel free to change the name of the data frame.

I uncommented the last 2 lines and changed the name of the data frame from df_data_1 to dfYou can now build your project!This is what you will see after you run the first cellShould your current cluster terminate for any reason, you can always reconnect your pre-existing notebook to a new cluster and continue your work.

After you have completed your project, you should stop your environment and delete your notebook and data files to avoid any unexpected charges.

To stop your environment, click on “Environment” tab on top of your project page.

Click on the three dots on the right of your active environment and select “Stop”.

Select “Stop” to stop your environmentNext, head over to the “Assets” tab on top of your project page.

Click on the three dots to the right of your data file and select “Remove”.

Repeat this for your notebook as well.

“Remove” your data file.

Do the same for your notebook as wellYou can check your billing information by selecting “Manage” and then “Billing and usage”.

To check your billing detailsAn advantage of using Spark on the IBM Watson Studio platform is that it comes pre-installed with commonly used libraries, such as Pandas and Matplotlib, etc.

This is in contrast with AWS EMR service, which does not have these libraries pre-installed.

If your IBM hosted notebook ceased running or your internet browser crashed (yes it did happen), it probably meant that your current setup on IBM Watson Studio have insufficient memory to tackle the task.

You might have to choose a simpler model, apply dimensionality reduction to your dataset, or purchase a paid plan to access a more powerful instance.

Setting Up Amazon AWSI will now share the setup instructions for Amazon Web Services Elastic MapReduce (EMR).

To begin.

you need to sign up for an AWS account.

You will be required to provide a credit card when signing up, but you would not be charged for anything yet.

You will need to choose a support plan, and the free Basic Support Plan should suffice.

Next, go to the Amazon EMR console and click “Get Started with Amazon EMR”.

Once you have signed in with your account, you are ready to create your clusters.

Click “Get Started with Amazon EMR”Choose the appropriate location on the top right hand corner (Pick the one closest to you).

Select “Clusters” on the menu on the left and then click “Create cluster”.

Location, on the top right hand corner, is currently Ohio.

Change this to the most appropriate location.

Next you can create your clusters.

Configure your clusters with the following settings:Release: emr-5.


0 or laterApplications: : Spark 2.


0 on Hadoop 2.


5 YARN with Ganglia 3.


2 and Zeppelin 0.


0Instance type: m3.

xlargeNumber of instance: 6EC2 key pair: Proceed without an EC2 key pair or feel free to use one if you’d likeIf you want to run the code as it is, you are recommended to use 6 instances of m3.


Of course, you can try a lower number of instances, (3 for instance).

However, if you encounter errors such as “Session isn’t active”, it probably means that your current cluster setup have insufficient memory for the task.

You will need to create a new cluster with larger instances.

The rest of the settings can be kept at their default values, and you can finalize the setup by clicking “Create cluster”.

This picture shows the setup with 4 instances.

You are recommended to use 6 instances for this project to avoid memory issues.

You will get a similar error message when running your code if your cluster has insufficient memory for the taskNext, you will be brought to a page that shows that the cluster is “Starting”.

The status will then change to “Running” after a few minutes.

Finally, it will change to “Waiting”.

The whole process can take somewhere between 3 to 10 minutes.

At this point, you can move on to the next step.

The first status will state that the cluster is “Starting”Next, the status will change to “Running”Finally, it will change to “Waiting”.

At this point, you can move on to the next stepFinally, you can create your notebook.

Select “Notebooks” on the menu on the left.

Create a name for your notebookSelect “Choose an existing cluster” and choose the cluster you just createdUse the default setting for “AWS service role” — this should be “EMR_Notebooks_DefaultRole” or “Create default role” if you haven’t done this before.

You can keep the rest of the settings as it is, and click “Create notebook” on the bottom rightNext, wait for the status of the notebook to change from “Starting” or “Pending” to “Ready”.

At this point, you can “open” the notebook.

Now, you can start coding.

The starter code to create a new Spark session and read in the full dataset is provided below:# Starter codefrom pyspark.

sql import SparkSession# Create spark sessionspark = SparkSession .

builder .

appName("Sparkify") .

getOrCreate()# Read in full sparkify datasetevent_data = "s3n://dsnd-sparkify/sparkify_event_data.

json"df = spark.



head()The full dataset is located at: s3n://dsnd-sparkify/sparkify_event_data.

jsonThis is what you will see after running the code.

If you terminated your cluster and stopped your notebook, and you want to re-run the notebook, you can create a new cluster and re-connect your notebook to the newly cluster.

Follow the previously given instructions to setup a new cluster and then return to the “Notebooks” option in the left menu.

Click on the existing notebook you have created, select “Change cluster” and choose the newly created cluster.

Finally select “Change cluster and start notebook”.

To re-connect your notebook to a new cluster, click on your existing notebookSelect “Change cluster”Select the newly created cluster with the “Choose” button located under the “Choose an existing cluster” option.

Finally select “Change cluster and start notebook”To avoid any unexpected charges on AWS, terminate your clusters and delete your notebook when you have completed your analysis.

You can check this in the “Clusters” and “Notebooks” options in the left menu.

If you have setup clusters on multiple locations, make sure to check on all of these locations as each location will have its own list of clusters!Ensure all your clusters are terminated when you are done with your analysisYou can check your billing details under the “My Billing Dashboard” option which can be found under your account name.

When running the code on Amazon EMR, you might encounter these errors or exceptions:TypeError: object of type ‘NoneType’ has no len()KeyError: 14933 (the number could be different)Do not fret if you see these errors.

Generally, they do not mean that your code is faulty, and their occurrence will not affect the execution of your code.

If they occur while you are executing a block of code, you can ignore these error messages.

The code will still be executed.

An example of error #1An example of error #2It costs me around US$20 to run the analysis on AWS EMR.

That is taking into account of the fact that I had to re-run the analysis numerous times as my AWS hosted notebook crashed due to insufficient memory (I initially tried working with 3,4 and 5 instances of m3.


Hence, if you start with a sufficiently large cluster (6 instances of m3.

xlarge for this project), you might be able to run the code once without encountering any issues.

This could help lower your cost.

In addition, I have developed the model’s prototype on my local machine, so that helped lower the eventual cost.

One side note is that AWS EMR clusters do not have libraries such as Scikit-Learn, Pandas, Matplotlib pre-installed.

You should be able to complete the project without these libraries, although you would not be able to perform any data visualizations.

If you do want to use these libraries, you will need to install them by following the instructions provided by this link.

This video would also be helpful.

It is a somewhat complicated procedure that vaguely involves writing a bash script to install these libraries, uploading the script to an S3 bucket, and then instantiate a cluster with the script.

Admittedly, I have not tried it out at the time of writing.

If I do get around in installing these libraries, I might update this article in the future.

Setting Up DatabricksIn order to register for a Databricks account, you will need an AWS account.

Databricks’ platform depends on AWS for the cloud infrastructure.

You can register for a Databricks account here.

The registration process for Databricks is slightly longer than the other two aforementioned cloud services and it involves shuttling between your AWS account and Databricks account setup page.

However, once you have completed the set up and activated your account, it is relatively easy to set up the work environment in Databricks.

In addition, the Community Edition offers you free unlimited access to a 6 GB cluster (no need to worry about time limit!), making it a perfect environment to build your model’s prototype.

Likewise with IBM Watson Studio, the work environment of Databricks comes with commonly used libraries such as Pandas, Matplotlib pre-installed.

Once you have created your account, head over to the login page of Databricks Community Edition.

Login to your accountYou will be greeted by a similar page.

Home page of Databricks Community EditionSelect “Clusters” on the left menu.

Select “Clusters”Select “+ Create Cluster” on the top left hand corner of the page.

Select the blue “+ Create Cluster” buttonGive a name to your cluster.

You can leave the Databricks Runtime Version unmodified.

Select the Python version you want to work with.

I chose Python 3 for my project.

You can leave the Availability Zone field blank.

Select “Create Cluster” when you are done.

Wait for the status to change from “Pending” to “Running”.

Cluster status is “Pending”Cluster status is “Running”Next, select “Workspace” on the left menu, then click on “Users”.

Select the arrow next to your email and click on “Create” then “Notebook”.

Creating a new notebookGive a name to your notebook, and you can keep the rest of the settings unmodified.

Give a name to your notebook and select “Create”Next, you will want to upload your dataset.

Select “Data” on the left menu, and click on “Add Data”.

Simply drag and drop the file you want to upload into the grey box.

Do not leave the page when the file is uploading, otherwise you will have to re-do the whole process.

Drag and drop your desired data fileWhen the upload is completed, select “Create table in Notebook”.

This will open an example notebook with a pre-written cell that has the code to read in your data.

Copy the code in this cell and go back to the original notebook you created earlier and paste the code there.

The example notebook.

Copy the first cell and paste it in the notebook you created earlierCongratulations, you can start coding!If you want to save the file, click on “File”, then “Export”, then “iPython Notebook”.

If you want to change a code cell into a markdown cell, type “%md” in the first line of the cell.

If you have terminated your cluster, you can always create a new cluster and reconnect an existing notebook to the new cluster.

Even though the Community Edition is a free service, it is still a good idea to terminate your clusters and delete your notebooks to avoid any unexpected AWS charges.

You can check your clusters’ status by clicking on “Clusters” on the left menu.

Do note that if you choose to use the full Databricks platform instead of the Community Edition, you will incur AWS charges.

DatasetThe schema of the dataset:A row from the dataset:One of the key column to look out for in this dataset is the ‘page’ column.

The ‘page’ column keeps track of what pages a user visits.

This is a list of the pages that users can visit:Cancel: User has visited the cancel page.

Does not mean cancellation was completed.

Submit Downgrade: User has submitted a downgrade from premium to free tierThumbs Down: User gave a thumbs down.

Home: User visited the home pageDowngrade: User visited the downgrade page.

Does not mean downgrade is submitted.

Roll Advert: Advertisements are played.

Logout: User logged out.

Save Settings: User made some changes to settings and save them.

Cancellation Confirmation: User cancelled the subscription.

About: User visited the about page.

Submit Registration: User submitted a registration request.

Settings: User visited the settings page.

Login: User logged in.

Register: User visited the registration page.

Does not mean registration was completed.

Add to Playlist: User added songs to playlist.

Add Friend: User added a friend.

Next Song: User listened to a song.

Thumbs Up: User gave a thumbs up.

Help: User visited the help page.

Upgrade: User upgraded from free to premium tier.

Of these, the pages which we should pay attention to are perhaps: ‘Next Song’ which tracks the songs played by a user, ‘Submit Downgrade’ which tracks when a user submits a downgrade request and ‘Cancellation Confirmation’ which tracks when a user’s cancellation request is confirmed.

We will know when a customer churn by searching for a ‘Submit Downgrade’ or ‘Cancellation Confirmation’ page.

Note that a user can visit the ‘Downgrade’ and ‘Cancel’ page but not submit a downgrade or cancellation request.

The rest of the pages are relatively straightforward.

They indicate that the user has visited the relevant page.

As mentioned earlier, there are 2 months of data available.

Data PreprocessingReading in the data is slightly different depending on the cloud computing service you are using.

Each service’s method of reading in the dataset are shared above in the setup instructions.

If you are reading in the data on your local computer, this is the code for the small dataset:# create a Spark sessionspark = SparkSession.

builder .

master("local") .

appName("Sparkify") .

getOrCreate()df = spark.



json")This code assumes that the dataset is located in the same file directory as your code.

You might need to change the path of the dataset depending on where you store your dataset.

Data CleaningThere are a couple of columns with null values.

Notably, entries missing firstName, gender, lastName, location, registration and userAgent belong to users who did not log in or did not register.

Number of null values for each feature on the full datasetSince we do not know which users these entries are tracking and most of these entries are associated with the login or home page, these entries provide little value.

Hence, we can discard them from our analysis.

# filter out all entries with missing names.

# these also removes entries with missing gender, lastName, # location, registration and userAgentdf = df.



isNotNull())Pages that are not ‘NextSong’ have null values for artist, length and sessionId.

All of these variables are only valid when songs are played (they are only non-null when page=’NextSong’), hence we can leave them unmodified.

Feature EngineeringFirst, we transform the timestamps to datetime.

The original timestamps are in milliseconds, so we have to divide them by 1000 before converting them.

# original timestamp in milliseconds, so divide by 1000 adjust_timestamp = udf(lambda x : x//1000, IntegerType()) df = df.

withColumn("ts_adj", adjust_timestamp('ts')) # convert adjusted timestamp to datetime df = df.

withColumn("datetime", from_unixtime(col("ts_adj"))) # convert registration timestamp to datetime df = df.

withColumn("reg_adj", adjust_timestamp('registration')) # convert adjusted registration timestamp to datetime df = df.

withColumn("reg_datetime", from_unixtime(col("reg_adj"))) # drop all the timestamp columns.

Will not need them columns_to_drop = ['registration', 'ts', 'ts_adj', 'reg_adj'] df = df.

drop(*columns_to_drop)Next, we can label the months from 0 to N where N represents the total number of months available in the dataset.

We can approximate the start date of the analysis as ‘2018–10–01 00:00:00’.

# add start date of analysis df = df.

withColumn('analysis_start_date', lit('2018-10-01 00:00:00')) # number the months starting from the very first month of the # analysis df = df.

withColumn("month_num", floor(months_between(df.

datetime, df.

analysis_start_date)))Replace the string variables for gender and level with binary (0 or 1) variables.

# engineer free or paid binary variable# free: 0, paid: 1df = df.

replace(["free", "paid"], ["0", "1"], "level")# engineer male and female binary binary variable# male: 0, female: 1df = df.

replace(["M", "F"], ["0", "1"], "gender")We can define churn events as whenever users visit a ‘Cancellation Confirmation’ or ‘Submit Downgrade’ pages.

def define_churn(x): """ Defining churn as cancellation of service or downgrading from premium to free tier.

""" if x == "Cancellation Confirmation": return 1 elif x == "Submit Downgrade": return 1 else: return 0 churn_event = udf(lambda x : define_churn(x), IntegerType()) df = df.

withColumn("churn", churn_event("page"))Generating monthly statistics for each user:# number of register page visitsdf_register = df.

select('userId', 'month_num', 'page') .


page=="Register") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numRegister')# number of cancel page visitsdf_cancel = df.

select('userId', 'month_num', 'page') .


page=="Cancel") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numCancelVisits')# number of upgrade page visitsdf_upgrade = df.

select('userId', 'month_num', 'page') .


page=="Upgrade") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numUpgradeVisits')# number of downgrade page visitsdf_downgrade = df.

select('userId', 'month_num', 'page') .


page=="Downgrade") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numDowngradeVisits')# number of home page visitsdf_home = df.

select('userId', 'month_num', 'page') .


page=="Home") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numHomeVisits')# number of about page visitsdf_about = df.

select('userId', 'month_num', 'page') .


page=="About") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numAboutVisits')# number of setting page visitsdf_settings = df.

select('userId', 'month_num', 'page') .


page=="Settings") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numSettingsVisits')# number of times user save settings changesdf_saveSettings = df.

select('userId', 'month_num', 'page') .


page=="Save Settings") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numSaveSettings')# number of login page visitsdf_login = df.

select('userId', 'month_num', 'page') .


page=="Login") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numLogins')# number of logout page visitsdf_logout = df.

select('userId', 'month_num', 'page') .


page=="Logout") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numLogouts')# number of songs added to playlistdf_addPlaylist = df.

select('userId', 'month_num', 'page') .


page=="Add to Playlist") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numAddPlaylists')# number of friends addeddf_addFriend = df.

select('userId', 'month_num', 'page') .


page=="Add Friend") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numFriends')# number of thumbs up givendf_thumbsUp = df.

select('userId', 'month_num', 'page') .


page=="Thumbs Up") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numThumbsUp')# number of thumbs down givendf_thumbsDown = df.

select('userId', 'month_num', 'page') .


page=="Thumbs Down") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numThumbsDown')# number of advertisements rolleddf_advert = df.

select('userId', 'month_num', 'page') .


page=="Roll Advert") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numAdverts')# number of songs playeddf_songsPlayed = df.

select('userId', 'month_num', 'page') .


page=="NextSong") .

groupBy('userId', 'month_num') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'numSongsPlayed')# total amount of time user listened to songsdf_totalListen = df.

select('userId', 'month_num', 'length') .

groupBy('userId', 'month_num') .

agg({'length':'sum'}) .

withColumnRenamed('sum(length)', 'totalListenTime')# number of songs played per sessiondf_songsPerSession = df.

select('userId', 'month_num', 'page', 'sessionId') .


page=="NextSong") .

groupBy('userId', 'month_num', 'sessionId') .

agg({'page':'count'}) .

withColumnRenamed('count(page)', 'SongsPerSession')# average number of songs played per sessiondf_avgSongsPerSession = df_songsPerSession.

groupBy('userId', 'month_num') .



alias ('avgSongsPerSession'))# number of singers listened per monthdf_singersPlayed = df.

select('userId', 'month_num', 'page', 'artist') .


page=="NextSong") .

groupBy('userId', 'month_num') .



alias('numSingersPlayed'))# number of singers per sessiondf_singersPerSession = df.

select('userId', 'month_num', 'page', 'artist', 'sessionId') .


page=="NextSong") .

groupBy('userId', 'month_num', 'sessionId') .



alias('SingersPerSession'))# average number of singers per sessiondf_avgSingersPerSession = df_singersPerSession.

groupBy('userId', 'month_num') .



alias ('avgSingersPerSession'))# amount of time spent for each sessiondf_userSession = df.

groupBy("userId", "month_num", "sessionId") .


datetime))-min(unix_timestamp (df.



alias('sessionTimeMins'))# average time per sessiondf_avgUserSession = df_userSession.

groupBy('userId', 'month_num').



alias ('avgSessionMins'))# number of sessions per monthdf_numSession = df.

select('userId', 'month_num', 'sessionId').

dropDuplicates() .

groupby('userId', 'month_num').

agg({'sessionId':'count'}) .

withColumnRenamed('count(sessionId)', 'numSessions')# if user had premium level this month# if user had premium at any point of the month, assumer he/she has # premium for the whole month for simplicitydf_level = df.

select('userId', 'month_num', 'level') .

groupBy('userId', 'month_num') .

agg({'level':'max'}) .

withColumnRenamed('max(level)', 'level')# find user's gender# assuming nobody changes gender midwaydf_gender = df.

select('userId', 'month_num', 'gender') .

groupBy('userId', 'month_num') .

agg({'gender':'max'}) .

withColumnRenamed('max(gender)', 'gender')# start of each monthdf = df.

withColumn("start_of_month", expr("add_months (analysis_start_date, month_num)"))# days since registration to start of each monthdf = df.

withColumn("daysSinceReg", datediff(df.

start_of_month, df.

reg_datetime))df_daysReg = df.

select('userId', 'month_num', 'daysSinceReg') .

groupBy('userId', 'month_num') .



alias('daysSinceReg'))# did user churn this monthdf_churn = df.

select('userId', 'month_num', 'churn') .

groupBy('userId', 'month_num') .

agg({'churn':'max'}) .

withColumnRenamed('max(churn)', 'churn')Concatenating these monthly statistics into a new data-frame:all_data = df_register.

join(df_cancel, ['userId', 'month_num'], 'outer') .

join(df_upgrade, ['userId', 'month_num'], 'outer') .

join(df_downgrade, ['userId', 'month_num'], 'outer') .

join(df_home, ['userId', 'month_num'], 'outer') .

join(df_about, ['userId', 'month_num'], 'outer') .

join(df_settings, ['userId', 'month_num'], 'outer') .

join(df_saveSettings, ['userId', 'month_num'], 'outer') .

join(df_login, ['userId', 'month_num'], 'outer') .

join(df_logout, ['userId', 'month_num'], 'outer') .

join(df_addPlaylist, ['userId', 'month_num'], 'outer') .

join(df_addFriend, ['userId', 'month_num'], 'outer') .

join(df_thumbsUp, ['userId', 'month_num'], 'outer') .

join(df_thumbsDown, ['userId', 'month_num'], 'outer') .

join(df_advert, ['userId', 'month_num'], 'outer') .

join(df_songsPlayed, ['userId', 'month_num'], 'outer') .

join(df_totalListen, ['userId', 'month_num'], 'outer') .

join(df_avgSongsPerSession, ['userId', 'month_num'], 'outer') .

join(df_singersPlayed, ['userId', 'month_num']) .

join(df_avgSingersPerSession, ['userId', 'month_num'], 'outer') .

join(df_avgUserSession, ['userId', 'month_num'], 'outer') .

join(df_numSession, ['userId', 'month_num'], 'outer') .

join(df_level, ['userId', 'month_num'], 'outer') .

join(df_gender, ['userId', 'month_num'], 'outer') .

join(df_daysReg, ['userId', 'month_num'], 'outer') .

join(df_churn, ['userId', 'month_num'], 'outer')Next, generate 1-month lagged features.

These will be used as input features for the model, rather than the current month statistics, since we do not want any look-ahead biases.

windowlag = (Window.


orderBy('month_num'))# generate 1 month lag featuresall_data = all_data.

withColumn('numRegister_lastMonth', lag(all_data ['numRegister']).

over(windowlag))all_data = all_data.

withColumn('numCancelVisits_lastMonth', lag (all_data['numCancelVisits']).

over(windowlag))all_data = all_data.

withColumn('numUpgradeVisits_lastMonth', lag (all_data['numUpgradeVisits']).

over(windowlag))all_data = all_data.

withColumn('numDowngradeVisits_lastMonth', lag (all_data['numDowngradeVisits']).

over(windowlag))all_data = all_data.

withColumn('numHomeVisits_lastMonth', lag (all_data['numHomeVisits']).

over(windowlag))all_data = all_data.

withColumn('numAboutVisits_lastMonth', lag (all_data['numAboutVisits']).

over(windowlag))all_data = all_data.

withColumn('numSettingsVisits_lastMonth', lag (all_data['numSettingsVisits']).

over(windowlag))all_data = all_data.

withColumn('numSaveSettings_lastMonth', lag (all_data['numSaveSettings']).

over(windowlag))all_data = all_data.

withColumn('numLogins_lastMonth', lag(all_data ['numLogins']).

over(windowlag))all_data = all_data.

withColumn('numLogouts_lastMonth', lag(all_data ['numLogouts']).

over(windowlag))all_data = all_data.

withColumn('numAddPlaylists_lastMonth', lag (all_data['numAddPlaylists']).

over(windowlag))all_data = all_data.

withColumn('numFriends_lastMonth', lag(all_data ['numFriends']).

over(windowlag))all_data = all_data.

withColumn('numThumbsUp_lastMonth', lag(all_data ['numThumbsUp']).

over(windowlag))all_data = all_data.

withColumn('numThumbsDown_lastMonth', lag (all_data['numThumbsDown']).

over(windowlag))all_data = all_data.

withColumn('numAdverts_lastMonth', lag(all_data ['numAdverts']).

over(windowlag))all_data = all_data.

withColumn('numSongsPlayed_lastMonth', lag (all_data['numSongsPlayed']).

over(windowlag))all_data = all_data.

withColumn('totalListenTime_lastMonth', lag (all_data['totalListenTime']).

over(windowlag))all_data = all_data.

withColumn('avgSongsPerSession_lastMonth', lag (all_data['avgSongsPerSession']).

over(windowlag))all_data = all_data.

withColumn('numSingersPlayed_lastMonth', lag (all_data['numSingersPlayed']).

over(windowlag))all_data = all_data.

withColumn('avgSingersPerSession_lastMonth', lag (all_data['avgSingersPerSession']).

over(windowlag))all_data = all_data.

withColumn('avgSessionMins_lastMonth', lag (all_data['avgSessionMins']).

over(windowlag))all_data = all_data.

withColumn('numSessions_lastMonth', lag(all_data ['numSessions']).

over(windowlag))all_data = all_data.

withColumn('level_lastMonth', lag(all_data ['level']).

over(windowlag))The generated features that will be used for our predictive model are:numRegister_lastMonth : Number of times user registered last monthnumCancelVisits_lastMonth : Number of times user visited cancellation page last monthnumUpgradeVisits_lastMonth : Number of times user visited upgrade page last monthnumDowngradeVisits_lastMonth : Number of times user visited downgrade page last monthnumHomeVisits_lastMonth : Number of times user visited home page last monthnumAboutVisits_lastMonth : Number of times user visited about page last monthnumSettingsVisits_lastMonth : Number of times user visited settings page last monthnumSaveSettings_lastMonth : Number of times user saved changes to settings last monthnumLogins_lastMonth : Number of times user login last monthnumLogouts_lastMonth : Number of times user logout last monthnumAddPlaylists_lastMonth : Number of songs user added to playlist last monthnumFriends_lastMonth : Number of friends user added last monthnumThumbsUp_lastMonth : Number of thumbs up user gave last monthnumThumbsDown_lastMonth : Number of thumbs down user gave last monthnumAdverts_lastMonth : Number of advertisements played to user last monthnumSongsPlayed_lastMonth : Number of songs user played last monthtotalListenTime_lastMonth : Total listening time for user last monthavgSongsPerSession_lastMonth : Average number of songs played per session by user last monthnumSongsPlayed_lastMonth : Number of songs played by user last monthavgSingersPerSession_lastMonth : Average number of singers played per session by user last monthavgSessionMins_lastMonth : Average number of minutes per session by user last monthnumSessions_lastMonth : Number of sessions by user last monthdaysSinceReg: Number of days since registration to the first day of the current month for each userlevel_lastMonth : Tracks if a user was a paid or free user last month.

If user was a paid user at any period of time last month, we will assume he/she was a paid user for the entire month for simplicityAll missing values are imputed with 0 since missing values generally indicate no page visits, no songs played, etc.

Exploratory Data AnalysisThe following exploratory data analysis was obtained from the medium-sized instance of the dataset.

Similar trends should be observed in both the small and large instances of the dataset.

Shown below are the boxplots of various user statistics from the first month grouped by whether these users churned during the second month.

We are interested to see if there are any discernable patterns in past user behavior that could indicate whether a user will churn in the near future.

A surprising observation is that users who churned in the second month seem to be far more active on the digital music service during their first month, as compared to users who did not churn.

Churned customers consistently have higher page visits, played more songs, listened to more singers, gave more thumbs up and thumbs down, etc.

Perhaps, this suggests that users who are likely to churn next month will often try to make full use of their service in the current month.

In addition, we note that the number of registrations, the number of visits to the cancel page and the number of logins were missing.

When a user registers or login, the system does not know the user’s id.

Hence no information on these activities were recorded for the customers.

As for the missing statistic regarding page visits to the cancel page, I note that none of the users staying with the service in the second month visited the cancel page during the first month.

It is likely that any user who visited the cancel page during the first month probably churned during the month, hence none of them were with the service during the second month.

Premium users are far more likely to churnPremium users seem far more likely to churn as compared to free users.

Note that churn in this case could mean either downgrading from premium to free tier or cancelling their account altogether.

Male users are slightly less likely to churnIn addition, it seems that male users are slightly less likely to churn as compared to female users.

During the second month, the number of users who churned vs number of users who did not churnThe number of users who churned in the second month are far lower than the number of users who did not churn.

Hence, we will need to use the appropriate metrics to evaluate our model to ensure that our results are not deceiving.

We could also employ techniques to handle imbalanced datasets, such as SMOTE, to improve our model’s ability to predict users who will churn.

In conclusion, it is likely that most churn resulted from users downgrading from premium to free tier.

I hypothesized that there was a lower introductory premium pricing plan for new premium users that did not require any long-term commitment.

Many of these new premium users were likely unwilling to make a long-term commitment to the music service, and wanted to maximize their usage of the service during the month when they had premium service.

Alternatively, it could also be that the premium users who churned found the service unsatisfactory after prolonged usage, and thus wanted to downgrade their subscription.

ModellingFrom the exploratory data analysis, we note that past customer activities can provide a relatively good indication of whether they will churn in the near future.

We will now build predictive models, with the aid of Spark ML library, to predict when these customers will potentially churn.

Spark ML library requires input features to be numeric, so we have to convert all string variables to integers or floats.

# convert userId, gender, level, level_lastMonth to numericconvert_numeric = ['userId', 'level', 'gender', 'level_lastMonth']for feat in convert_numeric: featName = feat + "_n" all_data = all_data.

withColumn(featName, all_data[feat].

cast ("float")) all_data = all_data.

drop(feat)You can check the datatypes of the features by running all_data.


In addition, ensure that there are no null or ‘NaN’ values in your input features.

Next, we drop entries from month 0, since we do not have previous months’ statistics to work with.

We will drop the current month statistics and keep last month’s lagged features and the churn labels.

# first month is month 0model_data = all_data .


month_num>0) .

select('userId_n', 'month_num', 'numUpgradeVisits_lastMonth', 'numDowngradeVisits_lastMonth', 'numHomeVisits_lastMonth', 'numAboutVisits_lastMonth', 'numSettingsVisits_lastMonth', 'numSaveSettings_lastMonth', 'numLogouts_lastMonth', 'numAddPlaylists_lastMonth', 'numFriends_lastMonth', 'numThumbsUp_lastMonth', 'numThumbsDown_lastMonth', 'numAdverts_lastMonth', 'numSongsPlayed_lastMonth', 'totalListenTime_lastMonth', 'avgSongsPerSession_lastMonth', 'numSingersPlayed_lastMonth', 'avgSingersPerSession_lastMonth', 'avgSessionMins_lastMonth', 'numSessions_lastMonth', 'level_lastMonth_n', 'gender_n', 'daysSinceReg', 'churn' ).

withColumnRenamed('churn', 'label')We will store the input features and the labels in a data frame model_data.

As a recap, we will be predicting customer churn during the second month using their activity data from the first month.

Hence, we will perform a simple train-test split on the second month to obtain the training and test datasets.

train,test = model_data.


8, 0.

2], seed=50)Note that all input features for a data point have to be placed in a single vector called features.

It is recommended to scale the input features as well.

The column churn should be renamed as label.

Spark ML library will look for the features and label vectors, so use this naming convention.

inputColumns = ['userId_n', 'month_num', 'numUpgradeVisits_lastMonth', 'numDowngradeVisits_lastMonth', 'numHomeVisits_lastMonth', 'numAboutVisits_lastMonth', 'numSettingsVisits_lastMonth', 'numSaveSettings_lastMonth', 'numLogouts_lastMonth', 'numAddPlaylists_lastMonth', 'numFriends_lastMonth', 'numThumbsUp_lastMonth', 'numThumbsDown_lastMonth', 'numAdverts_lastMonth', 'numSongsPlayed_lastMonth', 'totalListenTime_lastMonth', 'avgSongsPerSession_lastMonth', 'numSingersPlayed_lastMonth', 'avgSingersPerSession_lastMonth', 'avgSessionMins_lastMonth', 'numSessions_lastMonth', 'level_lastMonth_n', 'gender_n', 'daysSinceReg']assembler = VectorAssembler(inputCols=inputColumns, outputCol="FeaturesVec") scaler = StandardScaler(inputCol="FeaturesVec", outputCol="features", withMean=True, withStd=True)We will be trying 3 types of models: logistic regression, linear support vector classifier and gradient-boosted trees (GBT).

# set max_iter to 10 to reduce computation time # Logistic Regression lr=LogisticRegression(maxIter=10)pipeline_lr = Pipeline(stages=[assembler, scaler, lr]) # Support Vector Machine Classifier svc = LinearSVC(maxIter=10) pipeline_svc = Pipeline(stages=[assembler, scaler, svc]) # Gradient Boosted Trees gbt = GBTClassifier(maxIter=10, seed=42) pipeline_gbt = Pipeline(stages=[assembler, scaler, gbt])The test accuracy, f1-score, precision and recall will be reported.

F1-score will be preferred since the labels are imbalanced.

where:TN (True Negatives): Customers we predict will not churn and who did not churn in reality.

FP (False Positives): Customers predicted to churn but did not churn in reality.

FN (False Negatives): Customers we predict will not churn but churned in reality.

TP (True Positives): Customers predicted to churn and churned in reality.

The precision tells us how precise our predictions are (of those predicted to churn, how many did churn), while recall indicates how well our model recalled (how many of the original churned customer did the model manage to find).

The f1-score can be viewed as the weighted average of precision and recall.

Since the number of customers who churned is smaller than the number of customers who did not churn, f1-score will provide a more accurate representation of the model’s performance, as opposed to accuracy.

A word of caution when using PySpark’s MulticlassClassificationEvaluator to report the f1-score for a binary classification task, is that the in-built function will treat both labels 0 and 1 as separate classes and return the weighted f1-score of both classes.

This will produce an overly optimistic score that can be deceiving.

Hence, I opted to define my own function to calculate the f1-score.

More details can be found in the code.

ResultsWe will be performing k-fold cross-validation with 5 folds, and the models will be optimized using the area under the PR curve.

Using the area under the PR curve is preferred over using the area under the ROC curve, since the imbalance of classes in the dataset meant that using the area under the ROC curve will lead to an overly optimistic picture.

The following results are obtained from training our model on the full dataset.

Logistic RegressionparamGrid = ParamGridBuilder() .



0, 0.

05, 0.

1]) .

build() cv_lr = CrossValidator(estimator=pipeline_lr, estimatorParamMaps=paramGrid, evaluator= BinaryClassificationEvaluator(metricName= "areaUnderPR"), numFolds=5, seed=42) cvModel_lr = cv_lr.

fit(train)lr_results = cvModel_lr.

transform(test)evaluate_model(lr_results)Accuracy: 0.

790473641959F1-Score: 0.

387802971071Precision: 0.

666666666667Recall: 0.

273428886439Confusion Matrix:TN:2706.

0 | FP:124.

0 FN:659.

0 | TP: 248.

0Linear Support Vector ClassifierparamGrid = ParamGridBuilder() .



0, 0.

05, 0.

1]) .

build()cv_svc = CrossValidator(estimator=pipeline_svc, estimatorParamMaps=paramGrid, evaluator= BinaryClassificationEvaluator(metricName= "areaUnderPR"), numFolds=5, seed=42)cvModel_svc = cv_svc.

fit(train)svc_results = cvModel_svc.

transform(test)evaluate_model(svc_results)Accuracy: 0.

773882793685F1-Score: 0.

189837008629Precision: 0.

727941176471Recall: 0.

109151047409Confusion Matrix:TN:2793.

0 | FP:37.

0 FN:808.

0 | TP: 99.

0Gradient-Boosted TreesparamGrid = ParamGridBuilder() .


minInstancesPerNode,[5]) .


maxDepth,[7]) .



75]) .

build()cv_gbt = CrossValidator(estimator=pipeline_gbt, estimatorParamMaps=paramGrid, evaluator= BinaryClassificationEvaluator(metricName= "areaUnderPR"), numFolds=5, seed=42)cvModel_gbt = cv_gbt.

fit(train)gbt_results = cvModel_gbt.

transform(test)evaluate_model(gbt_results)Accuracy: 0.

776826331282F1-Score: 0.

409348441926Precision: 0.

572277227723Recall: 0.

318632855568Confusion Matrix:TN:2614.

0 | FP:216.

0 FN:618.

0 | TP: 289.

0The GBT model achieved the highest f1-score and the highest recall among the 3 models for the test dataset.

However, it also has the lowest precision.

Nonetheless, the GBT model maintained the best balance between precision and recall.

On the other hand, the linear support vector classifier has the highest precision, but the lowest f1-score, mostly attributed to it’s low recall .

Ultimately, the model chosen for the task of predicting customer churn will depend on your company’s needs.

If you prioritize customer retention, and do not mind spending a little more on discounts and incentives to retain customers, you can opt for the GBT model since it is the best at recalling customers who will churn.

In addition, the GBT model will also be ideal if you want to maintain a better balance between customer retention and the amount spent on discounts and incentives to retain them, since it also has the highest f1-score.

However, if you want to minimize the amount spent on discounts and incentives to retain customers, the linear support vector classifier will be ideal since it’s predictions are the most precise.

ConclusionIn this article, we implemented a single model to predict when customers churn (downgrading from premium to free tier or canceling their accounts).

With this single model, we managed to achieve an f1-score of approximately 0.

4, a respectable score.

An alternative approach that could improve f1-scores, would be to build separate models to predict the two events.

There is the possibility that both events have different signals, and hence the use of separate models would lead to better results.

The code accompanying this article can be found here.

Thank you for reading this article!.If you have any thoughts or feedback, leave a comment below or send me an email at leexinjie@gmail.


I’d love to hear from you!.

. More details

Leave a Reply