124043d2b4363222c711c644a15a599869b86c51
[nbds] / struct / list.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Harris-Michael lock-free list-based set
6  * http://www.research.ibm.com/people/m/michael/spaa-2002.pdf
7  */
8 #include <stdio.h>
9 #include <string.h>
10
11 #include "common.h"
12 #include "struct.h"
13 #include "mem.h"
14
15 typedef struct node {
16     uint64_t key;
17     uint64_t value;
18     struct node *next;
19 } node_t;
20
21 typedef struct list {
22     node_t *head;
23 } list_t;
24
25 node_t *node_alloc (uint64_t key, uint64_t value) {
26     node_t *item = (node_t *)nbd_malloc(sizeof(node_t));
27     memset(item, 0, sizeof(node_t));
28     item->key   = key;
29     item->value = value;
30     return item;
31 }
32
33 list_t *ll_alloc (void) {
34     list_t *list = (list_t *)nbd_malloc(sizeof(list_t));
35     list->head = node_alloc((uint64_t)-1, 0);
36     list->head->next = NULL;
37     return list;
38 }
39
40 static node_t *find_pred (node_t **pred_ptr, list_t *list, uint64_t key, int help_remove) {
41     node_t *pred = list->head;
42     node_t *item = pred->next;
43     TRACE("l3", "find_pred: searching for key %p in list (head is %p)", key, pred);
44
45     while (item != NULL) {
46         node_t *next = item->next;
47         TRACE("l3", "find_pred: visiting item %p (next %p)", item, next);
48         TRACE("l3", "find_pred: key %p", item->key, item->value);
49
50         // Marked items are logically removed, but not unlinked yet.
51         while (EXPECT_FALSE(IS_TAGGED(next))) {
52
53             // Skip over partially removed items.
54             if (!help_remove) {
55                 item = (node_t *)STRIP_TAG(item->next);
56                 if (EXPECT_FALSE(item == NULL))
57                     break;
58                 next = item->next;
59                 continue;
60             }
61
62             // Unlink partially removed items.
63             node_t *other;
64             if ((other = SYNC_CAS(&pred->next, item, STRIP_TAG(next))) == item) {
65                 item = (node_t *)STRIP_TAG(next);
66                 if (EXPECT_FALSE(item == NULL))
67                     break;
68                 next = item->next;
69                 TRACE("l3", "find_pred: unlinked item %p from pred %p", item, pred);
70                 TRACE("l3", "find_pred: now item is %p next is %p", item, next);
71
72                 // The thread that completes the unlink should free the memory.
73                 nbd_defer_free(other);
74             } else {
75                 TRACE("l3", "find_pred: lost race to unlink from pred %p; its link changed to %p", pred, other);
76                 if (IS_TAGGED(other))
77                     return find_pred(pred_ptr, list, key, help_remove); // retry
78                 item = other;
79                 if (EXPECT_FALSE(item == NULL))
80                     break;
81                 next = item->next;
82             }
83         }
84
85         if (EXPECT_FALSE(item == NULL))
86             break;
87
88         // If we reached the key (or passed where it should be), we found the right predesssor
89         if (item->key >= key) {
90             TRACE("l3", "find_pred: found pred %p item %p", pred, item);
91             if (pred_ptr != NULL) {
92                 *pred_ptr = pred;
93             }
94             return item;
95         }
96
97         pred = item;
98         item = next;
99
100     }
101
102     // <key> is not in the list.
103     if (pred_ptr != NULL) {
104         *pred_ptr = pred;
105     }
106     return NULL;
107 }
108
109 // Fast find. Do not help unlink partially removed nodes and do not return the found item's predecessor.
110 uint64_t ll_lookup (list_t *list, uint64_t key) {
111     TRACE("l3", "ll_lookup: searching for key %p in list %p", key, list);
112     node_t *item = find_pred(NULL, list, key, FALSE);
113
114     // If we found an <item> matching the <key> return its value.
115     return (item && item->key == key) ? item->value : DOES_NOT_EXIST;
116 }
117
118 // Insert the <key>, if it doesn't already exist in the <list>
119 uint64_t ll_add (list_t *list, uint64_t key, uint64_t value) {
120     TRACE("l3", "ll_add: inserting key %p value %p", key, value);
121     node_t *pred;
122     node_t *item = NULL;
123     do {
124         node_t *next = find_pred(&pred, list, key, TRUE);
125
126         // If a node matching <key> already exists in the list, return its value.
127         if (next != NULL && next->key == key) {
128             TRACE("l3", "ll_add: there is already an item %p (value %p) with the same key", next, next->value);
129             if (EXPECT_FALSE(item != NULL)) { nbd_free(item); }
130             return next->value;
131         }
132
133         TRACE("l3", "ll_add: attempting to insert item between %p and %p", pred, next);
134         if (EXPECT_TRUE(item == NULL)) { item = node_alloc(key, value); }
135         item->next = next;
136         node_t *other = SYNC_CAS(&pred->next, next, item);
137         if (other == next) {
138             TRACE("l3", "ll_add: successfully inserted item %p", item, 0);
139             return DOES_NOT_EXIST; // success
140         }
141         TRACE("l3", "ll_add: failed to change pred's link: expected %p found %p", next, other);
142
143     } while (1);
144 }
145
146 uint64_t ll_remove (list_t *list, uint64_t key) {
147     TRACE("l3", "ll_remove: removing item with key %p from list %p", key, list);
148     node_t *pred;
149     node_t *item = find_pred(&pred, list, key, TRUE);
150     if (item == NULL || item->key != key) {
151         TRACE("l3", "ll_remove: remove failed, an item with a matching key does not exist in the list", 0, 0);
152         return DOES_NOT_EXIST;
153     }
154
155     // Mark <item> removed. This must be atomic. If multiple threads try to remove the same item
156     // only one of them should succeed.
157     if (EXPECT_FALSE(IS_TAGGED(item->next))) {
158         TRACE("l3", "ll_remove: %p is already marked for removal by another thread", item, 0);
159         return DOES_NOT_EXIST;
160     }
161     node_t *next = SYNC_FETCH_AND_OR(&item->next, TAG);
162     if (EXPECT_FALSE(IS_TAGGED(next))) {
163         TRACE("l3", "ll_remove: lost race -- %p is already marked for removal by another thread", item, 0);
164         return DOES_NOT_EXIST;
165     }
166
167     uint64_t value = item->value;
168
169     // Unlink <item> from the list.
170     TRACE("l3", "ll_remove: link item's pred %p to it's successor %p", pred, next);
171     node_t *other;
172     if ((other = SYNC_CAS(&pred->next, item, next)) != item) {
173         TRACE("l3", "ll_remove: unlink failed; pred's link changed from %p to %p", item, other);
174         // By marking the item earlier, we logically removed it. It is safe to leave the item.
175         // Another thread will finish physically removing it from the list.
176         return value;
177     } 
178
179     // The thread that completes the unlink should free the memory.
180     nbd_defer_free(item); 
181     return value;
182 }
183
184 void ll_print (list_t *list) {
185     node_t *item;
186     item = list->head->next;
187     while (item) {
188         printf("0x%llx ", item->key);
189         fflush(stdout);
190         item = item->next;
191     }
192     printf("\n");
193 }
194
195 #ifdef MAKE_list_test
196 #include <errno.h>
197 #include <pthread.h>
198 #include <sys/time.h>
199
200 #include "runtime.h"
201
202 #define NUM_ITERATIONS 10000000
203
204 static volatile int wait_;
205 static long num_threads_;
206 static list_t *ll_;
207
208 void *worker (void *arg) {
209
210     // Wait for all the worker threads to be ready.
211     SYNC_ADD(&wait_, -1);
212     do {} while (wait_); 
213
214     for (int i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
215         unsigned r = nbd_rand();
216         int key = (r & 0xF);
217         if (r & (1 << 8)) {
218             ll_add(ll_, key, 1);
219         } else {
220             ll_remove(ll_, key);
221         }
222
223         rcu_update();
224     }
225
226     return NULL;
227 }
228
229 int main (int argc, char **argv) {
230     nbd_init();
231     //lwt_set_trace_level("m0l0");
232
233     char* program_name = argv[0];
234     pthread_t thread[MAX_NUM_THREADS];
235
236     if (argc > 2) {
237         fprintf(stderr, "Usage: %s num_threads\n", program_name);
238         return -1;
239     }
240
241     num_threads_ = 2;
242     if (argc == 2)
243     {
244         errno = 0;
245         num_threads_ = strtol(argv[1], NULL, 10);
246         if (errno) {
247             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
248             return -1;
249         }
250         if (num_threads_ <= 0) {
251             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
252             return -1;
253         }
254         if (num_threads_ > MAX_NUM_THREADS) {
255             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
256             return -1;
257         }
258     }
259
260     ll_ = ll_alloc();
261
262     struct timeval tv1, tv2;
263     gettimeofday(&tv1, NULL);
264
265     wait_ = num_threads_;
266
267     for (int i = 0; i < num_threads_; ++i) {
268         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
269         if (rc != 0) { perror("pthread_create"); return rc; }
270     }
271
272     for (int i = 0; i < num_threads_; ++i) {
273         pthread_join(thread[i], NULL);
274     }
275
276     gettimeofday(&tv2, NULL);
277     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
278     ll_print(ll_);
279     printf("Th:%ld Time:%dms\n", num_threads_, ms);
280
281     return 0;
282 }
283 #endif//list_test