Scaling Giant Model with Google’s GShard

I recently came across an interesting paper from Google (GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding), which presents their work for scaling giant language translation models (with 600B parameters trained on 2048 TPU v3 cores).

I liked this paper because it not only describes the system and model innovations for distributed training, but also discusses how to make it easy for the users to develop distributed model training program. I believe this is a critical but often overlooked problem; existing approaches usually require users write different programs for the local and distributed training of the same model. On the other hand, my experience with scaling AI for distribute Big Data suggests that data scientists prefer not dealing with (or even thinking about) distributed computing; our projects (BigDL and Analytics Zoo) have adopted a simple but effective semantics, which always takes a single node program and scales it out in a data parallel fashion.

In contrast, GShard provides a more flexible semantics through annotation; the user can simply write a single node version of the model, and then add annotations (on a subset of critical tensors) to specify the parallel execution policy. The underlying compiler (based on XLA) can then automatically partition the computation graph and add appropriate communication operations, so as to train the original model in a distributed fashion.

GShard annotations for parallel execution include:

  • Replication: replicating the tensor across different partitions. This can be used to replicate the weight for data-parallel training.
  • Sharding (or splitting): partitioning the tensor into different shards placed on different devices. This can support some form of model-parallelism (i.e., partitioning a node or layer in the computation graph across different devices).

The flexibility of GShard annotations can be very useful for several challenging scenarios in the simple data-parallel paradigm, such as supporting extremely large models (often creating bottlenecks in memory capacity and/or network bandwidth) and complex evaluation metrics (often requiring manually rewriting the logic in a data parallel fashion). On the other hand, sometimes it does require the users to understand the partitioning details (e.g., specifying which tensor to shard and along which dimension to partition).

In addition, the paper describes innovations in designing neural network architecture through conditional computation (i.e., having a sub-network activated for each input). In the paper, it replaces the feed-forward layer of the Transformer by Mixture of Expert layer in its giant language translation model; as a result, the compiler can efficiently partition this layer across multiple devices, and achieve sub-linear scaling (i.e., computation demand grows slower than the number of parameters) on the sparsely scaled Transformer-based translation model.

The paper also shows how to spatially shard a convolutional layer, which however involves much more complex transformations. And while it is not discussed in the paper, it will also be useful to understand the effectiveness of sharding on models that do not have the sparse properties (such as conventional Transformer or CNN models)

If you are interested in distributed deep learning training, designing giant neural nets, or both, this paper will be a very good read.

Gertjan "GJ" De Wilde

Building the #1 Realtime Unified API

4 年

Thanks! Interesting read.

回复

要查看或添加评论,请登录

Jason (Jinquan) Dai的更多文章

社区洞察

其他会员也浏览了