Processing math: 100%

2412.14374

Total: 1

#1 Scaling Deep Learning Training with MPMD Pipeline Parallelism [PDF5] [Copy] [Kimi3] [REL]

Authors: Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to 1.11× with respect to the best performing SPMD configuration.

Subjects: Distributed, Parallel, and Cluster Computing , Machine Learning , Programming Languages

Publish: 2024-12-18 22:15:11 UTC