NonBlockingThreadPool.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
00011 #define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
00012 
00013 
00014 namespace Eigen {
00015 
00016 template <typename Environment>
00017 class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
00018  public:
00019   typedef typename Environment::Task Task;
00020   typedef RunQueue<Task, 1024> Queue;
00021 
00022   NonBlockingThreadPoolTempl(int num_threads, Environment env = Environment())
00023       : env_(env),
00024         threads_(num_threads),
00025         queues_(num_threads),
00026         coprimes_(num_threads),
00027         waiters_(num_threads),
00028         blocked_(0),
00029         spinning_(0),
00030         done_(false),
00031         ec_(waiters_) {
00032     waiters_.resize(num_threads);
00033 
00034     // Calculate coprimes of num_threads.
00035     // Coprimes are used for a random walk over all threads in Steal
00036     // and NonEmptyQueueIndex. Iteration is based on the fact that if we take
00037     // a walk starting thread index t and calculate num_threads - 1 subsequent
00038     // indices as (t + coprime) % num_threads, we will cover all threads without
00039     // repetitions (effectively getting a presudo-random permutation of thread
00040     // indices).
00041     for (int i = 1; i <= num_threads; i++) {
00042       unsigned a = i;
00043       unsigned b = num_threads;
00044       // If GCD(a, b) == 1, then a and b are coprimes.
00045       while (b != 0) {
00046         unsigned tmp = a;
00047         a = b;
00048         b = tmp % b;
00049       }
00050       if (a == 1) {
00051         coprimes_.push_back(i);
00052       }
00053     }
00054     for (int i = 0; i < num_threads; i++) {
00055       queues_.push_back(new Queue());
00056     }
00057     for (int i = 0; i < num_threads; i++) {
00058       threads_.push_back(env_.CreateThread([this, i]() { WorkerLoop(i); }));
00059     }
00060   }
00061 
00062   ~NonBlockingThreadPoolTempl() {
00063     done_ = true;
00064     // Now if all threads block without work, they will start exiting.
00065     // But note that threads can continue to work arbitrary long,
00066     // block, submit new work, unblock and otherwise live full life.
00067     ec_.Notify(true);
00068 
00069     // Join threads explicitly to avoid destruction order issues.
00070     for (size_t i = 0; i < threads_.size(); i++) delete threads_[i];
00071     for (size_t i = 0; i < threads_.size(); i++) delete queues_[i];
00072   }
00073 
00074   void Schedule(std::function<void()> fn) {
00075     Task t = env_.CreateTask(std::move(fn));
00076     PerThread* pt = GetPerThread();
00077     if (pt->pool == this) {
00078       // Worker thread of this pool, push onto the thread's queue.
00079       Queue* q = queues_[pt->thread_id];
00080       t = q->PushFront(std::move(t));
00081     } else {
00082       // A free-standing thread (or worker of another pool), push onto a random
00083       // queue.
00084       Queue* q = queues_[Rand(&pt->rand) % queues_.size()];
00085       t = q->PushBack(std::move(t));
00086     }
00087     // Note: below we touch this after making w available to worker threads.
00088     // Strictly speaking, this can lead to a racy-use-after-free. Consider that
00089     // Schedule is called from a thread that is neither main thread nor a worker
00090     // thread of this pool. Then, execution of w directly or indirectly
00091     // completes overall computations, which in turn leads to destruction of
00092     // this. We expect that such scenario is prevented by program, that is,
00093     // this is kept alive while any threads can potentially be in Schedule.
00094     if (!t.f)
00095       ec_.Notify(false);
00096     else
00097       env_.ExecuteTask(t);  // Push failed, execute directly.
00098   }
00099 
00100   int NumThreads() const final {
00101     return static_cast<int>(threads_.size());
00102   }
00103 
00104   int CurrentThreadId() const final {
00105     const PerThread* pt =
00106         const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
00107     if (pt->pool == this) {
00108       return pt->thread_id;
00109     } else {
00110       return -1;
00111     }
00112   }
00113 
00114  private:
00115   typedef typename Environment::EnvThread Thread;
00116 
00117   struct PerThread {
00118     constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { }
00119     NonBlockingThreadPoolTempl* pool;  // Parent pool, or null for normal threads.
00120     uint64_t rand;  // Random generator state.
00121     int thread_id;  // Worker thread index in pool.
00122   };
00123 
00124   Environment env_;
00125   MaxSizeVector<Thread*> threads_;
00126   MaxSizeVector<Queue*> queues_;
00127   MaxSizeVector<unsigned> coprimes_;
00128   MaxSizeVector<EventCount::Waiter> waiters_;
00129   std::atomic<unsigned> blocked_;
00130   std::atomic<bool> spinning_;
00131   std::atomic<bool> done_;
00132   EventCount ec_;
00133 
00134   // Main worker thread loop.
00135   void WorkerLoop(int thread_id) {
00136     PerThread* pt = GetPerThread();
00137     pt->pool = this;
00138     pt->rand = std::hash<std::thread::id>()(std::this_thread::get_id());
00139     pt->thread_id = thread_id;
00140     Queue* q = queues_[thread_id];
00141     EventCount::Waiter* waiter = &waiters_[thread_id];
00142     for (;;) {
00143       Task t = q->PopFront();
00144       if (!t.f) {
00145         t = Steal();
00146         if (!t.f) {
00147           // Leave one thread spinning. This reduces latency.
00148           // TODO(dvyukov): 1000 iterations is based on fair dice roll, tune it.
00149           // Also, the time it takes to attempt to steal work 1000 times depends
00150           // on the size of the thread pool. However the speed at which the user
00151           // of the thread pool submit tasks is independent of the size of the
00152           // pool. Consider a time based limit instead.
00153           if (!spinning_ && !spinning_.exchange(true)) {
00154             for (int i = 0; i < 1000 && !t.f; i++) {
00155               t = Steal();
00156             }
00157             spinning_ = false;
00158           }
00159           if (!t.f) {
00160             if (!WaitForWork(waiter, &t)) {
00161               return;
00162             }
00163           }
00164         }
00165       }
00166       if (t.f) {
00167         env_.ExecuteTask(t);
00168       }
00169     }
00170   }
00171 
00172   // Steal tries to steal work from other worker threads in best-effort manner.
00173   Task Steal() {
00174     PerThread* pt = GetPerThread();
00175     const size_t size = queues_.size();
00176     unsigned r = Rand(&pt->rand);
00177     unsigned inc = coprimes_[r % coprimes_.size()];
00178     unsigned victim = r % size;
00179     for (unsigned i = 0; i < size; i++) {
00180       Task t = queues_[victim]->PopBack();
00181       if (t.f) {
00182         return t;
00183       }
00184       victim += inc;
00185       if (victim >= size) {
00186         victim -= size;
00187       }
00188     }
00189     return Task();
00190   }
00191 
00192   // WaitForWork blocks until new work is available (returns true), or if it is
00193   // time to exit (returns false). Can optionally return a task to execute in t
00194   // (in such case t.f != nullptr on return).
00195   bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
00196     eigen_assert(!t->f);
00197     // We already did best-effort emptiness check in Steal, so prepare for
00198     // blocking.
00199     ec_.Prewait(waiter);
00200     // Now do a reliable emptiness check.
00201     int victim = NonEmptyQueueIndex();
00202     if (victim != -1) {
00203       ec_.CancelWait(waiter);
00204       *t = queues_[victim]->PopBack();
00205       return true;
00206     }
00207     // Number of blocked threads is used as termination condition.
00208     // If we are shutting down and all worker threads blocked without work,
00209     // that's we are done.
00210     blocked_++;
00211     if (done_ && blocked_ == threads_.size()) {
00212       ec_.CancelWait(waiter);
00213       // Almost done, but need to re-check queues.
00214       // Consider that all queues are empty and all worker threads are preempted
00215       // right after incrementing blocked_ above. Now a free-standing thread
00216       // submits work and calls destructor (which sets done_). If we don't
00217       // re-check queues, we will exit leaving the work unexecuted.
00218       if (NonEmptyQueueIndex() != -1) {
00219         // Note: we must not pop from queues before we decrement blocked_,
00220         // otherwise the following scenario is possible. Consider that instead
00221         // of checking for emptiness we popped the only element from queues.
00222         // Now other worker threads can start exiting, which is bad if the
00223         // work item submits other work. So we just check emptiness here,
00224         // which ensures that all worker threads exit at the same time.
00225         blocked_--;
00226         return true;
00227       }
00228       // Reached stable termination state.
00229       ec_.Notify(true);
00230       return false;
00231     }
00232     ec_.CommitWait(waiter);
00233     blocked_--;
00234     return true;
00235   }
00236 
00237   int NonEmptyQueueIndex() {
00238     PerThread* pt = GetPerThread();
00239     const size_t size = queues_.size();
00240     unsigned r = Rand(&pt->rand);
00241     unsigned inc = coprimes_[r % coprimes_.size()];
00242     unsigned victim = r % size;
00243     for (unsigned i = 0; i < size; i++) {
00244       if (!queues_[victim]->Empty()) {
00245         return victim;
00246       }
00247       victim += inc;
00248       if (victim >= size) {
00249         victim -= size;
00250       }
00251     }
00252     return -1;
00253   }
00254 
00255   static EIGEN_STRONG_INLINE PerThread* GetPerThread() {
00256     EIGEN_THREAD_LOCAL PerThread per_thread_;
00257     PerThread* pt = &per_thread_;
00258     return pt;
00259   }
00260 
00261   static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
00262     uint64_t current = *state;
00263     // Update the internal state
00264     *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
00265     // Generate the random output (using the PCG-XSH-RS scheme)
00266     return static_cast<unsigned>((current ^ (current >> 22)) >> (22 + (current >> 61)));
00267   }
00268 };
00269 
00270 typedef NonBlockingThreadPoolTempl<StlThreadEnvironment> NonBlockingThreadPool;
00271 
00272 }  // namespace Eigen
00273 
00274 #endif  // EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
 All Classes Functions Variables Typedefs Enumerator