Distributed TensorFlow: Scaling Google’s Deep Learning Library on Spark

Written by Christopher Smith, Ushnish De and Christopher Nguyen


Arimo’s growing data science team’s charter includes researching and developing new methods and applications for machine learning and deep learning.

One theme we’re investigating is distributed deep learning. The value and variety of patterns and predictions that can be found with deep learning compounds when datasets and models are very large. However, training large models can be slow or difficult if the data or model does not fit one machine’s memory. When Google open sourced their TensorFlow deep learning library we were excited to see if we could run TensorFlow in the distributed Spark environment.

This post is an overview of the Arimo’s presentation titled: Distributed Tensor Flow on Spark: Scaling Google’s Deep Learning Library presented at Spark Summit East 2016. Watch the talk video | Slides

Watch for future posts where we’ll be investigating Google’s newly released distributed version of Tensorflow.


Late in 2015, Google open sourced their deep learning library TensorFlow to much fanfare. At Google, TensorFlow is in production for a variety of production applications from search to maps to translations, so the library has been extensively tested at scale. A Google whitepaper describes the various systems considerations that went into TensorFlow’s design. The open source library only contains the single-machine implementation, possibly due to the full distributed version’s dependence on Google’s infrastructure. (Note: Google has now open-sourced their distributed code)

To investigate the capabilities of a distributed version of TensorFlow, we adapted the single machine version of TensorFlow to Apache Spark. Apache Spark brings in-memory computing and scalable distributed computing. Our goal is to answer: “Can TensorFlow and Apache Spark work together to deliver a powerful, scalable distributed deep-learning platform?”


So how did we do it?

TensorFlow has a Python API for building computational graphs, which are then executed in C++ and Spark is written in Scala and has a Java and a Python API, so it was natural to choose Python as the implementation language.

The essential question for TensorFlow on Spark is how to distribute training of neural networks on Spark. Spark is excellent for iterated map-reduce problems, but training neural nets is not a map-reduce problem (See 13 dwarves of parallel computing). Taking a cue from Google, we imitated their DownpourSGD architecture. The DownpourSGD architecture is a data-parallel setup, meaning each worker has the entire model and is operating on data different from the other workers (data-parallel), as opposed to having different parts of the model on different machines (model-parallel).

WP TestWP Test

Figure 1: On the left is the diagram from Google’s Distbelief paper,
on the right is a schematic of our implementation

Essentially, we take the gradient descent method, split it into two—“Compute Gradient” followed by “Apply Gradients (Descent)” and insert a network boundary between them.

WP Test

Figure 2: Schematic of the role played by driver and workers.
The driver and workers send and receive data via websockets.

The Spark workers are computing gradients asynchronously, periodically sending their gradients back to the driver (parameter server), which combines all the worker’s gradients and sends the resulting parameters back to the workers as the workers ask for them.

The name “DownpourSGD” comes from the intuition that if we view gradient descent as a water droplet flowing down the error-hill, then the asynchronicity of the workers suggests several water droplets near each other all flowing downhill into the same valley along different paths.

WP Test

Experimental Setup

We tested the scaling capabilities of the implementation across the following dimensions

  1. Dataset Size
  2. Model Size
  3. Compute Size
  4. CPU vs GPU

Dataset Size

Dataset NameRowsColumnsCellsDescription
MNIST60,000784>39.2M60K 28×28 images
Molecular150,0002,871430M https://www.kaggle.com/c/MerckActivity

Notice that between the larger two datasets, the Higgs dataset is much larger in terms of number of rows while the Molecular set is larger in total number of data points (rows x columns). This leads to interesting results during model training.

Model Size

We used feed forward networks for each of our models, mainly for simplicity of computing model size, i.e. total number of weight and bias variables

DatasetHidden LayersHidden UnitsTotal Parameters

Compute size

We had a single machine configuration which had 12 CPU cores as well as an AWS cluster of 24 nodes with 4 CPUs and no GPUs. We used Docker containers to ensure all the nodes had the same configuration, the hardware of each node was as follows

  1. 4 cores per machine
  2. 1 executor per machine
  3. 10GB memory per executor
  4. Spark 1.6.0


Lastly we had a cluster of 8 nodes with the same configuration as above but additionally had 1 GPU each. The 12 core single machine configuration also had 4 GPUs which could be turned on as required, however by default TensorFlow only uses one GPU.


Data-parallelism has benefits, but network communications overhead quickly limited scalability.

Scaling the cluster

We demonstrated that TensorFlow on Spark can leverage increasing cluster size, but cluster size gives increased performance only up to a point. It appears that network communications bottlenecks performance and more work is needed to determine if the bottleneck can be overcome.

On a side note, we found that as long as the data is small enough to fit on one machine, one machine with GPU capability outperforms other architectures because it does not have to deal with network overhead issues and leverages the GPU advantages for mathematical operations.

The following figures summarize key test results from the various models we ran. A few key trends are visible:

WP Test

KEY TAKEAWAY: Synchronicity of Workers matters: making the workers asynchronous helps, but if they get too out of sync it can slow or even prevent convergence

WP Test

KEY TAKEAWAY: Increased model size reduces gains made from data parallelism.

WP Test

KEY TAKEAWAY: Network communication varies directly with model size. This gives a predictably proportional slowdown in training speed as model size increases.

WP Test

KEY TAKEAWAY:The system is compute bound until more than ~16 workers (64 cores here) are introduced, when it becomes network bound.

Overall our implementation provided a limited factor of scalability. We’re investigating methods to mitigate these bottlenecks so that training performance can scale with parallelism. This remains an interesting direction for research due to the advantages of using commercial grade Spark technology to distribute TensorFlow.

GPU Performance Testing

WP Test

Figure 4: Training speed (rows/min) across architectures for each dataset

In the last figure, we compare the speed of the various implementations, including our local and distributed GPU implementations. There are essentially two competing forces here, the local implementations have no data-parallelism and no network-overhead. As we look to the right, the datasets get larger in terms of number of rows, we see that local GPU beats local CPU and GPU on cluster eventually beats the cluster without GPU implementation. This corroborates what we already know about GPUs speeding up mathematical operations. However the GPU on local implementation greatly eclipses all other implementations as the dataset becomes larger. Since all the datasets are small enough to fit in memory, the network becomes the bottleneck for the distributed implementations while GPU on local has no such bottleneck and gets the best of both worlds.


Our research goal is to develop a scalable, distributed computing implementation of deep learning. The current project demonstrated that there are interesting possibilities with TensorFlow on Apache Spark, but we have not yet achieved our goal.

  • This implementation is best for large datasets and small models, since the model size is linearly related to network overhead. Large datasets allow us to take full advantage of the data parallelism of the system.
  • For data that fits on single machine, single-machine GPU offers the best speed
  • For large datasets, leverage both speed and scale using Spark cluster with GPUs
  • This project is open sourced: https://github.com/adatao/tensorspark

Future Work

  • For larger models, we’re looking into model parallelism and model compression to improve performance.
  • We are looking into more clever ways of computing the parameter update step, e.g. using coordinate descent or conjugate gradient
  • An important way to scale this architecture out more effectively is to use multiple parameter servers to split up the model, which is in fact described in the TensorFlow as well as the DistBelief white paper. Our current implementation can easily cause the driver to bottleneck


On this project we encountered many challenges that we overcame by research, trial and error and sheer persistence. We’re sharing what we found so you can recreate our work if you’d like.

GPU is faster, but it’s not a panacea

GPU can easily get almost a 10x speedup on local configuration, while GPU on Spark gets a 2-3x speedup because of network overhead. Therefore the Spark configuration starts to make sense once the data no longer fits on one machine’s memory.

However, GPU has its own memory, so GPU memory can become a bottleneck if matrices are too large to fit the somewhat limited typical GPU memory. If your model doesn’t fit on GPU memory, model parallelism may make more sense than data parallelism.

We used AWS GPU instances (NVIDIA K680 cards with added memory – 1,536 cuda-cores, 4GB memory). The GPU instances use NVIDIA compute environment 3.0, which TensorFlow does not natively support. To support the AWS GPUs build TensorFlow from source and put a set the TF_UNOFFICIAL_SETTING flag in the configure step. This build technique does significantly increase the amount of virtual memory demanded by TensorFlow on some systems, so be warned. We discovered this when trying to set –executor-cores=8 , when we got “out of virtual memory” errors. [https://github.com/adatao/tensorspark/blob/master/gpu_install.sh#L17]


  • Since Tensorflow only has a Python and C API, we had to use PySpark to make the implementation work with Spark. This leads to objects being converted between Scala, Java, Python and C++ which is not optimal for performance.
  • TensorFlow objects cannot be pickled, so we recommend serializing them as numpy arrays.
  • Parsing large json is slow, so avoid it.

Error Divergence

Our first implementation involved each worker creating a new TF model with random weights and biases and then starting to communicate with the driver. This quickly showed a problem.

WP Test

Figure 5: Mean Squared Error on test set vs number of training samples seen

We can see how test error blows up after only 14,000 samples. Since the error surface has multiple local optima, having each worker start training from a different starting point may lead to the parameters never converging and errors building up. Going back to our DownpourSGD analogy, think of rain falling on a mountain, some of the drops go down one side and other drops go down another (see Figure below), so there is no convergence on the correct set of parameters.

WP Test

Figure 6: Visualization of cusp between distinct local optima

We found we could mitigate this effect with a warm-up phase, wherein the driver trains on some of the data alone, effectively committing to one side of the mountain. The result is a decreased probability of diverging because of a mountain cusp, as there are finitely many such cusps and warm-up eliminates some of them. [https://en.wikipedia.org/wiki/Morse_theory]

Gradient Atavism

Sign Up for the Arimo Newsletter

Even without mountains to diverge the various gradients, there is still possibility of divergence because the gradients can get too out sync due to high curvature in the error surface. As an analogy, imagine there’s a fast paced meeting and one person is always 10 minutes behind (or ahead of!) everyone else. This person keeps bringing up topics that everyone else thinks are irrelevant and it throws everyone off for a while. When individual workers return obsolete gradients to the driver it can induce temporary spikes in the error, and can even create rather beautiful oscillations. Looking at just the peaks or troughs of these oscillations sometimes gives a very nice exponential curve, suggesting some kind of equation governing this behavior.

Another way to understand this phenomenon is to remember that gradients are vectors that lie in the tangent plane to the error surface, and this tangent plane is only a good approximation of the error surface in a small neighborhood around the point where it touches, so taking too large of a step will result in unexpected results.

WP Test

Figure 7: Example of overdamped oscillating convergence

WP Test

Figure 8: Example of underdamped oscillating convergence.
This is from actual data, but is well approximated by y=(sin(x) + 1).ex/20 + 1

So, you can reduce wobble by

  • Reducing network overhead (by compressing or directly sending binary over network)
  • Reducing minibatch size (so each gradient is computed faster)
  • Reducing learning rate
  • Dropping gradients that take too long to transmit over the network


Lastly, we chose to have one executor and one “executor-core” per machine so that we only had 1 TensorFlow instance on each machine. The reason for this is that TensorFlow generates many threads. While you can tune the quantity of threads TensorFlow uses for intra/inter-operation parallelism, TensorFlow has a fixed overhead of 10 additional threads, so we found in practice that multiple instances of TensorFlow on one machine can cause a lot of thrashing. In the end, we found it most efficient to implement one instance per machine.

Other Notes

For web sockets we used Tornado which is easy to use and efficient.

We explored compressing the parameters and gradients before sending it over via web sockets and while this reduces the network overhead it slows down the overall training process because compression itself takes time. Therefore we ended up simply using numpy.dumps to serialize the data to a string before sending it over web socket and loads to load it back to an array.

We noticed that by default each Spark worker creates a new instance of the TensorFlow model being used by the system, which ends up causing memory leaks because memory is being allocated to these models faster than it can be garbage collected. Therefore we implemented the Borg design pattern which allows all the workers to share the underlying TensorFlow model being used. This is how the pattern is implemented.

WP Test