]> pd.if.org Git - nbds/blob - struct/skiplist.c
add references to literature for skiplist
[nbds] / struct / skiplist.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * C implementation of the lock-free skiplist data-structure created by Maurice Herlihy, 
6  * Yossi Lev, and Nir Shavit. See "The Art of Multiprocessor Programming"
7  * http://www.amazon.com/Art-Multiprocessor-Programming-Maurice-Herlihy/dp/0123705916/
8  *
9  * See also Kir Fraser's dissertation "Practical Lock Freedom"
10  * www.cl.cam.ac.uk/techreports/UCAM-CL-TR-579.pdf
11  *
12  */
13 #include <stdio.h>
14 #include <string.h>
15
16 #include "common.h"
17 #include "struct.h"
18 #include "mem.h"
19
20 #define MAX_LEVEL 3
21
22 typedef struct node {
23     uint64_t key;
24     uint64_t value;
25     int top_level;
26     struct node *next[];
27 } node_t;
28
29 typedef struct skiplist {
30     node_t *head;
31     node_t *last;
32     int top_level;
33 } skiplist_t;
34
35 static int random_level (int r) {
36     if (r&1)
37         return 0;
38     int n = __builtin_ctz(r);
39     if (n < MAX_LEVEL)
40         return n;
41     return MAX_LEVEL;
42 }
43
44 node_t *node_alloc (int top_level, uint64_t key, uint64_t value) {
45     assert(top_level >= 0 && top_level <= MAX_LEVEL);
46     size_t sz = sizeof(node_t) + (top_level + 1) * sizeof(node_t *);
47     node_t *item = (node_t *)nbd_malloc(sz);
48     memset(item, 0, sz);
49     item->key   = key;
50     item->value = value;
51     item->top_level = top_level;
52     return item;
53 }
54
55 skiplist_t *skiplist_alloc (void) {
56     skiplist_t *skiplist = (skiplist_t *)nbd_malloc(sizeof(skiplist_t));
57     skiplist->head = node_alloc(MAX_LEVEL, 0, 0);
58     skiplist->last = node_alloc(MAX_LEVEL, (uint64_t)-1, 0);
59     for (int level = 0; level <= MAX_LEVEL; ++level) {
60         skiplist->head->next[level] = skiplist->last;
61     }
62     return skiplist;
63 }
64
65 static node_t *find_preds (node_t **preds, skiplist_t *skiplist, uint64_t key, int help_remove) {
66     node_t *pred = skiplist->head;
67     node_t *item = NULL;
68     TRACE("s3", "find_preds: searching for key %p in skiplist (head is %p)", key, pred);
69 #ifndef NDEBUG
70     int count = 0;
71 #endif
72
73     // Traverse the levels of the skiplist from the top level to the bottom
74     for (int level = MAX_LEVEL; level >= 0; --level) {
75         TRACE("s3", "find_preds: level %llu", level, 0);
76         item = pred->next[level];
77         if (IS_TAGGED(item)) {
78             TRACE("s3", "find_preds: pred %p is marked for removal (item %p); retry", pred, item);
79             return find_preds(preds, skiplist, key, help_remove); // retry
80         }
81         do {
82             node_t *next = item->next[level];
83             TRACE("s3", "find_preds: visiting item %p (next %p)", item, next);
84             TRACE("s3", "find_preds: key %p", item->key, 0);
85
86             // Marked items are logically removed, but not fully unlinked yet.
87             while (EXPECT_FALSE(IS_TAGGED(next))) {
88
89                 // Skip over partially removed items.
90                 if (!help_remove) {
91                     item = (node_t *)STRIP_TAG(item->next);
92                     next = item->next[level];
93                     continue;
94                 }
95
96                 // Unlink partially removed items.
97                 node_t *other;
98                 if ((other = SYNC_CAS(&pred->next[level], item, STRIP_TAG(next))) == item) {
99                     item = (node_t *)STRIP_TAG(next);
100                     next = item->next[level];
101                     TRACE("s3", "find_preds: unlinked item %p from pred %p", item, pred);
102                     TRACE("s3", "find_preds: now item is %p next is %p", item, next);
103
104                     // The thread that completes the unlink should free the memory.
105                     if (level == 0) { nbd_defer_free(other); }
106                 } else {
107                     TRACE("s3", "find_preds: lost race to unlink item from pred %p; its link changed to %p",pred,other);
108                     if (IS_TAGGED(other))
109                         return find_preds(preds, skiplist, key, help_remove); // retry
110                     item = other;
111                     next = item->next[level];
112                 }
113             }
114
115             // If we reached the key (or passed where it should be), we found a pred. Save it and continue down.
116             if (item->key >= key) {
117                 TRACE("s3", "find_preds: found pred %p item %p", pred, item);
118                 if (preds != NULL) {
119                     preds[level] = pred;
120                 }
121                 break;
122             }
123
124             assert(count++ < 18);
125             pred = item;
126             item = next;
127
128         } while (1);
129     }
130     return item;
131 }
132
133 // Fast find that does not help unlink partially removed nodes and does not return the node's predecessors.
134 uint64_t skiplist_lookup (skiplist_t *skiplist, uint64_t key) {
135     TRACE("s3", "skiplist_lookup: searching for key %p in skiplist %p", key, skiplist);
136     node_t *item = find_preds(NULL, skiplist, key, FALSE);
137
138     // If we found an <item> matching the <key> return its value.
139     return (item->key == key) ? item->value : DOES_NOT_EXIST;
140 }
141
142 // Insert the <key> if it doesn't already exist in the <skiplist>
143 uint64_t skiplist_add_i (int n, skiplist_t *skiplist, uint64_t key, uint64_t value) {
144     TRACE("s3", "skiplist_add: inserting key %p value %p", key, value);
145     node_t *preds[MAX_LEVEL+1];
146     node_t *item = NULL;
147     do {
148         node_t *next = find_preds(preds, skiplist, key, TRUE);
149
150         // If a node matching <key> already exists in the skiplist, return its value.
151         if (next->key == key) {
152             TRACE("s3", "skiplist_add: there is already an item %p (value %p) with the same key", next, next->value);
153             if (EXPECT_FALSE(item != NULL)) { nbd_free(item); }
154             return next->value;
155         }
156
157         // First insert <item> into the bottom level.
158         if (EXPECT_TRUE(item == NULL)) { item = node_alloc(random_level(n), key, value); }
159         TRACE("s3", "skiplist_add: attempting to insert item between %p and %p", preds[0], next);
160         item->next[0] = next;
161         for (int level = 1; level <= item->top_level; ++level) {
162             item->next[level] = preds[level]->next[level];
163         }
164         node_t *other = SYNC_CAS(&preds[0]->next[0], next, item);
165         if (other == next) {
166             TRACE("s3", "skiplist_add: successfully inserted item %p at level 0", item, 0);
167             break; // success
168         }
169         TRACE("s3", "skiplist_add: failed to change pred's link: expected %p found %p", next, other);
170
171     } while (1);
172
173     // Insert <item> into the skiplist from the bottom level up.
174     for (int level = 1; level <= item->top_level; ++level) {
175         do {
176             node_t *pred = preds[level];
177             node_t *next = pred->next[level];
178             while (EXPECT_FALSE(IS_TAGGED(next) || next->key < key)) {
179                 find_preds(preds, skiplist, key, TRUE);
180                 pred = preds[level];
181                 next = pred->next[level];
182             }
183
184             do {
185                 // There in no need to continue linking in the item if another thread removed it.
186                 node_t *old_next = ((volatile node_t *)item)->next[level];
187                 if (IS_TAGGED(old_next))
188                     return DOES_NOT_EXIST; // success
189
190                 // Use a CAS so we to not inadvertantly remove a mark another thread placed on the item.
191                 if (next == old_next || SYNC_CAS(&item->next[level], old_next, next) == old_next)
192                     break;
193             } while (1);
194
195             TRACE("s3", "skiplist_add: attempting to insert item between %p and %p", pred, next);
196             node_t *other = SYNC_CAS(&pred->next[level], next, item);
197             if (other == next) {
198                 TRACE("s3", "skiplist_add: successfully inserted item %p at level %llu", item, level);
199                 break; // success
200             }
201             TRACE("s3", "skiplist_add: failed to change pred's link: expected %p found %p", next, other);
202
203         } while (1);
204     }
205     return value;
206 }
207
208 uint64_t skiplist_remove (skiplist_t *skiplist, uint64_t key) {
209     TRACE("s3", "skiplist_remove: removing item with key %p from skiplist %p", key, skiplist);
210     node_t *preds[MAX_LEVEL+1];
211     node_t *item = find_preds(preds, skiplist, key, TRUE);
212     if (item->key != key) {
213         TRACE("s3", "skiplist_remove: remove failed, an item with a matching key does not exist in the skiplist", 0, 0);
214         return DOES_NOT_EXIST;
215     }
216
217     // Mark <item> removed at each level of the skiplist from the top down. This must be atomic. If multiple threads
218     // try to remove the same item only one of them should succeed. Marking the bottom level establishes which of 
219     // them succeeds.
220     for (int level = item->top_level; level >= 0; --level) {
221         if (EXPECT_FALSE(IS_TAGGED(item->next[level]))) {
222             TRACE("s3", "skiplist_remove: %p is already marked for removal by another thread", item, 0);
223             if (level == 0)
224                 return DOES_NOT_EXIST;
225             continue;
226         }
227         node_t *next = SYNC_FETCH_AND_OR(&item->next[level], TAG);
228         if (EXPECT_FALSE(IS_TAGGED(next))) {
229             TRACE("s3", "skiplist_remove: lost race -- %p is already marked for removal by another thread", item, 0);
230             if (level == 0)
231                 return DOES_NOT_EXIST;
232             continue;
233         }
234     }
235
236     uint64_t value = item->value;
237
238     // Unlink <item> from the top down.
239     int level = item->top_level;
240     while (level >= 0) {
241         node_t *pred = preds[level];
242         node_t *next = item->next[level];
243         TRACE("s3", "skiplist_remove: link item's pred %p to it's successor %p", pred, STRIP_TAG(next));
244         node_t *other = NULL;
245         if ((other = SYNC_CAS(&pred->next[level], item, STRIP_TAG(next))) != item) {
246             TRACE("s3", "skiplist_remove: unlink failed; pred's link changed from %p to %p", item, other);
247             // By marking the item earlier, we logically removed it. It is safe to leave the item partially
248             // unlinked. Another thread will finish physically removing it from the skiplist.
249             return value;
250         }
251         --level; 
252     }
253
254     // The thread that completes the unlink should free the memory.
255     nbd_defer_free(item); 
256     return value;
257 }
258
259 void skiplist_print (skiplist_t *skiplist) {
260     for (int level = MAX_LEVEL; level >= 0; --level) {
261         printf("(%d) ", level);
262         node_t *item = skiplist->head;
263         while (item) {
264             node_t *next = item->next[level];
265             printf("%s%p:0x%llx ", IS_TAGGED(next) ? "*" : "", item, item->key);
266             item = (node_t *)STRIP_TAG(next);
267         }
268         printf("\n");
269         fflush(stdout);
270     }
271
272     node_t *item = skiplist->head;
273     while (item) {
274         assert(item->top_level <= MAX_LEVEL);
275         int is_marked = IS_TAGGED(item->next[0]);
276         printf("%s%p:0x%llx (%d", is_marked ? "*" : "", item, item->key, item->top_level);
277         for (int i = 1; i <= item->top_level; ++i) {
278             node_t *next = (node_t *)STRIP_TAG(item->next[i]);
279             printf(" %p:0x%llx", item->next[i], next ? next->key : 0);
280         }
281         printf(")\n");
282         fflush(stdout);
283         item = (node_t *)STRIP_TAG(item->next[0]);
284     }
285 }
286
287 #ifdef MAKE_skiplist_test
288 #include <errno.h>
289 #include <pthread.h>
290 #include <sys/time.h>
291
292 #include "runtime.h"
293
294 #define NUM_ITERATIONS 10000000
295
296 static volatile int wait_;
297 static long num_threads_;
298 static skiplist_t *skiplist_;
299
300 void *worker (void *arg) {
301     int id = (int)(size_t)arg;
302
303     unsigned int rand_seed = id+1;//rdtsc_l();
304
305     // Wait for all the worker threads to be ready.
306     SYNC_ADD(&wait_, -1);
307     do {} while (wait_); 
308
309     for (int i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
310         int n = rand_r(&rand_seed);
311         int key = (n & 0xF) + 1;
312         if (n & (1 << 8)) {
313             skiplist_add_i(n, skiplist_, key, 1);
314         } else {
315             skiplist_remove(skiplist_, key);
316         }
317
318         rcu_update();
319     }
320
321     return NULL;
322 }
323
324 int main (int argc, char **argv) {
325     nbd_init();
326     lwt_set_trace_level("s3");
327
328     char* program_name = argv[0];
329     pthread_t thread[MAX_NUM_THREADS];
330
331     if (argc > 2) {
332         fprintf(stderr, "Usage: %s num_threads\n", program_name);
333         return -1;
334     }
335
336     num_threads_ = 2;
337     if (argc == 2)
338     {
339         errno = 0;
340         num_threads_ = strtol(argv[1], NULL, 10);
341         if (errno) {
342             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
343             return -1;
344         }
345         if (num_threads_ <= 0) {
346             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
347             return -1;
348         }
349         if (num_threads_ > MAX_NUM_THREADS) {
350             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
351             return -1;
352         }
353     }
354
355     skiplist_ = skiplist_alloc();
356
357     struct timeval tv1, tv2;
358     gettimeofday(&tv1, NULL);
359
360     wait_ = num_threads_;
361
362     for (int i = 0; i < num_threads_; ++i) {
363         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
364         if (rc != 0) { perror("pthread_create"); return rc; }
365     }
366
367     for (int i = 0; i < num_threads_; ++i) {
368         pthread_join(thread[i], NULL);
369     }
370
371     gettimeofday(&tv2, NULL);
372     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
373     printf("Th:%ld Time:%dms\n", num_threads_, ms);
374     skiplist_print(skiplist_);
375     lwt_dump("lwt.out");
376
377     return 0;
378 }
379 #endif//skiplist_test