How To Build An Artificial Neural Network in Java
Let me give you a brief note on the purpose of this article and then will discuss the technical part. I see plenty of Python based tutorials on how to start coding deep learning models, train them and deploy into production. Some way or the other my research always ended up considering Python for any machine learning practices. Most of the time, it was due to simplicity and the fact that majority of machine learning libraries are released in Python. Python is interesting, but it was never my primary choice. I was born and brought up in Java and it's always painful to see that there's technology constraint to learn something new. So, should I be transforming to a "Python Developer" to learn machine learning at its best? This question was all over my mind. Well, some folks would argue that you're a software engineer and you're expected to learn any technology stack without hesitance. Who said Python is ugly? It is indeed beautiful, it does the job with less coding effort. Mathematical computations that involve plenty of Java code can be done in few lines in Python. Most of the organizations look for all-rounders who have taste in multiple technologies. But it is also fair to think on how to leverage your existing technology stack to implement what you dream about. Searching for Java-based deep learning frameworks ended up in DL4J. There may be other Java frameworks/libraries that does the job, but DL4J is the only promising commercial-grade deep learning framework in Java till date. The founders have released DL4J for the same concern, for the people of Java. Having said that, do not restrict yourself to particular machine learning library. Solution varies upon your model and you will have to research on what works best in your case. Now let's get our hands dirty! I will be using DeepLearning4j framework throughout this article as my attempt is to provide Java oriented solution.
I'm using the below example to demonstrate the implementation of neural network. An International Banking company wants to increase the customer retention rate by predicting the customers that are most likely to leave the bank.
The customer data set (CSV format) look like this: Click Here
1) Data Pre-Processing
Data in human readable format may doesn't make sense for machine learning purpose. We will be only dealing with digits while we train the neural network. Also, we need to take note of categorical inputs and transform them properly before feeding to neural network. As you could observe, there are 14 fields on the customer data set. The last field "Exited" tells whether customer left the bank or not. '1' indicates that customer has left the bank. So it is going to be your output label. Check the data set and inspect the possible dependent labels on deciding the output. The first three labels: RowNumber, CustomerId & Surname can surely be neglected since those are not deciding factors. Now we have 10 fields for consideration apart from the output label. If you inspect, you will see there are two labels: Geography & Gender in which values are not digits. We need to transform them into digits in a meaningful manner before passing onto neural network. 'Gender' label should be mapped to binary values (0 or 1) depends on male/female. 'Geography' label on the other hand, have multiple values. We can use one hot encoding to encode this label to values.
In DL4J, we can define a schema for the data set and then feed this schema into transform process. We can then apply all the encoding and transformation.
Schema schema = new Schema.Builder()
.addColumnsString("RowNumber")
.addColumnInteger("CustomerId")
.addColumnString("Surname")
.addColumnInteger("CreditScore")
.addColumnCategorical("Geography",Arrays.asList("France","Spain","Germany")) //Define categorical variable
.addColumnCategorical("Gender",Arrays.asList("Male","Female"))//Define categorical variable
.addColumnsInteger("Age","Tenure","Balance","NumOfProducts","HasCrCard","IsActiveMember","EstimatedSalary","Exited")
.build();
After this encoding, 'Geography' label will be converted to multiple columns with binary values. Let's say if we have 3 countries in data set, it will be mapped to three columns, each represents a country value. We also have to take care of dummy variable trap by removing one categorical variable. The removed variable becomes the base category against other categories. For example, we removed "France" and kept it as the base for indicating other 'country' values.
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.removeColumns("RowNumber","Surname","CustomerId") //variables to be ignored
.categoricalToInteger("Gender") //Transforming categorical label into integers
.categoricalToOneHot("Geography")//Applying one-hot encoding
.removeColumns("Geography[France]")//Removing one categorical field to avoid dummy variable trap
.build();
Now it's time to split the data set to training/test sets. We will be using training set for training the neural network and test set will be used to measure how well your neural network has been trained. Mean error and success rate will be calculated at the end of each epoch. Use CSVRecordReader to read the csv file and pass it to TransformProcessRecordReader to apply the transformation we have defined above. Both record readers are implementations of RecordReader interface.
RecordReader reader = new CSVRecordReader(1,','); /* first line to skip and comma seperated */
reader.initialize(new FileSplit(new ClassPathResource("Churn_Modelling.csv").getFile()));
RecordReader transformProcessRecordReader = new TransformProcessRecordReader(reader,transformProcess); //Passing transformation process to convert the csv file
Now let's define what are input labels and what are output labels. The resultant data set after applying transformation would have 13 columns. So the index values are 0 to 12. Last column represent expected output and all other columns are input labels.
Input training labels will look like this:
And the output data will look like this:
Remember to define a batch size for your data set. Batch size defines the quantity on which you want to transfer the data from data set to neural network. We have 10000 entries in our data set. We could have a batch size of 8000 (training set) so that the whole data set can be transferred in single data chunk. But remember, there's big difference while you chose larger batch size. There will be less number of updates performed if you choose large batch size. Now define a lower batch size and use DataSetIterator to retrieve data sets from the file.
int labelIndex = 11; // input labels 0-11
int numClasses = 1; // number of output labels
int batchSize = 10;
DataSetIterator iterator = new RecordReaderDataSetIterator(transformProcessRecordReader,batchSize,labelIndex,numClasses);
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);
DataSet trainSet = testAndTrain.getTrain();
DataSet testSet = testAndTrain.getTest();
The second and the best approach would be making use of DataSetIteratorSplitter.
DataSetIteratorSplitter splitter = new DataSetIteratorSplitter(iterator,10000,0.8);
We can now get training/test set iterators to pass into neural network model once it's ready to be trained.
Now, can we go ahead and feed this data to neural network? Absolutely not! Because when you inspect the data, you will see that data is not scaled properly. The data we're feeding to neural network should be comparable each other. The magnitude of 'Balance' and 'Estimated Salary' is way higher than most other labels. So, if we process them as such, there could be high dependency on these labels on computation. It would potentially hide the effect of other dependent labels for predicting the output model. So, we need to do feature scaling here.
DataNormalization dataNormalization = new NormalizerStandardize();
dataNormalization.fit(iterator);
iterator.setPreProcessor(dataNormalization);/*automatically perform transform on each iterator*/
DataSetIteratorSplitter splitter = new DataSetIteratorSplitter(iterator,10000,0.8);
Remember that, data pre-processing is very crucial for avoiding incorrect outputs and errors and it is entirely dependent on the data we possess. Finally, we have the data that can be fed to neural network. Let's see how we can design the neural network model.
2) Define Neural Network Shape
First, we will start defining the neural network configuration. So, we specify how many neurons should be present in the input layer, the hidden layer structure & it's connections, the output layer, activation functions for each of the layers, the loss function for the output layer and the optimizer function.
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.UNIFORM)//assign uniform weights across the layers
.updater(new Adam()) //Stochastic gradient descent as optimizer function
.list()
.layer(new DenseLayer.Builder().nIn(11).nOut(6).dropOut(0.1).build()) //input layer
.layer(new DenseLayer.Builder().nIn(6).nOut(6).dropOut(0.1).build()) //hidden layer
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(6).nOut(1).activation(Activation.SIGMOID).build()) //output layer, signmoid activation + binary cross entropy loss function
.backprop(true).pretrain(false)
.build();
As you could see we have added dropOuts between input layer and output layer. This is to optimize the neural network by avoiding the over-fitting. Also note that we didn't drop large portion of neurons (only 10% ) to avoid under-fitting at the same time. There are 11 input labels and one output category. So, we configured the same. How we decide the number of neurons in the hidden layer? A recommended model would be an average count of both input and output neurons, so it would be (11+1) / 2 = 6 . Note that our expected output model would indicate the whether the customer would leave the bank or not. It is going to be a probabilistic calculation of customers who leaves the bank. So, this scenario will look like logistic regression and hence use sigmoid activation function in the output layer. We also need to specify the loss function using which error rate will be calculated. In our case, it is the sum of squares of the difference between actual output and expected output. Corresponding loss function is then binary cross-entropy.
3) Train the model and predict results
Once model has been configured, let's compile them using below code:
MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(configuration);
multiLayerNetwork.init();
The purpose of K-Fold cross validation is to ensure low variance in the computed accuracy. If we go with default 10-fold cross validation, we are making 10 folds and will perform updates at each fold we're making. After that, we can start training our neural network model.
DataSetIterator kFoldIterator = new KFoldIterator(trainSet);
multiLayerNetwork.fit(kFoldIterator,100); /*100 epochs, one update every 800 data sets*/
We specified 100 epochs in the above example and we have 8000 data-sets. It will take a while to complete training the neural network. Now that we can go ahead and evaluate our results using below code:
Evaluation evaluation = new Evaluation(1); // number of output categories -> 1
INDArray output = multiLayerNetwork.output(testSet.getFeatureMatrix()); //getFeatureMatrix() for getting input labels from testset
output = output.cond(new AbsValueGreaterThan(0.50)); /*output having probability values, converting them to true or false. True for customer who leaves bank */
evaluation.eval(testSet.getLabels(),output); //evaluate the output with actual outputs
System.out.println("args = [" + evaluation.stats() + "]"); //display evaluation metrics / confusion matrix
And the best approach would be again to pass an iterator instead of dataset to the network model and print the evaluation metrics straight away.
model.fit(splitter.getTrainIterator(),100);
Evaluation evaluation = model.evaluate(splitter.getTestIterator());
System.out.println("args = " + evaluation.stats() + "");
Confusion matrix will be in the form:
=========================Confusion Matrix=========================
0 1
-----------
1556 54 | 0 = 0
297 93 | 1 = 1
0 identified as 0 -> 1556
1 identified as 1 -> 93
Other predictions are wrong. Around 82% accuracy here
Now let's see the neural network in action. Let us predict if the following customer will leave the bank:
Geography: France
Credit Score: 600
Gender: Male
Age: 40
Tenure: 3
Balance: 60000
Number of Products: 2
Has Credit Card: Yes
Is Active Member: Yes
Estimated Salary: 50000
We could program one-hot encoding and transformation to convert them before feeding to neural network. But for now, we're directly feeding the details as a row vector after applying changes manually.
double[] data={0.0, 0, 600, 1, 40, 3, 60000, 2, 1, 1, 50000}; //data transformed to array
INDArray array = Nd4j.create(data);//converting aray to row vector
dataNormalization.transform(array);//feature scaling before passing to neural network
int[] result = multiLayerNetwork.predict(array);
System.out.println("Result = " + Arrays.toString(result)+ ""); //Prints '0' in our case indicating that customer will not leave the bank
Note that we're doing feature scaling before we send the data to neural network. Congrats! We have just developed a standard neural network with 82% prediction rate. Feel free to explore DL4J concrete examples here: https://github.com/rahul-raj/Deeplearning4J. In the near future, I will be coming up with examples for more complex structures including CNN (Convolutional Neural Networks) & RNN (Recurrent Neural Networks). Enjoy Deep learning & Thanks to SuperDataScience for the use-case! Feel free to message me on Linkedin for any queries or clarifications. Wondering how to deploy machine learning models to production? checkout this post.
Data Engineer at Creditchek
5 年app
Data Engineer at Creditchek
5 年Do you know why I may be getting this error?? Error: Program type already present: com.google.thirdparty.publicsuffix.PublicSuffixPatterns
Data Engineer at Creditchek
5 年Thanks a lot Raj. I am developing a 2d prediction model using an android app to get the user inputs. I've been searching for a way to onehotencode the categorical data gotten from the app to feed into the .tflite model. I've search everywhere and couldn't believe there is almost no knowledge on such a simple process. Thanks alot
Theoretical Physics PhD | Quantitative Researcher | Data Analyst | Quantum Technologies
6 年Hey Rahul, this is a really fun tutorial. However I'm having trouble getting the program to build and run properly. I'm using the IntelliJ IDE and installing the deeplearning4j libraries via maven so that the deeplearning4j specific functions work correctly. Yet there are some that still aren't working. What libraries do I need to install to get it to build and run? Thanks
Lecturer at Univesity of Kordofan
6 年Thank you for the excellent discuss. I have a question and I wait your answer. How java can read dataset if it images reside on folder?