All posts by Algorithmia

Machine learning use cases

A brain with two segments, one like a neural network drawing

Machine learning is a vast field, composed of many model types, subsets, and use cases. In our forthcoming 2020 State of Enterpriser Machine Learning report, we dig into the use cases that are used most often by businesses today, but as there are new advances made in ML every day, there are also advances in number and complexity of ML use cases. This post will walk through some common machine learning use cases and how they enable businesses to leverage their data in novel ways.

What is machine learning?

Machine learning is the subset of artificial intelligence that involves the study and use of algorithms and statistical models for computer systems to perform specific tasks without human interaction. Machine learning models rely on patterns and inference instead of manual human instruction. Most any task that can be completed with a data-defined pattern or set of rules can be done with machine learning. This allows companies to automate processes that were previously only possible for humans to perform—think responding to customer service calls, bookkeeping, and reviewing resumes. 

To extract machine learning value, a model must be trained to react to certain data in certain ways, which requires a lot of clean training data. Once the model successfully works through the training data and is able to understand the nuances of the patterns its learning, it will be able to perform the task on real data. We’ll walk through some use cases of machine learning to help you understand the value of this technology.

What is machine learning used for?

Machine learning has many potential uses, including external (client-facing) applications like customer service, product recommendation, and pricing forecasts, but it is also being used internally to help speed up processes or improve products that were previously manual and time-consuming. You’ll notice these two types throughout our list of machine learning use cases below.

Seven machine learning use cases list

1. Voice assistants

This consumer-based use for machine learning applies mostly to smart phones and smart home devices. The voice assistants on these devices use machine learning to understand what you say and craft a response. The machine learning models behind voice assistants were trained on human languages and variations in the human voice, because it has to translate what it hears into words and then make an intelligent, on-topic response. 

Millions of consumers use this technology, often without realizing the complexity behind the tool. The concept of training machine learning models to follow rules is fairly simple, but when you consider training a model to understand the human voice, interpret meaning, and craft a response, that is a heavy task. 

2. Dynamic pricing

This machine–based pricing strategy is most known in the travel industry. Flights, hotels, and other travel bookings usually have a dynamic pricing strategy behind them. Consumers know that the sooner they book their trip the better, but they may not realize that the actual price changes are made via machine learning. 

Travel companies set rules for how much the price should increase as the travel date gets closer, how much it should increase as seat availability decreases, and how high it should be relative to competitors. Then, they let the machine learning model run with competitor prices, time, and availability data feeding into it. 

3. Email filtering

This is a classic use of machine learning. Email inboxes also have a spam inbox, where your email provider automatically filters unwanted spam emails. But how do they know when an email is spam? They have trained a model to identify spam emails based on characteristics they have in common. This includes the content of the email itself, the subject, and the sender. If you’ve ever looked at your spam inbox, you know that it wouldn’t be very hard to pick out spam emails because they look very different from real emails.

4. Product recommendations

Amazon and other online retailers often list “recommended products” for each consumer individually. These recommendations are based on past purchases, browsing history, and any other behavioral information they have about consumers. Often the recommendations are helpful in finding related items that you need to complement your purchase (think batteries for a new electronic gadget). 

However, most consumers probably don’t realize that their recommended products are a machine learning model’s analysis of their behavioral data. This is a great way for online retailers to provide extra value or upsells to their customers using machine learning.

5. Personalized marketing

Marketing is becoming more personal as technologies like machine learning gain more ground in the enterprise. Now that much of marketing is online, marketers can use characteristic and behavioral data to segment the market. Digital ad platforms allow marketers to choose characteristics of the audience they want to market to, but many of these platforms take it a step further and continuously optimize the audience based on who clicks and/or converts on the ads. The marketer may have listed 4 attributes they want their audience to have, but the platform may find 5 other attributes that make users more likely to respond to the ads. 

6. Process automation

There are many processes in the enterprise that are much more efficient when done using machine learning. These include analyses such as risk assessments, demand forecasting, customer churn prediction, and others. These processes require a lot of time (possibly months) to do manually, but the insights gained are crucial for business intelligence. But if it takes months to get insights from the data, the insights may already be outdated by the time they are acted upon. Machine learning for process automation alleviates the timeliness issue for enterprises. 

Industries are getting more and more competitive now that technology has sped up these processes. Companies can get up-to-date analyses on their competition in real time. This high level of competition makes customer loyalty even more crucial, and machine learning can even help with customer loyalty analyses like sentiment analysis. Companies like provide a suite of ML tools to enable this type of analysis quickly and deliver results in a consumable format.

7. Fraud detection

Banks use machine learning for fraud detection to keep their consumers safe, but this can also be valuable to companies that handle credit card transactions. Fraud detection can save money on disputes and chargebacks, and machine learning models can be trained to flag transactions that appear fraudulent based on certain characteristics. 

Machine learning can provide value to consumers as well as to enterprises. An enterprise can gain insights into its competitive landscape and customer loyalty and forecast sales or demand in real time with machine learning. 

If you’re already implementing machine learning in your enterprise or you’d like to start...

Continue learning

Harnessing Machine Learning for Data Science Projects

Data science and machine learning: their differences and how they are used in the enterprise

Improving Customer Retention Analytics With Machine Learning

Using machine learning for sentiment analysis: a deep dive

Sentiment analysis depicted as four faces with different facial expressions

Sentiment analysis invites us to consider the sentence, You’re so smart! and discern what’s behind it. It sounds like quite a compliment, right? Clearly the speaker is raining praise on someone with next-level intelligence. However, consider the same sentence in the following context.

Wow, did you think of that all by yourself, Sherlock? You’re so smart!

Now we’re dealing with the same words except they’re surrounded by additional information that changes the tone of the overall message from positive to sarcastic. 

This is one of the reasons why detecting sentiment from natural language (NLP or natural language processing) is a surprisingly complex task. Any machine learning model that hopes to achieve suitable accuracy needs to be able to determine what textual information is relevant to the prediction at hand, have an understanding of negation, human patterns of speech, idioms, metaphors, etc, and be able to assimilate all of this knowledge into a rational judgment about a quantity as nebulous as “sentiment.” 

In fact, when presented with a piece of text, sometimes even humans disagree about its tonality, especially if there’s not a fair deal of informative context provided to help rule out incorrect interpretations. With that said, recent advances in deep learning methods have allowed models to improve to a point that is quickly approaching human precision on this difficult task.

Sentiment analysis datasets

The first step in developing any model is gathering a suitable source of training data, and sentiment analysis is no exception. There are a few standard datasets in the field that are often used to benchmark models and compare accuracies, but new datasets are being developed every day as labeled data continues to become available. 

The first of these datasets is the Stanford Sentiment Treebank. It’s notable for the fact that it contains over 11,000 sentences, which were extracted from movie reviews and accurately parsed into labeled parse trees. This allows recursive models to train on each level in the tree, allowing them to predict the sentiment first for sub-phrases in the sentence and then for the sentence as a whole.

The Amazon Product Reviews Dataset provides over 142 million Amazon product reviews with their associated metadata, allowing machine learning practitioners to train sentiment models using product ratings as a proxy for the sentiment label.

The IMDB Movie Reviews Dataset provides 50,000 highly polarized movie reviews with a 50-50 train/test split.

The Sentiment140 Dataset provides valuable data for training sentiment models to work with social media posts and other informal text. It provides 1.6 million training points, which have been classified as positive, negative, or neutral.

Sentiment analysis, a baseline method

Whenever you test a machine learning method, it’s helpful to have a baseline method and accuracy level against which to measure improvements. In the field of sentiment analysis, one model works particularly well and is easy to set up, making it the ideal baseline for comparison.

To introduce this method, we can define something called a tf-idf score. This stands for term frequency-inverse document frequency, which gives a measure of the relative importance of each word in a set of documents. In simple terms, it computes the relative count of each word in a document reweighted by its prevalence over all documents in a set. (We use the term “document” loosely.) It could be anything from a sentence to a paragraph to a longer-form collection of text. Analytically, we define the tf-idf of a term t as seen in document d, which is a member of a set of documents D as:

tfidf(t, d, D) = tf(t, d) * idf(t, d, D)

Where tf is the term frequency, and idf is the inverse document frequency. These are defined to be:

tf(t, d) = count(t) in document d


idf(t, d, D) = -log(P(t | D))

Where P(t | D) is the probability of seeing term t given that you’ve selected document D.

From here, we can create a vector for each document where each entry in the vector corresponds to a term’s tf-idf score. We place these vectors into a matrix representing the entire set D and train a logistic regression classifier on labeled examples to predict the overall sentiment of D. 

Sentiment analysis models

The idea here is that if you have a bunch of training examples, such as I’m so happy today!, Stay happy San Diego, Coffee makes my heart happy, etc., then terms such as “happy” will have a relatively high tf-idf score when compared with other terms. 

From this, the model should be able to pick up on the fact that the word “happy” is correlated with text having a positive sentiment and use this to predict on future unlabeled examples. Logistic regression is a good model because it trains quickly even on large datasets and provides very robust results. 

Other good model choices include SVMs, Random Forests, and Naive Bayes. These models can be further improved by training on not only individual tokens, but also bigrams or tri-grams. This allows the classifier to pick up on negations and short phrases, which might carry sentiment information that individual tokens do not. Of course, the process of creating and training on n-grams increases the complexity of the model, so care must be taken to ensure that training time does not become prohibitive.

More advanced models

The advent of deep learning has provided a new standard by which to measure sentiment analysis models and has introduced many common model architectures that can be quickly prototyped and adapted to particular datasets to quickly achieve high accuracy.

Most advanced sentiment models start by transforming the input text into an embedded representation. These embeddings are sometimes trained jointly with the model, but usually additional accuracy can be attained by using pre-trained embeddings such as Word2Vec, GloVe, BERT, or FastText

Next, a deep learning model is constructed using these embeddings as the first layer inputs:

Convolutional neural networks
Surprisingly, one model that performs particularly well on sentiment analysis tasks is the convolutional neural network, which is more commonly used in computer vision models. The idea is that instead of performing convolutions on image pixels, the model can instead perform those convolutions in the embedded feature space of the words in a sentence. Since convolutions occur on adjacent words, the model can pick up on negations or n-grams that carry novel sentiment information.

LSTMs and other recurrent neural networks
RNNs are probably the most commonly used deep learning models for NLP and with good reason. Because these networks are recurrent, they are ideal for working with sequential data such as text. In sentiment analysis, they can be used to repeatedly predict the sentiment as each token in a piece of text is ingested. Once the model is fully trained, the sentiment prediction is just the model’s output after seeing all n tokens in a sentence. 

RNNs can also be greatly improved by the incorporation of an attention mechanism, which is a separately trained component of the model. Attention helps a model to determine on which tokens in a sequence of text to apply its focus, thus allowing the model to consolidate more information over more timesteps. 

Recursive neural networks
Although similarly named to recurrent neural nets, recursive neural networks work in a fundamentally different way. Popularized by Stanford researcher Richard Socher, these models take a tree-based representation of an input text and create a vectorized representation for each node in the tree. Typically, the sentence’s parse tree is used. As a sentence is read in, it is parsed on the fly and the model generates a sentiment prediction for each element of the tree. This gives a very interpretable result in the sense that a piece of text’s overall sentiment can be broken down by the sentiments of its constituent phrases and their relative weightings. The SPINN model from Stanford is another example of a neural network that takes this approach.

Multi-task learning
Another promising approach that has emerged recently in NLP is that of multi-task learning. Within this paradigm, a single model is trained jointly across multiple tasks with the goal of achieving state-of-the-art accuracy in as many domains as possible. The idea here is that a model’s performance on task x can be bolstered by its knowledge of related tasks y and z, along with their associated data. Being able to access a shared memory and set of weights across tasks allows for new state-of-the-art accuracies to be reached. Two popular MTL models that have achieved high performance on sentiment analysis tasks are the Dynamic Memory Network and the Neural Semantic Encoder.

Sentiment analysis and unsupervised models

One encouraging aspect of the sentiment analysis task is that it seems to be quite approachable even for unsupervised models that are trained without any labeled sentiment data, only unlabeled text. The key to training unsupervised models with high accuracy is using huge volumes of data. 

One model developed by OpenAI trains on 82 million Amazon reviews that it takes over a month to process! It uses an advanced RNN architecture called a multiplicative LSTM to continually predict the next character in a sequence. In this way, the model learns not only token-level information, but also subword features, such as prefixes and suffixes. Ultimately, it incorporates some supervision into the model, but it is able to acquire the same or better accuracy as other state-of-the-art models with 30-100x less labeled data. It also uncovers a single sentiment “neuron” (or feature) in the model, which turns out to be predictive of the sentiment of a piece of text.

Moving from sentiment to a nuanced spectrum of emotion

Sometimes simply understanding just the sentiment of text is not enough. For acquiring actionable business insights, it can be necessary to tease out further nuances in the emotion that the text conveys. A text having negative sentiment might be expressing any of anger, sadness, grief, fear, or disgust. Likewise, a text having positive sentiment could be communicating any of happiness, joy, surprise, satisfaction, or excitement. Obviously, there’s quite a bit of overlap in the way these different emotions are defined, and the differences between them can be quite subtle. 

This makes the emotion analysis task much more difficult than that of sentiment analysis, but also much more informative. Luckily, more and more data with human annotations of emotional content is being compiled. Some common datasets include the SemEval 2007 Task 14, EmoBank, WASSA 2017, The Emotion in Text Dataset, and the Affect Dataset. Another approach to gathering even larger quantities of data is to use emojis as a proxy for an emotion label. 🙂 

When training on emotion analysis data, any of the aforementioned sentiment analysis models should work well. The only caveat is that they must be adapted to classify inputs into one of n emotional categories rather than a binary positive or negative. 

To see some NLP models in action with Algorithmia, start here, and discover many more sentiment analysis models at

Further reading

MonkeyLearn – A guide to sentiment analysis functions and resources.

Stanford  – Reading Emotions From Speech Using Deep Neural Networks, a publication

Coursera – Applied Text Mining in Python video demonstration

Machine learning methods with R

iris virginica in a field


R is an excellent language for machine learning. R primarily excels in the following three areas:

  1. Data manipulation
  2. Plotting
  3. Built-in libraries

R manipulates data using a native, in-built data structure called a data frame. Let’s load some data now and see how R can help us to work with it.

The data

Luckily for us, R comes with some built-in datasets that we can simply load and have ready to go. We’ll use one such dataset called iris to test some of R’s machine learning capabilities. This dataset contains 50 flowers, each one of three different different species: iris setosa, iris versicolor, and iris virginica. We can classify these data points using a data frame containing 4 different measurement attributes: sepal length, sepal width, petal length, and petal width. Let’s load the data and get started.


The R head command lets us inspect the first few elements of a dataset. You’ll see that the first elements all have the same class label, given by the Species column. When training a machine learning model, we want our data to be randomized so let’s fix this.

Shuffling the data

The elements of our data frame are the rows, so to shuffle the data we need to permute the rows. To do this we can use the built-in R functions, nrow and samplenrow(iris) returns a range the size of the iris dataset. Applying sample to this returns a permutation of the range. We can then index into the dataset using the shuffled row indices as shown below.

shuffled <- iris[sample(nrow(iris)),]

Now, when we take a look at the head of the dataset, we see that we’re getting a mix of different flowers, which is great!

Train and validation splits

Now that we’ve got the data shuffled, let’s partition it into a training and validation set. We’ll hold out 20 percent of the data points for the validation set. We can get the two partitions as follows:

n <- nrow(shuffled)
split <- .2 * n
train <- shuffled[1:split,]
validate <- shuffled[split:n,]

Defining a model

Once that we have our data divided into train and validation splits, we can start training a model on it. We’ll first explore using boosted decision trees for classification with an R package called xgboost. Let’s first install it.


Now, we can construct the model. The XGBoost interface requires that we section the training data into the actual X data and the species labels. To do this, we just select the corresponding columns from the data frame. We also need to convert the Species labels into numerical values in \(\{0, 1, 2\}\) as XGBoost only handles numbered class labels. We’ll do the same for the validation set.

# Training set
dcols <- c(1:4)
lcols <- c(5)
train$Species <- as.character(train$Species)
train$Species[train$Species == "setosa"] <- 0
train$Species[train$Species == "virginica"] <- 1
train$Species[train$Species == "versicolor"] <- 2
X <- train[,dcols]
labels <- train[,lcols]

# Validation Set
validate$Species <- as.character(validate$Species)
validate$Species[validate$Species == "setosa"] <- 0
validate$Species[validate$Species == "virginica"] <- 1
validate$Species[validate$Species == "versicolor"] <- 2
Xv <- validate[,dcols]
labelsV <- validate[,lcols]

Training the model

Now that we’ve gotten that out of the way, we can begin training the model. The interface for xgboost is super simple; we can train the model with a single call.

booster <- xgboost(data = as.matrix(X), label = labels, max.depth = 2, eta = 1, nthread = 2, nrounds = 2, objective = "multi:softmax", num.class = 3)
[1] train-merror:0.000000 
[2] train-merror:0.000000 

The parameters we used deserve some explanation.

  1. We set max.depth = 2 to use decision trees that have a maximum depth of 2. Deeper trees can give higher accuracy but are also higher in variance. Depth 2 should be a good balance for our purposes.
  2. eta specifies the learning rate for our model. This is a hyperparameter that requires some experimentation and tuning to get right. We’ll use 1 for this example.
  3. nthread specifies the number of CPU threads to use while training the model. This is a small dataset so we don’t need many threads, but we would for larger datasets.
  4. nrounds specifies the number of training epochs (passes) to perform over the data. We’ll start with 2.

Evaluating model performance

Now that we’ve trained our boosted tree model, we can evaluate its performance on the validation set. To do this, we can use the predict function that comes packaged with xgboost in order to generate our predictions.

preds <- predict(booster, as.matrix(Xv))
[1] 1 1 0 0 1 0

You can see that our boosted model is predicting nicely. All that is left is to calculate is its accuracy. We will calculate the mean of all entries of preds that are not equal to the corresponding entries of labelsV.

err <- mean(preds != labelsV)
print(paste("validation-error=", err))
[1] "validation-error= 0.0495867768595041"
print(paste("validation-accuracy=", 1 - err))
[1] "validation-accuracy= 0.950413223140496"

Nice! The validation accuracy on the dataset is equal to about 95 percent, which is a great performance from the model.

Visualizing our results

To finish off the project, we’ll just visualize our predictions. First, we will plot the native data using a library called ggvis. Before that though, we must install the package.


Now, since we’re visually limited to two dimensions, we will choose to plot our classes vs. the Sepal Length and Width attributes. This can be done as follows.

Registered S3 method overwritten by 'dplyr':
  method           from
data <- iris
data %>% ggvis(~Sepal.Length, ~Sepal.Width, fill = ~Species) %>% layer_points()

To visualize how our model predicts in comparison, we’ll quickly run it across all data points, not just the validation set. To combine the training and validation data, we use the rbind function, which just joins the two data frames vertically. We also undo the work we did before, converting from numeric labels back to species names.

all_data <- rbind(X, Xv)
preds <- predict(booster, as.matrix(all_data))
all_data["Species"] <- preds
all_data$Species[all_data$Species == 0] <- "setosa"
all_data$Species[all_data$Species == 1] <- "virginica"
all_data$Species[all_data$Species == 2] <- "versicolor"

Everything looks good! Let’s plot the final result, using the same code as before.

all_data %>% ggvis(~Sepal.Length, ~Sepal.Width, fill = ~Species) %>% layer_points()

We can visually compare the true plot with the plot of our predictions. There should only be a few points misclassified between the two!

To learn more about machine learning models in R, visit

5 machine learning models you should know

k-means clustering
(Depiction of a clustering model, Medium)

Getting started with machine learning starts with understanding the how and why behind employing particular methods. We’ve chosen five of the most commonly used machine learning models on which to base the discussion.

AI taxonomy 

Before diving too deep, we thought we’d define some important terms that are often confused when discussing machine learning. 

  • Algorithm – A set of predefined rules used to solve a problem. For example, simple linear regression is a prediction algorithm used to find a target value (y) based on an independent variable (x). 
  • Model – The actual equation or computation that is developed by applying sample data to the parameters of the algorithm. To continue the simple linear regression example, the model is the equation of the line of best fit of the x and y values in the sample set plotted against each other.
  • Neural network  – A multilayered algorithm that consists of an input layer, output layer, and a hidden layer in the middle. The hidden layer is a series of stacked algorithms that iterate until the computer chooses a final output. Neural networks are sometimes referred to as “black box” algorithms because humans don’t have a clear and structured idea how the computer is making its decisions. 
  • Deep learning – Machine learning methods based on neural network architecture. “Deep” refers to the large number of algorithms employed in the hidden layer (often more than 100).
  • Data science – A discipline that combines math, computer science, and business/domain knowledge. 

Machine learning methods

Machine learning methods are often broken down into two broad categories: supervised learning and unsupervised learning

Supervised learning – Supervised learning methods are used to find a specific target, which must also exist in the data. The main categories of supervised learning include classification and regression. 

  • Classification – Classification models often have a binary target sometimes phrased as a “yes” or “no.” A variation on this model is probability estimation in which the target is how likely a new observation is to fall into a particular category. 
  • Regression – Regression models always have a numeric target. They model the relationship between a dependent variable and one or more independent variables. 

Unsupervised learning – Unsupervised learning methods are used when there is no specific target to find. Their purpose is to form groupings within the dataset or make observations about similarities. Further interpretation would be needed to make any decisions on these results. 

  • Clustering – Clustering models look for subgroups within a dataset that share similarities. These natural groupings are similar to each other, but different than other groups. They may or may not have any actual significance. 
  • Dimension reduction – These models reduce the number of variables in a dataset by grouping similar or correlated attributes.

It’s important to note that individual models are not necessarily used in isolation. It often takes a combination of supervised and unsupervised methods to solve a data science problem. For example, one might use a dimension-reduction method on a large dataset and then use the new variables in a regression model. 

To that end, Model pipelining involves the act of splitting up machine learning workflows into modular, reusable parts to couple together with other model applications to build more powerful software over time. 

What are the most popular machine learning algorithms? 

Below we’ve detailed some of the most common machine learning algorithms. They’re often mentioned in introductory data science courses and books and are a good place to begin. We’ve also provided some examples of how these algorithms are used in a business context. 

Linear regression

Linear regression is a method in which you predict an output variable using one or more input variables. This is represented in the form of a line: y=bx+c. The Boston Housing Dataset is one of the most commonly used resources for learning to model using linear regression. With it, you can predict the median value of a home in the Boston area based on 14 attributes, including crime rate per town, student/teacher ratio per town, and the number of rooms in the house. 

K-means clustering

K-means clustering is a method that forms groups of observations around geometric centers called centroids. The “k” refers to the number of clusters, which is determined by the individual conducting the analysis. Clustering is often used as a market segmentation approach to uncover similarity among customers or uncover an entirely new segment altogether. 

k-means clustering


Principal component analysis (PCA)

PCA is a dimension-reduction technique used to reduce the number of variables in a dataset by grouping together variables that are measured on the same scale and are highly correlated. Its purpose is to distill the dataset down to a new set of variables that can still explain most of its variability. 

A common application of PCA is aiding in the interpretation of surveys that have a large number of questions or attributes. For example, global surveys about culture, behavior or well-being are often broken down into principal components that are easy to explain in a final report. In the Oxford Internet Survey, researchers found that their 14 survey questions could be distilled down to four independent factors. 

K-nearest neighbors (k-NN)

Nearest-neighbor reasoning can be used for classification or prediction depending on the variables involved. It is a comparison of distance (often euclidian) between a new observation and those already in a dataset. The “k” is the number of neighbors to compare and is usually chosen by the computer to minimize the chance of overfitting or underfitting the data. 

In a classification scenario, how closely the new observation is to the majority of the neighbors of a particular class determines which class it is in. For this reason, k is often an odd number to prevent ties. For a prediction model, an average of the targeted attribute of the neighbors predicts the value for the new observation. 

k-nearest neighbors depicted


Classification and regression trees (CART)

Decision trees are a transparent way to separate observations and place them into subgroups. CART is a well-known version of a decision tree that can be used for classification or regression. You choose a response variable and make partitions through the predictor variables. The computer typically chooses the number of partitions to prevent underfitting or overfitting the model. CART is useful in situations where “black box” algorithms may be frowned upon due to inexplicability, because interested parties need to see the entire process behind a decision. 

Titanic survivor dataset(

How do I choose the best model for machine learning?

The model you choose for machine learning depends greatly on the question you are trying to answer or the problem you are trying to solve. Additional factors to consider include the type of data you are analyzing (categorical, numerical, or maybe a mixture of both) and how you plan on presenting your results to a larger audience. 

The five model types discussed herein do not represent the full collection of model types out there, but are commonly used for most business use cases we see today. Using the above methods, companies can conduct complex analysis (predict, forecast, find patterns, classify, etc.) to automate workflows.

To learn more about machine learning models, visit

Customer churn prediction with machine learning

Illustration of revolving door with customers leaving

Why is churn prediction important? 

Defined loosely, churn is the process by which customers cease doing business with a company. Preventing a loss in profits is one clear motivation for reducing churn, but other subtleties may underlie a company’s quest to quell it. Most strikingly, the cost of customer acquisition usually starkly outweighs that of customer retention, so stamping out churn also compels from a more subtle financial perspective. 

While churn presents an obvious difficulty to businesses, its remedy is not always immediately clear. In many cases, and without descriptive data, companies are at a loss as to what drives it. Luckily, machine learning provides effective methods for identifying churn’s underlying factors and proscriptive tools for addressing it.

Methods for solving high churn rate

As with any machine learning task, the first, and often the most crucial step, is gathering data. Typical datasets used in customer churn prediction tasks will often curate customer data such as time spent on a company website, links clicked, products purchased, demographic information of users, text analysis of product reviews, tenure of the customer-business relationship, etc. The key here is that the data be high quality, reliable, and plentiful. 

Good results can often still be obtained with sparse data, but obviously more data is usually better. Once the data has been chosen, the problem must be formulated and the data featurization chosen. It’s important that this stage be undertaken with an attention to detail, as churn can mean different things to different enterprises. 

Types of churn

For some, the problem is best characterized as predicting the ratio of churn to retention. For others, predicting the percentage risk that an individual customer will churn is desired. And for many more, identifying churn might constitute taking a global perspective and predicting the future rate at which customers might exit the business relationship. All of these are valid qualifications, but they must be chosen consistently across the customer churn prediction pipeline.

Once the data has been chosen, prepped, and cleaned, modeling can begin. While identifying the most suitable deep learning prediction model can be more of an art than a science, we’re usually dealing with a classification problem (predicting whether a given individual will churn) for which certain models are standards of practice. 

For classification problems such as this, both decision trees and logistic regression are desirable for their ease of use, training and inference speed, and interpretable outputs. These should be the go-to methods in any practitioner’s toolbox for establishing a baseline accuracy before moving onto more complex modeling choices. 

For decision trees, the model can be further tweaked by experimenting with adding random forests, bagging, and boosting. Beyond these two choices, Convolutional Neural Networks, Support Vector Machines, Linear Discriminant Analysis, and Quadratic Discriminant Analysis can all serve as viable prediction models to try. 

Defining metrics with customer data 

Once a model has been chosen, it needs to be evaluated against a consistent and measurable benchmark. One way to do this is to examine the model’s ROC (Receiver Operating Characteristic) curve when applied to a test set. Such a curve plots the True Positive rate against the False Positive rate. By looking to maximize the AUC (area under the curve), one can tune a model’s performance or assess tradeoffs between different models. 

Another useful metric is the Precision-Recall Curve, which, you guessed it, plots precision vs. recall. It’s useful in problems where one class is more qualitatively interesting than the other, which is the case with churn because we’re interested in the smaller proportion of customers looking to leave than those who aren’t (although we do care about them as well). 

In this case, a business would hope to develop potential churners with high precision so as to target potential interventions at them. For example, one such intervention might involve an email blast offering coupons or discounts to those most likely to churn. By carefully selecting which customers to target, businesses can allay the cost of these redemptive measures and increase their effectiveness.

Sifting through insights from model output 

Once the selected model has been tuned, a post-hoc analysis can be conducted. An examination of which input data features were most informative to the model’s success could suggest areas to target and improve. The total pool of customers can even be divided into segments, perhaps by using a clustering algorithm such as k-means. 

This allows businesses to hone in on the particular markets where they may be struggling and custom tailor their churn prevention approaches to meet those markets’ individual needs. They can also tap into the high interpretability of their prediction model (if such an interpretable model was selected) and use it to identify the decisions which led those customers to churn.

Combating churn with machine learning

While churn prediction can look like a daunting task, it’s actually not all that different from any machine learning problem. When looked at generally, the overall workflow looks much the same. However, special care must be given to the feature selection, model interpretation, and post-hoc analysis phases so that appropriate measures can be taken to alleviate churn. 

In this way, the key skill in adapting machine learning to churn prediction lies not in any particular, specialized model to the task but in the domain knowledge of the practitioner and that person’s ability to make knowledgeable business decisions given the black box of a model’s output.