WorkFlow for Neural Layer Splitting
Subramaniyam Venkata Pooni
Distinguished Technologist | AI & Cloud-Native Innovator | 5G & Edge Computing Expert
A step-by-step explanation of how neural network splitting works, from high-level design to compiled code, across multiple machines:
Workflow
1. High-Level Design: Neural Network Definition
1.1 Model Architecture:
At a high level, the neural network is typically defined in a framework like TensorFlow, PyTorch, or JAX.
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # Flatten
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
1.2 Splitting Plan:
Decide how to split the model across machines:
Layer Split: Assign conv1 to Machine A, conv2 to Machine B, etc.
Intra-Layer Split: Divide large operations (e.g., large feature maps) across machines.
2. Intermediate Representation (IR) of the Model
2.1 Building the Computational Graph:
The framework converts the model into a computational graph.
Example:conv1 → ReLU → conv2 → ReLU → fc1 → ReLU → fc2
Nodes represent operations, and edges represent data flow (tensors).
2.2 Graph Partitioning:
Horizontal Partitioning: Large tensors are split into chunks for parallel processing.
Vertical Partitioning: Layers are assigned to different devices.
Example:
Machine A: conv1 → ReLU
Machine B: conv2 → ReLU
Machine C: fc1 → ReLU → fc2
2.3 IR Generation:
The graph is transformed into Intermediate Representation (IR) for further optimization.
Example:
%conv1 = Conv2D(%input, %weights1)
%relu1 = Relu(%conv1)
%conv2 = Conv2D(%relu1, %weights2)
3. Optimization and Compilation
3.1 Graph Optimization:
Operator Fusion: Combine operations like Conv2D and ReLU into a single kernel.
Memory Optimization: Minimize tensor storage by reusing memory.
Example:
领英推荐
FusedOp = Conv2D + ReLU
3.2 Partitioning for Devices:
The IR is split into subgraphs, each assigned to a specific machine or hardware.
Communication operations (Send/Receive) are added where machines exchange data
Machine A:
%output1 = Conv2D(%input, %weights1)
Send(%output1)
Machine B:
%input2 = Receive()
%output2 = Conv2D(%input2, %weights2)
Send(%output2)
3.3 Backend Compilation:
Each subgraph is compiled into device-specific code:
CUDA Kernels: For GPUs.
LLVM IR: For CPUs.
Example CUDA kernel:
__global__ void conv2d_kernel(float* input, float* weights, float* output) {
// Perform convolution
}
4. Execution Plan and Scheduling
4.1 Scheduling:
Execution of subgraphs is coordinated to respect data dependencies.
Example:
Machine A computes conv1 and sends the result to Machine B.
Machine B waits for Machine A’s output, computes conv2, and sends it to Machine C.
4.2 Data Parallelism and Overlap:
For efficiency, computation and communication can overlap:
While Machine A processes batch 2, Machine B processes batch 1.
5. Runtime Execution
5.1 Runtime Environment:
Framework-specific runtime handles execution:
TensorFlow uses the TF Runtime and XLA Compiler.
PyTorch uses its autograd engine and torch.distributed.
5.2 Communication Between Machines:
Machines exchange intermediate tensors via:
NCCL: For GPU-to-GPU communication.
gRPC/MPI: For machine-to-machine communication.
5.3 Monitoring and Profiling:
Tools like TensorBoard, Nsight, or Profiler monitor execution to debug and optimize.