]> pd.if.org Git - nbds/blob - struct/skiplist.c
5e9066c280165b078ddc3cbd5a7613d442df4502
[nbds] / struct / skiplist.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Implementation of the lock-free skiplist data-structure created by Maurice Herlihy, Yossi Lev,
6  * and Nir Shavit. See Herlihy's and Shivit's book "The Art of Multiprocessor Programming".
7  * http://www.amazon.com/Art-Multiprocessor-Programming-Maurice-Herlihy/dp/0123705916/
8  *
9  * See also Kir Fraser's dissertation "Practical Lock Freedom".
10  * www.cl.cam.ac.uk/techreports/UCAM-CL-TR-579.pdf
11  *
12  * This code is written for the x86 memory-model. The algorithim depends on certain stores and
13  * loads being ordered. Be careful, this code probably won't work correctly on platforms with
14  * weaker memory models if you don't add memory barriers in the right places.
15  */
16 #include <stdio.h>
17 #include <string.h>
18
19 #include "common.h"
20 #include "runtime.h"
21 #include "struct.h"
22 #include "mem.h"
23 #include "tls.h"
24
25 // Setting MAX_LEVEL to 0 essentially makes this data structure the Harris-Michael lock-free list
26 // in list.c
27 #define MAX_LEVEL 31
28
29 typedef struct node {
30     uint64_t key;
31     uint64_t value;
32     int top_level;
33     struct node *next[];
34 } node_t;
35
36 typedef struct skiplist {
37     node_t *head;
38 } skiplist_t;
39
40 static int random_level (void) {
41     unsigned r = nbd_rand();
42     if (r&1)
43         return 0;
44     int n = __builtin_ctz(r)-1;
45 #if MAX_LEVEL < 31
46     if (n > MAX_LEVEL)
47         return MAX_LEVEL;
48 #endif
49     assert(n <= MAX_LEVEL);
50     return n;
51 }
52
53 node_t *node_alloc (int level, uint64_t key, uint64_t value) {
54     assert(level >= 0 && level <= MAX_LEVEL);
55     size_t sz = sizeof(node_t) + (level + 1) * sizeof(node_t *);
56     node_t *item = (node_t *)nbd_malloc(sz);
57     memset(item, 0, sz);
58     item->key   = key;
59     item->value = value;
60     item->top_level = level;
61     return item;
62 }
63
64 skiplist_t *sl_alloc (void) {
65     skiplist_t *skiplist = (skiplist_t *)nbd_malloc(sizeof(skiplist_t));
66     skiplist->head = node_alloc(MAX_LEVEL, 0, 0);
67     memset(skiplist->head->next, 0, (MAX_LEVEL+1) * sizeof(skiplist_t *));
68     return skiplist;
69 }
70
71 static node_t *find_preds (node_t *preds[MAX_LEVEL+1], int n, skiplist_t *skiplist, uint64_t key, int help_remove) {
72     node_t *pred = skiplist->head;
73     node_t *item = NULL;
74     TRACE("s3", "find_preds: searching for key %p in skiplist (head is %p)", key, pred);
75
76     // Optimization for small lists. No need to traverse empty higher levels.
77     assert(MAX_LEVEL > 2);
78     int start_level = 2;
79     while (pred->next[start_level+1] != NULL) {
80         start_level += start_level - 1;
81         if (EXPECT_FALSE(start_level >= MAX_LEVEL)) {
82             start_level = MAX_LEVEL;
83             break;
84         }
85     }
86     if (EXPECT_FALSE(start_level < n)) {
87         start_level = n;
88     }
89
90     // Traverse the levels of the skiplist from the top level to the bottom
91     for (int level = start_level; level >= 0; --level) {
92         TRACE("s3", "find_preds: level %llu", level, 0);
93         item = pred->next[level];
94         if (EXPECT_FALSE(IS_TAGGED(item))) {
95             TRACE("s3", "find_preds: pred %p is marked for removal (item %p); retry", pred, item);
96             return find_preds(preds, n, skiplist, key, help_remove); // retry
97         }
98         while (item != NULL) {
99             node_t *next = item->next[level];
100             TRACE("s3", "find_preds: visiting item %p (next %p)", item, next);
101             TRACE("s3", "find_preds: key %p", item->key, 0);
102
103             // Marked items are logically removed, but not fully unlinked yet.
104             while (EXPECT_FALSE(IS_TAGGED(next))) {
105
106                 // Skip over partially removed items.
107                 if (!help_remove) {
108                     item = (node_t *)STRIP_TAG(item->next);
109                     if (EXPECT_FALSE(item == NULL))
110                         break;
111                     next = item->next[level];
112                     continue;
113                 }
114
115                 // Unlink partially removed items.
116                 node_t *other;
117                 if ((other = SYNC_CAS(&pred->next[level], item, STRIP_TAG(next))) == item) {
118                     item = (node_t *)STRIP_TAG(next);
119                     if (EXPECT_FALSE(item == NULL))
120                         break;
121                     next = item->next[level];
122                     TRACE("s3", "find_preds: unlinked item %p from pred %p", item, pred);
123                     TRACE("s3", "find_preds: now item is %p next is %p", item, next);
124
125                     // The thread that completes the unlink should free the memory.
126                     if (level == 0) { nbd_defer_free(other); }
127                 } else {
128                     TRACE("s3", "find_preds: lost race to unlink from pred %p; its link changed to %p", pred, other);
129                     if (IS_TAGGED(other))
130                         return find_preds(preds, n, skiplist, key, help_remove); // retry
131                     item = other;
132                     if (EXPECT_FALSE(item == NULL))
133                         break;
134                     next = item->next[level];
135                 }
136             }
137
138             if (EXPECT_FALSE(item == NULL))
139                 break;
140
141             // If we reached the key (or passed where it should be), we found a pred. Save it and continue down.
142             if (item->key >= key) {
143                 TRACE("s3", "find_preds: found pred %p item %p", pred, item);
144                 break;
145             }
146
147             pred = item;
148             item = next;
149         }
150         if (preds != NULL) {
151             preds[level] = pred;
152         }
153     }
154     if (n == -1 && item != NULL) {
155         assert(preds != NULL);
156         for (int level = start_level + 1; level <= item->top_level; ++level) {
157             preds[level] = skiplist->head;
158         }
159     }
160     return item;
161 }
162
163 // Fast find that does not help unlink partially removed nodes and does not return the node's predecessors.
164 uint64_t sl_lookup (skiplist_t *skiplist, uint64_t key) {
165     TRACE("s3", "sl_lookup: searching for key %p in skiplist %p", key, skiplist);
166     node_t *item = find_preds(NULL, 0, skiplist, key, FALSE);
167
168     // If we found an <item> matching the <key> return its value.
169     return (item && item->key == key) ? item->value : DOES_NOT_EXIST;
170 }
171
172 // Insert the <key> if it doesn't already exist in the <skiplist>
173 uint64_t sl_add (skiplist_t *skiplist, uint64_t key, uint64_t value) {
174     TRACE("s3", "sl_add: inserting key %p value %p", key, value);
175     node_t *preds[MAX_LEVEL+1];
176     node_t *item = NULL;
177     do {
178         int n = random_level();
179         node_t *next = find_preds(preds, n, skiplist, key, TRUE);
180
181         // If a node matching <key> already exists in the skiplist, return its value.
182         if (next != NULL && next->key == key) {
183             TRACE("s3", "sl_add: there is already an item %p (value %p) with the same key", next, next->value);
184             if (EXPECT_FALSE(item != NULL)) { nbd_free(item); }
185             return next->value;
186         }
187
188         // First insert <item> into the bottom level.
189         if (EXPECT_TRUE(item == NULL)) { item = node_alloc(n, key, value); }
190         TRACE("s3", "sl_add: attempting to insert item between %p and %p", preds[0], next);
191         item->next[0] = next;
192         for (int level = 1; level <= item->top_level; ++level) {
193             node_t *pred = preds[level];
194             item->next[level] = pred->next[level];
195         }
196         node_t *pred = preds[0];
197         node_t *other = SYNC_CAS(&pred->next[0], next, item);
198         if (other == next) {
199             TRACE("s3", "sl_add: successfully inserted item %p at level 0", item, 0);
200             break; // success
201         }
202         TRACE("s3", "sl_add: failed to change pred's link: expected %p found %p", next, other);
203
204     } while (1);
205
206     // Insert <item> into the skiplist from the bottom level up.
207     for (int level = 1; level <= item->top_level; ++level) {
208         do {
209             node_t *pred;
210             node_t *next;
211             do {
212                 pred = preds[level];
213                 next = pred->next[level];
214                 if (next == NULL) // item goes at the end of the list
215                     break;
216                 if (!IS_TAGGED(next) && next->key > key) // pred's link changed
217                     break;
218                 find_preds(preds, item->top_level, skiplist, key, TRUE);
219             } while (1);
220
221             do {
222                 // There in no need to continue linking in the item if another thread removed it.
223                 node_t *old_next = ((volatile node_t *)item)->next[level];
224                 if (IS_TAGGED(old_next))
225                     return DOES_NOT_EXIST; // success
226
227                 // Use a CAS so we to not inadvertantly remove a mark another thread placed on the item.
228                 if (next == old_next || SYNC_CAS(&item->next[level], old_next, next) == old_next)
229                     break;
230             } while (1);
231
232             TRACE("s3", "sl_add: attempting to insert item between %p and %p", pred, next);
233             node_t *other = SYNC_CAS(&pred->next[level], next, item);
234             if (other == next) {
235                 TRACE("s3", "sl_add: successfully inserted item %p at level %llu", item, level);
236                 break; // success
237             }
238             TRACE("s3", "sl_add: failed to change pred's link: expected %p found %p", next, other);
239
240         } while (1);
241     }
242     return value;
243 }
244
245 uint64_t sl_remove (skiplist_t *skiplist, uint64_t key) {
246     TRACE("s3", "sl_remove: removing item with key %p from skiplist %p", key, skiplist);
247     node_t *preds[MAX_LEVEL+1];
248     node_t *item = find_preds(preds, -1, skiplist, key, TRUE);
249     if (item == NULL || item->key != key) {
250         TRACE("s3", "sl_remove: remove failed, an item with a matching key does not exist in the skiplist", 0, 0);
251         return DOES_NOT_EXIST;
252     }
253
254     // Mark <item> removed at each level of the skiplist from the top down. This must be atomic. If multiple threads
255     // try to remove the same item only one of them should succeed. Marking the bottom level establishes which of 
256     // them succeeds.
257     for (int level = item->top_level; level >= 0; --level) {
258         if (EXPECT_FALSE(IS_TAGGED(item->next[level]))) {
259             TRACE("s3", "sl_remove: %p is already marked for removal by another thread", item, 0);
260             if (level == 0)
261                 return DOES_NOT_EXIST;
262             continue;
263         }
264         node_t *next = SYNC_FETCH_AND_OR(&item->next[level], TAG);
265         if (EXPECT_FALSE(IS_TAGGED(next))) {
266             TRACE("s3", "sl_remove: lost race -- %p is already marked for removal by another thread", item, 0);
267             if (level == 0)
268                 return DOES_NOT_EXIST;
269             continue;
270         }
271     }
272
273     uint64_t value = item->value;
274
275     // Unlink <item> from the top down.
276     int level = item->top_level;
277     while (level >= 0) {
278         node_t *pred = preds[level];
279         node_t *next = item->next[level];
280         TRACE("s3", "sl_remove: link item's pred %p to it's successor %p", pred, STRIP_TAG(next));
281         node_t *other = NULL;
282         if ((other = SYNC_CAS(&pred->next[level], item, STRIP_TAG(next))) != item) {
283             TRACE("s3", "sl_remove: unlink failed; pred's link changed from %p to %p", item, other);
284             // By marking the item earlier, we logically removed it. It is safe to leave the item partially
285             // unlinked. Another thread will finish physically removing it from the skiplist.
286             return value;
287         }
288         --level; 
289     }
290
291     // The thread that completes the unlink should free the memory.
292     nbd_defer_free(item); 
293     return value;
294 }
295
296 void sl_print (skiplist_t *skiplist) {
297     for (int level = MAX_LEVEL; level >= 0; --level) {
298         node_t *item = skiplist->head;
299         if (item->next[level] == NULL)
300             continue;
301         printf("(%d) ", level);
302         while (item) {
303             node_t *next = item->next[level];
304             printf("%s%p ", IS_TAGGED(next) ? "*" : "", item);
305             item = (node_t *)STRIP_TAG(next);
306         }
307         printf("\n");
308         fflush(stdout);
309     }
310
311     printf("\n");
312     node_t *item = skiplist->head;
313     while (item) {
314         int is_marked = IS_TAGGED(item->next[0]);
315         printf("%s%p:0x%llx [%d]", is_marked ? "*" : "", item, item->key, item->top_level);
316         for (int level = 1; level <= item->top_level; ++level) {
317             node_t *next = (node_t *)STRIP_TAG(item->next[level]);
318             is_marked = IS_TAGGED(item->next[0]);
319             printf(" %p%s", next, is_marked ? "*" : "");
320             if (item == skiplist->head && item->next[level] == NULL)
321                 break;
322         }
323         printf("\n");
324         fflush(stdout);
325         item = (node_t *)STRIP_TAG(item->next[0]);
326     }
327 }
328
329 #ifdef MAKE_skiplist_test
330 #include <errno.h>
331 #include <pthread.h>
332 #include <sys/time.h>
333
334 #include "runtime.h"
335
336 #define NUM_ITERATIONS 10000000
337
338 static volatile int wait_;
339 static long num_threads_;
340 static skiplist_t *sl_;
341
342 void *worker (void *arg) {
343
344     // Wait for all the worker threads to be ready.
345     SYNC_ADD(&wait_, -1);
346     do {} while (wait_); 
347
348     for (int i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
349         unsigned r = nbd_rand();
350         int key = (r & 0xF) + 1;
351         if (r & (1 << 8)) {
352             sl_add(sl_, key, 1);
353         } else {
354             sl_remove(sl_, key);
355         }
356
357         rcu_update();
358     }
359
360     return NULL;
361 }
362
363 int main (int argc, char **argv) {
364     nbd_init();
365     lwt_set_trace_level("s3");
366
367     char* program_name = argv[0];
368     pthread_t thread[MAX_NUM_THREADS];
369
370     if (argc > 2) {
371         fprintf(stderr, "Usage: %s num_threads\n", program_name);
372         return -1;
373     }
374
375     num_threads_ = 2;
376     if (argc == 2)
377     {
378         errno = 0;
379         num_threads_ = strtol(argv[1], NULL, 10);
380         if (errno) {
381             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
382             return -1;
383         }
384         if (num_threads_ <= 0) {
385             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
386             return -1;
387         }
388         if (num_threads_ > MAX_NUM_THREADS) {
389             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
390             return -1;
391         }
392     }
393
394     sl_ = sl_alloc();
395
396     struct timeval tv1, tv2;
397     gettimeofday(&tv1, NULL);
398
399     wait_ = num_threads_;
400
401     for (int i = 0; i < num_threads_; ++i) {
402         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
403         if (rc != 0) { perror("pthread_create"); return rc; }
404     }
405
406     for (int i = 0; i < num_threads_; ++i) {
407         pthread_join(thread[i], NULL);
408     }
409
410     gettimeofday(&tv2, NULL);
411     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
412     sl_print(sl_);
413     printf("Th:%ld Time:%dms\n", num_threads_, ms);
414
415     return 0;
416 }
417 #endif//skiplist_test