Training Graph Neural Networks: Part 12 of my Graph series of blogs

Training Graph Neural Networks: Part 12 of my Graph series of blogs


1. Introduction:


This is the continuation of my series of blogs on Graphs and is the twelfth article in the series. In this article, I will be talking about training ingredients of Graph Convolutional Neural Networks (GCNs).

?

Through the Part 8 and Part 9 of this series, it was explained how Graph neural networks computed node embeddings but the article didn’t discuss about determining the final predictions and steps in the training pipeline following the availability of node embeddings – that is the second half of the Figure as shown below:



Figure 1: The content of Part 8 and Part 9 of this Graph Series (shown marked in dotted lines)



?With this objective, this article is organized as follows:

?

  • Section 2 of this article revisits the definition of a “computation graph” in a Graph Neural Network and the concept of Message Passing and Layers in GNN and emphasizes on thee explanation that each node in a? Graph will have it’s own neural network architecture.


  • Section 3 talks about the Prediction Heads – Prediction Head denotes the output of the final model. Therefore, as envisaged, Node Level, Edge Level and Graph Level Prediction Heads are discussed in this section in sufficient detail.

?

  • Section 4 attempts to throw some insight on Supervised vs. Unsupervised Learning in Graphs with examples.

?

  • Section 5 and Section 6 go into the details of Loss Functions and Evaluation Metrics for Regression and Classification Problems in Graphs.


2. Review of the Part 8 and Part 9 of Graph Convolutional Neural Networks (GCNs):

?

The Part 8 and Part 9 of this series of blogs went into the details of generating node embeddings. The computation of Node Embeddings through Graph Neural Networks involved the understanding of the computation graph of a node and the calculations corresponding to every layer of the computation graph. ?The idea there was to compute the node embeddings following the calculations explained.

?

We considered an example Graph as shown in the Figure 2 below and constructed the computation graph of the node A as shown in the Figure 3 below.



Figure 2: Example Graph for explaining the construction of a Computation Graph



Figure 3: Computation Graph of Node A


It was explained through the above articles that the architecture corresponding to the computation graph of node A was justified as follows:

?

  • The node A as shown in the Figures 2 and 3 above, is going to take information from the neighbours in the network. The neighbours of A being B, C and D.? This is one-hop of information. This can be unfolded to multiple hops of information.

?

  • If we want to consider another hop, then, the node D takes the information from its neighbours A and node C takes the information from A, B, E and F (that is: the neighbours of C) and similarly node B takes information from its neighbours A and C.

?

I had further highlighted that the messages (embeddings) from node A and C to the node B, from nodes A, B, E and F to the node C and from node A to the node C and the messages from A to the node D will have to be transformed, aggregated together and passed to B. Similarly, the message from B will have to be transformed/aggregated before passing it to A and the same is with the other nodes. The transformation will be parametrised, and the parameters will have to be learnt.

?

It must be understood that the messages corresponded to the embeddings and that the embeddings are computed at every layer of the network – the layers are shown marked in the Figure 3. The main point to understand about the difference between Graph Neural Networks and Vanilla Neural Networks is that in Graph Neural Networks each node has its own neural network architecture. Every node in the Graph Network gets to define its own computation graph based on the network around its neighbourhood as shown in the Figure below:



Figure 4: Computation Graph for each node in the network


It must be underscored that the computation graph for every node can have arbitrary number of layers depending upon the number of hops considered. The embedding corresponding to the Layer 0 of node “u” may be taken as the input feature, xu of that node. Thus, the Layer k embedding gets the information from nodes that are k-hops away.

?


The article corresponding to the Part 8 and Part 9 of the series did not go into the detail of getting from the node embeddings to the actual prediction and once we have the predictions how do we evaluate them against the ground truth and how does one define the loss – the discrepancy between predicted and true labels. This is illustrated in the Figure 5 below:



Figure 5: Content of articles 8 and 9 vs. content the current article


Through the Part 8 and Part 9 of this series, it was explained how Graph neural networks produced node embeddings through the final layer “L” of the network as illustrated in the equation below:




In the above equation hv(L) corresponds to the node embeddings of node v at Layer L where the node v belongs to a Graph G.

?

Now, the question is to go into he details of arriving at the final predictions from the node embeddings highlighted in the Figure 5 in the block corresponding to Prediction Head.

?

3. Graph Neural Networks – Training Pipeline: Prediction Heads

?

Let us first talk about the Prediction Heads – Prediction Head denotes the output of the final model. Therefore, in case of Graphs, the final output prediction may be based on:

?

a)????? Node Level Tasks

b)???? Edge Level tasks

c)????? Graph Level tasks

?

We will be going through each of the above prediction heads as illustrated and discussed below:



Figure 6: Prediction Head in a GCN pipeline



Figure 7: Different prediction heads (that is – possible final predictions) in a Graph



3.1 Node Level Prediction Heads:

?

For node level prediction heads, we can directly make use of node embedding output through the GCN computation as detailed in the Part 9 of this series in section 9. That is – once we have the d-dimensional embedding for every node in the network and we want to make a “k” way prediction – that is – classification of node based on “k” different classes, the idea is quite simple and intuitive.

?

In such cases, for node level prediction, the output for every node is going to be a “weight matrix times the embedding of the node” as illustrated in the equation below:



Equation: Mapping of node embeddings from embedding space to prediction space


This means that we will be mapping the node embeddings from the embedding space to the prediction space. We will have a “k” dimensional output because we are interested in a “k” way prediction – that is, there might be “k” possible classification labels. The y_hat denotes the predicted value and “y” is the actual value.

?

Now with y_hat and y we can compute the discrepancy between the ground and the truth – that is the loss.

?

3.2 Edge Level Prediction Heads:

?

Let us now see what the options for the edge level prediction heads are. For edge level prediction task, we have to make the prediction on a pair of node embeddings. ?Let us say we have to make a k-way prediction on a pair of node embeddings. When we say a k-way prediction, it means we will have to predict a link for a relationship type and that there may be “k” relation types.



Figure 8: Link Prediction task in a Graph


Let us see what options are available – the pair of node embeddings here are(hu(L) and hv(L) as shown in the Figure 8 above.


Options for Edge Level Prediction Heads:

?

a) One option for the edge level prediction head is to concatenate the pair of node embeddings as shown in the figure below and pass them through a linear layer as shown below:


Figure 9: Concatenating Node Embeddings for Edge Level Prediction


So that:


Equation: Mapping of concatenated node embeddings from 2 x d embedding space to prediction space


We could then apply a non-linearity like sigmoid/Softmax layer. Therefore, the idea is to apply a linear transformation and then apply a non-linearity to get a k-dimensional output. Thus, we map from 2 x d dimensional embedding into a k-dimensional output.?


b)??Another option for edge level prediction head is to carry out a scalar dot product. For example, when we have only one relation type, then, to predict whether a link exists between two nodes or not could be ascertained by taking a dot product of the embeddings which would return a scalar:



Equation: Link prediction for one relation type – using dot product


c)???Now, if we have a k-way prediction – that is: if we want to predict the type of link for different relationship types then we will have multiheaded prediction problem. Here we would have a different weight matrix for each of the k-classes as illustrated in the equation below. After the multiplication with the weight matrix, we get a scalar and then all the concatenate each of the yuv_bar as shown in the equations below to get a k-way prediction. Each weight matrix will have to be learnable.



Equation: Link prediction – separate weight matrix for each class


Thus, every class would learn its own transformation matrix that will rotate/translate/shrink or stretch the vector. ?Once we have the prediction for each of the classes, we can concatenate them to get the final prediction.

?

?

3.3 ??Graph Level Prediction Heads:

?

Next, let us discuss about the Graph Level Prediction head – if we have to make the prediction on the entire Graph. For Graph level prediction, we would want to predict using all the node embeddings in the graph because the prediction involves the entire Graph.

?

Thus, the prediction head will make the prediction involving the entire Graph as shown in the equation below. This would involve aggregating the node embeddings for the entire graph in some form as explained in the sections corresponding to Global Mean Pooling, Global Max Pooling, Summation Based Pooling in this section.



Figure 10: Aggregation of all the Node Embeddings for Graph Level Predictions


Let us see how we can define the Graph prediction head. There are many ways to do this:


a)??Global Mean Pooling:

?

In Global Mean Pooling we take the average of all the node embeddings in the Graph as shown in the equation below:



Equation – Global Mean Pooling – involves Averaging of the embeddings of all nodes


b)??Global Max Pooling:


In Global max pooling, we take the maximum embedding coordinate wise maximum of the embedding of all the nodes.



Equation – Global Max Pooling – involves taking coordinate-wise max of all node embeddings in the network


c)???Summation Based Pooling:

In summation-based pooling, we take the sum of the embeddings of all the nodes in the Graph.



Equation – Global Sum Pooling – involves taking summation of all node embeddings in the network


3.3.1???Choosing the Right Pooling Strategy for Different Graph Level Prediction Scenarios:

?

Sum based pooling is a better option when we want to understand the structure of the graph or how many nodes are there in the Graph. Mean pooing is useful when the number of nodes does not play a role and when we’re comparing graphs of very different sizes.

?

The options described above for Graph Level Predictions work great for small graphs. There can be more advanced form of pooling that may resort to because global pooling in large graph may result in loss of information. More Advanced form of Pooling such as Hierarchical Pooling are discussed in section 3.3.4 but first let us discuss the issues with Global Pooling.

?

3.3.2????Issues with Global Pooling:

?

Let us outline the various problems with global through very simplistic examples as illustrated below.


Example 1:

Let us consider 2 Graphs one dimensional node embedding as shown below:

?

  • Node Embedding for Graph G1: {-1, -2, 0, 1, 2}

?

  • Node embedding for Graph G2: {-10, -20, 0, 10, 20}

?

For simplicity, we have taken the embedding as a single number – one dimensional node embedding.

?

Clearly G1 and G2 have very different node embedding – the structure of the Graphs G1 and G2 may be very different, but if we take Global sum pooling, then, we get as follows for Graphs G1 and G2:

?

For Graph G1 and G2,



Equation: Sum based pooling for Example 1 Graphs with 1-dimesional embedding


Thus, if we take the sum/average of both the graphs, we will have the same value. Therefore, we will not be able to classify the graphs into different classes as they have the same representation. This is an edge case example where global pooling and lead to unsatisfactory results especially where graphs are larger. A solution to this is Hierarchical pooling. Let us understand about Hierarchical Pooling.


3.3.3????Hierarchical Pooling:

?

In Hierarchical pooling, we do not aggregate everything together at the same time, but we take smaller groups and aggregate. For example, let us say we have (once again) Graphs with one dimensional embedding as illustrated below:

?

Graph G1 node embeddings are: {-1, -2, 0, 1, 2}

?

?

Let us assume we aggregate using ReLu as a non-linearity and summation as the aggregation function:

?

To aggregate hierarchically means that we aggregate the first two nodes and then aggregate then aggregate the last three nodes and then aggregate the aggregates.

?

Thus, for Graph G1:

?

Graph G1 node embeddings are: {-1, -2, 0, 1, 2}

?

So, for the Round 1 of aggregate of Graph G1: Aggregating the first two nodes and then the last 3 nodes:



Equation: Graph G1 – Hierarchical Pooling – Aggregation Round 1


Round 2: Aggregating the Aggregates:


Equation: Graph G1 – Hierarchical Pooling – Aggregation Round 2


Now Aggregating for Graph G2 in the similar fashion, we have:

?

Graph G2 node embeddings: {-10, -20, 0, 10, 20}

?

Round 1 of aggregate of Graph G2: Aggregating the first two nodes and then the last 3 nodes:


Equation: Graph G2 – Hierarchical Pooling – Aggregation Round 1


Round 2: Aggregating the Aggregates:


Equation: Graph G2 – Hierarchical Pooling – Aggregation Round 2


As seen through the operations above, the two Graphs have different embeddings and do not overlap in the embedding space. That is an illustration of how hierarchical pooling helps compare Global mean pooling or global summation pooling.

?

Hierarchical pooling in practice:

?

The question now arises how we are going to decide the order of aggregation in Hierarchical Pooling. That is – how do we decide which nodes to aggregate first and which nodes to aggregate in subsequent turns.

?

The insight which really helps to do well in this aspect – related to deciding the order of aggregation – is that Graphs tend to have a community structure. For example, if we take a social network Graph, there are tightly knight communities in a social network graph.

?

So, the idea is to detect these communities ahead of time and then aggregate the nodes inside the community this forming the community embedding and then further aggregate these community embeddings and then further aggregate these community embeddings based on how different communities are linked together. This is illustrated in the figure below:



Figure 11: Hierarchical Pooling in practice – Pooling of embeddings based on communities (followed by application of community detection algorithm)


The strategy for Hierarchical Pooling will thus be to split the graph into different clusters using a community detection algorithm and then aggregate in the respective community/cluster. This will be followed by aggregation based on how the communities are linked together and continue until you get to a prediction head.

?

Research Paper on Differential Pooling


The paper on Differential Pooling [https://arxiv.org/abs/1806.08804] describes how to strategize this. In a nutshell, we may leverage two independent GNNS and one Graph Neural Network will compute the node embedding and the second Graph Neural Network will determine the clusters to which the node belongs to so that they can be aggregated as described above. GNNs at the two levels can be executed in parallel.


4. Graph Neural Networks – Training Pipeline: Predictions and the Labels

?

Having talked about the Prediction heads, let’s talk about the predictions and the labels as highlighted in the figure below:



Figure 12: Graph Neural Network Training Pipeline – Discussion on Predictions and Labels


Supervised Learning vs Unsupervised Learning in Graphs:

?

Supervised Learning:

We can broadly distinguish in Graphs between Supervised Learning and Unsupervised Learning. Supervised Learning in Graphs would be where labels come from external sources e.g. where node belong to different classes, users in social network being interested in different topics, molecules being classified as toxic/non-toxic.


Supervised labels in Graphs come from specific use cases. For example:

?

a)????? In case of node labels in citation network, we could say that the subject area the research paper belongs to is the external label. This is defined for every node.

?

b)???? In a link prediction task, in a transaction network, we could have a label “y” for every transaction that tells whether the transaction is fraudulent or not.

?

c)????? For a Graph level task, the drug toxicity can be an externally defined label.

?

It is often advised to reduce the task to node/edge/graph level prediction task since they are easy to work with.

?

Unsupervised Learning:

We also have unsupervised learning in Graphs where the signal comes from the Graph itself. For example, in link prediction we want to predict whether a pair of nodes are connected, here we do not need any external information, what we need is just the pair of nodes that may be connected/not connected.

?

Unsupervised learning may also be termed as self-supervised learning because the input data gives the supervision to the model like in the case of link prediction.

?

The idea in unsupervised learning on Graphs is sometimes one does not have external labels as described above. The solution then is to define self-supervised leaning tasks. The supervised signal comes from the Graph itself – some of the examples include the following:

?

a)??For node level tasks – this may be prediction of node clustering coefficients, for molecular graph, it may be prediction of type of atom in a given node is: hydrogen, carbon, etc.

?

b)??For a link prediction task, it may be to hide a couple of edges and predict if the pair of edges is connected or not.


c)??For Graph level task it might be required to predict if 2 graphs are isomorphic.

?

In the type of tasks highlighted above, we do not require any ground truth label but only use graph structure information.


5. Loss Functions used while training GNNs

?

Loss functions will measure the discrepancy between prediction and the labels so that we can optimize the loss and back propagate all the way down to the parameters of the model. This section elaborates on the loss functions as illustrated in the Figure below that shows the overall training pipeline of GNNs.



Figure 13: Graph Neural Network Training Pipeline – Discussion on Loss Functions


Settings for GNN training:

?

Based on the discussions so far, we have the following setting in GNNNs:

?

  • We have “N” number of data points

?

  • Each data point can be an individual node / individual edge or / individual graph.

?

  • For each node level prediction, each node has a label yv(i) and the prediction label y_hatv(i) – the subscript “v” denotes the node and “i” the data point.



Node labels – predicted and actual


  • Each edge has a label yuv(i) and a prediction label y_hat(uv)^(i). The edge may denote whether the transaction between two nodes is fraudulent or not and the edge label may denote the type of edge.


Edge labels – predicted and actual


  • Similarly for Graph level prediction, we have again a prediction label y_hat(G)(i) and actual label yg(i)


Graph labels – predicted and actual


y_hat and y are notations used for predicted and actual respectively and i denotes the data point/training example and the subscripts v, uv and G denote the node, node pairs and graph respectively.

?

Classification or Regression:

An important distinct whilst solving the prediction problem in Graphs (or even with conventional machine learning) is whether we’re doing classification or regression. In classification problems, the labels “y” will have discrete categorical values like what topic or what kind of movies the user likes to see in their newsfeed on social media.

?

In regression we’re predicting continuous values such as displacement at a node of interest as shown in the figure below:



Figure 14: Node Regression problem to predict the displacement at a node on the body of a car


Or, in the case of biomedicine in classification one may predict if a particular drug is toxic or not and in regression one may predict the toxicity level.



Figure 15: Graph Neural Networks for toxicity prediction in biomedicine


GNNs can be applied to both the settings: classification as well as regression. The loss function and the evaluation metric will be different for both classification and regression. Let us discuss about the loss functions and the evaluation metrics for classification and regression in paragraphs below:

?

Classification Loss Function:

The most common Loss Function for classification is the Cross Entropy Loss Function. I have discussed about the Cross Entropy Loss Function in section 6 of my blog of the “Foundational Principles of Deep Learning” ?here. ?

?

In Cross Entropy Loss, if we’re doing a k-way prediction for the ith data point, the cross-entropy loss between the true label y and predicted label y_hat is given by:



Equation: Cross Entropy Loss for a Classification Problem


It is the sum between k different classes the y(i) values times the predicted log values

y_hat may be interpreted as the probability. The way it works is:

?

  • Imagine “y” is as illustrated as below:



Equation – Illustration of actual label of the ith training example in a classification problem


It is the one hot label encoding – this is a binary vector (vector of 0s or 1s) that tells you the particular node belongs to class number 3.

?

  • Then, y_hat will be a distribution as below where all entries sum to 1after the application of SoftMax so that all entries sum to 1. y_hat can be interpreted as the probability as shown in the equation below:



Equation – Illustration of predicted label of the ith training example in a classification problem


  • From the above, we can interpret that the prediction corresponds to the class number 3 with the highest probability.

?

?

  • The idea behind the equation of the cross-entropy loss function again illustrated below:


Equation: Cross Entropy Loss for a Classification Problem


It can be observed through the example, for y = 1 and y_hat close to 1 y(i). log(y) will approach 0 resulting in low loss which is what we want to achieve in the scenario.

?

And when y_hat is small, log(y_hat) will be large and 1 multiplies by large number gives a very large number resulting in high loss which is what we want to achieve.

?

This is for classification loss and the interpretation is not different from other ML problems. The total loss is sum of the losses over all the data points with cross entropy loss over an individual data point defined as above.



Equation – Total Cross Entropy Loss over N data points/training examples


Regression Loss:


For Regression Loss, the standard loss is the Mean Square Error or L2 Loss. Essentially, in a regression task for a k-way prediction, we’re targeting to predict, k-way real values for a given node I – which is actually the case in displacement prediction for the car problem, the Mean Square Error Loss for a data point (i) is represented as:



Equation: Mean Square Error Loss for a Classification Problem


The reason we take the quadratic loss is because its smooth, its continuous and its easy to take the derivative of, is always positive. The above equation represents the mean square error loss over one data point/one training example. If we have “N” training examples, we simply sum up the losses as below:



Equation – Total Mean Square Error Loss over N data points/training examples


This is in terms of classification and regression loss.


6. Evaluation Metrics used while training GNNs

?

Having discussed about Loss Functions, let us next talk about Evaluation Metrics as illustrated in the GNN Training pipeline in the Figure below.



Figure 16: Graph Neural Network Training Pipeline – Discussion on Evaluation Metrics

?

Let us now see how to measure the success of a Graph Convolutional Neural Network and talk about:


  • Accuracy
  • ROC (Receiver Operating Characteristic)


Regression tasks:

For evaluation metrics for regression tasks, we evaluate the root mean square error loss as shown below:

?

Root Mean Square Error Loss (RMSE):


Equation: The Root Mean Square Error Loss in a Regression Problem


?The above are analogous to mean square error loss the aim being to optimize and the above is with respect to how you report the performance of the model. ?

?

Mean Absolute Error Loss (MAE):

?

One could also do Mean Absolute Error Loss where you take the absolute differences as shown in the equation below:



Equation: The Mean Absolute Error Loss in a Regression Problem


The scikit-learn Python library implements different metrics, the above are the two most common.

?

Classification Loss:

For classification loss, we have a standard way to compute the classification accuracy for multi-class classification problems. The classification accuracy is computed as:



Equation: Classification Accuracy

This denotes the number of times the predicted class matched the true class – that is the fraction of times the class was predicted correctly.

?

This is a good metric where the class are balanced. However, there may be cases where we have an unbalanced distribution, then, it will not be right to use the above classification Accuracy. This use case will call for using other means to calculate the Accuracy such as: Precision, Recall and F1-Score. This has been discussed in my blog on: Skewed Datasets and Error Metrics here and is therefore not included in this blog.

要查看或添加评论,请登录

Ajay Taneja的更多文章

社区洞察

其他会员也浏览了