October 15, 2021

[Algorithms] - Multi class classification

[Algorithms] - Multi class classification

From Techopedia: Linear multiclass classification is a specific kind of targeted algorithm philosophy in machine learning and the field of structured prediction that uses both linear and multiclass methods. A multiclass classification is used to classify more than two classes – in contrast to a binary classification.

A linear classification uses an object’s characteristics to classify it by basing a decision on the value of a linear combination of characteristics.

How do we use sci-kit learn for this?

Well, fortunately scikit-learn makes it very easy to learn multiclass classification models. Essentially, it does this by converting a multiclass classification problem into a series of binary problems. What do I mean by that? Well, essentially when you pass in a dataset that has a categorical variable for the target value, scikit-learn detects this automatically and then for each class to be predicted. Scikit-learn creates one binary classifier that predicts that class against all the other classes. So for example, in the fruit dataset there are four categories of fruit. So scikit-learn learns four different binary classifiers. To predict a new data instance, what it then does is, takes that data instance to be predicted, whose labels to be predict, and runs it against each of the binary classifiers in turn, and the classifier that has the highest score is the one that, whose class it uses, as the prediction value.

Here, we simply pass in the normal dataset that has the value from one to four as the category of fruit to be predicted. And we fit it exactly the same way that we would fit the model as if it were a binary problem. And in general, if we're just, you know, fitting, and then predicting, all of this would be completely transparent. Scikit-learn would simply do the right thing and it would learn multiple classes, and it would predict multiple classes, and we wouldn't really have to do much else. However, we can get access to what's happening under the hood as it were, if we look at the coefficients and the intercepts of the linear models that result from fitting to the training data.

from sklearn.svm import LinearSVC

X_train, X_test, y_train, y_test = train_test_split(X_fruits_2d, y_fruits_2d, random_state = 0)

clf = LinearSVC(C=5, random_state = 67).fit(X_train, y_train)
print('Coefficients:\n', clf.coef_)
print('Intercepts:\n', clf.intercept_)
Coefficients:
 [[-0.26  0.71]
 [-1.63  1.16]
 [ 0.03  0.29]
 [ 1.24 -1.64]]
Intercepts:
 [-3.29  1.2  -2.72  1.16]

What we're doing here is fitting a linear support vector machine to the fruit training data. And if we look at the coefficient values, we'll see that instead of just one pair of coefficients for a single linear model, a classifier, we actually get four values. And these values correspond to the four classes of fruit in the training set. And so, what scikit-learn has done here is it's created four binary classifiers, one for each class.

Now lets plot this:

plt.figure(figsize=(6,6))
colors = ['r', 'g', 'b', 'y']
cmap_fruits = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#FFFF00'])

plt.scatter(X_fruits_2d[['height']], X_fruits_2d[['width']],
           c=y_fruits_2d, cmap=cmap_fruits, edgecolor = 'black', alpha=.7)

x_0_range = np.linspace(-10, 15)

for w, b, color in zip(clf.coef_, clf.intercept_, ['r', 'g', 'b', 'y']):
    # Since class prediction with a linear model uses the formula y = w_0 x_0 + w_1 x_1 + b, 
    # and the decision boundary is defined as being all points with y = 0, to plot x_1 as a 
    # function of x_0 we just solve w_0 x_0 + w_1 x_1 + b = 0 for x_1:
    plt.plot(x_0_range, -(x_0_range * w[0] + b) / w[1], c=color, alpha=.8)
    
plt.legend(target_names_fruits)
plt.xlabel('height')
plt.ylabel('width')
plt.xlim(-16, 16)
plt.ylim(-16, 16)
plt.show()

In this case the first pair of coefficients corresponds to a classifier that classifies apples versus the rest of the fruit, and so on with the rest of the coefficients and intercepts.

When a new object needs to be classified, let's say  Width: 10, and Height 5. These data point is just above the blue line, when predicting with python code, the sci kit learn will go through all 4 binary classifiers, and the one with the highest score will be the predicted fruit.

Some of these notes were taken from the Coursera course Applied Machine Learning in Python. The information is presented by Kevyn Collins-Thompson, PhD, an associate professor of Information and Computer Science at the University of Michigan