Machine Learning 101 - Part 2- A Tutorial on Simple Linear Regression
In my previous article Machine Learning 101 - Part 1- A Tutorial for Data Preprocessing, we have learned how to get dataset, fill missing data, transform categorical data, and do feature scaling. From this article on, we will dive right into the pool of machine learning algorithms. The first algorithm I am gonna introduce is simple linear regression. In this model, there is only one independent variable, and it assumes that the dependent variable has linear relationship with the independent variable. If x is the independent variable, y is the dependent variable, the simple linear regression model can be mathematically expressed as y = b + a * x, where b is the intercept, and a is the coefficient. Let's take a look at part of the data we are working with:
Assuming we have surveyed 100 publicly traded companies and collected their annual revenue and R&D expense data. We would like to find out if there is a linear relationship between the R&D expense and annual revenue. The data provided here is super clean. There is no missing data and there is neither currency sign in front of the numbers nor commas separating the numbers. If the real-world data contains dollar signs or commas, you can refer to the Python script in my Github repo: Simple Linear Regression. for sanitizing the data.
# Importing the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# Importing the dataset
dataset = pd.read_csv('revenue_rdexpense.csv')
X = dataset.iloc[:, :-1].values
y = dataset.iloc[:, 1].values
Firstly, we import the dataset from the revenue_rdexpense.csv file, and assign revenue as the independent variable, R&D expense as the dependent variable. Since there is limited amount of data and we want to use all the data for our simple linear regression model, we don't split our dataset this time. Also, because the linear regression model automatically does the feature scaling for us, we don't need to explicitly feature scale the data.
# Fitting Simple Linear Regression to the data set
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X, y)
The actual linear regression fitting algorithm is only three lines of code. In order to do simple linear regression, we need to import the LinearRegression object from the sklearn.linear_model library. This LinearRegression object is not only able to do simple linear regression with one independent variable, it is also able to do multiple linear regression with multiple independent variables (that is the topic of the following article). We first create an instance of the LinearRegression object called regressor. Then we apply the fit method of regressor to X and y. That is it! The fitting is done.
The next step is to see how good the fitting is compared to the original values. Note in the beginning we import matplotlib.pyplot as plt. It is a super useful tool for visualizing the regression results. We will show how it is used:
# Visualising the Regression results
plt.scatter(X, y, color = 'red')
plt.plot(X, regressor.predict(X), color = 'blue')
plt.title('Simple Linear Regression Model')
plt.xlabel('Revenue(Million Dollars)')
plt.ylabel('R&D Expense(Million Dollars)')
plt.show()
You can use the scatter method of plt to plot the original y values against X. Then we can plot the predicted y values against X. We can also specify the title, x coordinate label, y coordinate label on the graph. Here is what the graph looks like:
Visually, the scattered points lie near the fitting line, which indicates that the R&D expense does have linear relationship with the revenue to some extent. Regressor also provides some other functions to help us gain more insights of the fitting such as getting the coefficients and intercept of the linear regression:
# Print the coeficients and intercept
print('Coefficients: \n', regressor.coef_)
print('Intercept: \n', regressor.intercept_)
Conclusion
In this article, we have introduced how to use the LinearRegression object to perform simple linear regression with just one independent variable. It is easy to implement and it is super powerful. But in reality, things are more complicated. Often times, an outcome depends on multiple factors. And how can model this event? That is the problem the multiple linear regression tries to solve to some extent. In my next article Machine Learning 101 - Part 3- A Tutorial for Multiple Linear Regression, I am going to introduce a more powerful multiple linear regression algorithm, as well as the corresponding backward elimination algorithm to get the best fitting results. Stay tuned.
References
You can download the Python script and dataset for this tutorial from my Github repo: Simple Linear Regression.