Algorithmia Blog - Deploying AI at scale

Multiclass classification in machine learning

Ice cubes being sorted by opacity. Result: Clear ice cubes in one tray and opaque ice cubes in another.

What is multiclass classification? 

Multiclass classification is a classification task that consists of more than two classes, (ie. using a model to identify animal types in images from an encyclopedia). In multiclass classification, a sample can only have one class (ie. an elephant is only an elephant; it is not also a lemur). 

Outside of regression, multiclass classification is probably the most common machine learning task. In classification, we are presented with a number of training examples divided into K separate classes, and we build a machine learning model to predict to which of those classes previously unseen data belongs (ie. the animal types from the example above). In seeing the training data, the model learns patterns specific to each class and uses those patterns to predict the membership of future data.

Multiclass classification use cases

For example, a cybersecurity company might want to be able to monitor a user’s email inbox and classify incoming emails as either potential phishers or not. To do so, it might train a classification model on the email texts and inbound email addresses and learn to predict from which sorts of URLs threatening emails tend to originate. 

As another example, a marketing company might serve an online ad and want to predict whether a given customer will click on it. (This is a binary classification problem.)

How classifier machine learning works

Hundreds of models exist for classification. In fact, it’s often possible to take a model that works for regression and make it into a classification model. This is basically how logistic regression works. We model a linear response WX + b to an input and turn it into a probability value between 0 and 1 by feeding that response into a sigmoid function. We then predict that an input belongs to class 0 if the model outputs a probability greater than 0.5 and belongs to class 1 otherwise.

Another common model for classification is the support vector machine (SVM). An SVM works by projecting the data into a higher dimensional space and separating it into different classes by using a single (or set of) hyperplanes. A single SVM does binary classification and can differentiate between two classes. In order to differentiate between K classes, one can use (K – 1) SVMs. Each one would predict membership in one of the K classes.

Naive Bayes in ML classifiers

Within the realm of natural language processing and text classification, the Naive Bayes model is quite popular. Its popularity in large part arises from the fact of how simple it is and how quickly it trains. In the Naive Bayes classifier, we use Bayes’ Theorem to break down the joint probability of membership in a class into a series of conditional probabilities. 

The model makes the naive assumption (hence Naive Bayes) that all the input features to the model are mutually independent. While this isn’t true, it’s often a good enough approximation to get the results we want. The probability of class membership then breaks down into a product of probabilities, and we just classify an input X as class k if k maximizes this product.

Deep learning classification examples

There also exist plenty of deep learning models for classification. Almost any neural network can be made into a classifier by simply tacking a softmax function onto the last layer. The softmax function creates a probability distribution over K classes, and produces an output vector of length K. Each element of the vector is the probability that the input belongs to the corresponding class. The most likely class is chosen by selecting the index of that vector having the highest probability.

While many neural network architectures can be used, some work better than others. Convolutional Neural Networks (CNNs) typically fare very well on classification tasks, especially for images and text. A CNN extracts useful features from data, particularly ones that are invariant to scaling, transformation, and rotation. This helps it detect images that may be rotated, shrunken, or off-center, allowing it to achieve higher accuracy in image classification tasks.

Unsupervised classification

While nearly all typical classification models are supervised, you can think of unsupervised classification as a clustering problem. In this setting, we want to assign data into one of K groups without having labeled examples ahead of time (just as in unsupervised learning). Classic clustering algorithms such as k-means, k-medoids, or hierarchical clustering perform well at this task.

Keep learning

A guide to reinforcement learning

What is sentiment analysis

How do microservices work?