d9e461260c6511bf16190c9d2323f0849b96eec1
[nbds] / struct / list.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Harris-Michael lock-free list-based set
6  * http://www.research.ibm.com/people/m/michael/spaa-2002.pdf
7  */
8 #include <stdio.h>
9 #include <string.h>
10
11 #include "common.h"
12 #include "struct.h"
13 #include "mem.h"
14
15 typedef struct node {
16     uint64_t key;
17     uint64_t value;
18     struct node *next;
19 } node_t;
20
21 typedef struct list {
22     node_t *head;
23     node_t *last;
24 } list_t;
25
26 node_t *node_alloc (uint64_t key, uint64_t value) {
27     node_t *item = (node_t *)nbd_malloc(sizeof(node_t));
28     memset(item, 0, sizeof(node_t));
29     item->key   = key;
30     item->value = value;
31     return item;
32 }
33
34 list_t *ll_alloc (void) {
35     list_t *list = (list_t *)nbd_malloc(sizeof(list_t));
36     list->head = node_alloc(0, 0);
37     list->last = node_alloc((uint64_t)-1, 0);
38     list->head->next = list->last;
39     return list;
40 }
41
42 static node_t *find_pred (node_t **pred_ptr, list_t *list, uint64_t key, int help_remove) {
43     node_t *pred = list->head;
44     node_t *item = pred->next;
45     TRACE("l3", "find_pred: searching for key %p in list (head is %p)", key, pred);
46 #ifndef NDEBUG
47     int count = 0;
48 #endif
49
50     do {
51         node_t *next = item->next;
52         TRACE("l3", "find_pred: visiting item %p (next %p)", item, next);
53         TRACE("l3", "find_pred: key %p", item->key, item->value);
54
55         // Marked items are logically removed, but not unlinked yet.
56         while (EXPECT_FALSE(IS_TAGGED(next))) {
57
58             // Skip over partially removed items.
59             if (!help_remove) {
60                 item = (node_t *)STRIP_TAG(item->next);
61                 next = item->next;
62                 continue;
63             }
64
65             // Unlink partially removed items.
66             node_t *other;
67             if ((other = SYNC_CAS(&pred->next, item, STRIP_TAG(next))) == item) {
68                 item = (node_t *)STRIP_TAG(next);
69                 next = item->next;
70                 TRACE("l3", "find_pred: unlinked item %p from pred %p", item, pred);
71                 TRACE("l3", "find_pred: now item is %p next is %p", item, next);
72
73                 // The thread that completes the unlink should free the memory.
74                 nbd_defer_free(other);
75             } else {
76                 TRACE("l3", "find_pred: lost race to unlink item from pred %p; its link changed to %p", pred, other);
77                 if (IS_TAGGED(other))
78                     return find_pred(pred_ptr, list, key, help_remove); // retry
79                 item = other;
80                 next = item->next;
81             }
82         }
83
84         // If we reached the key (or passed where it should be), we found the right predesssor
85         if (item->key >= key) {
86             TRACE("l3", "find_pred: found pred %p item %p", pred, item);
87             if (pred_ptr != NULL) {
88                 *pred_ptr = pred;
89             }
90             return item;
91         }
92
93         assert(count++ < 18);
94         pred = item;
95         item = next;
96
97     } while (1);
98 }
99
100 // Fast find. Do not help unlink partially removed nodes and do not return the found item's predecessor.
101 uint64_t ll_lookup (list_t *list, uint64_t key) {
102     TRACE("l3", "ll_lookup: searching for key %p in list %p", key, list);
103     node_t *item = find_pred(NULL, list, key, FALSE);
104
105     // If we found an <item> matching the <key> return its value.
106     return (item->key == key) ? item->value : DOES_NOT_EXIST;
107 }
108
109 // Insert the <key>, if it doesn't already exist in the <list>
110 uint64_t ll_add (list_t *list, uint64_t key, uint64_t value) {
111     TRACE("l3", "ll_add: inserting key %p value %p", key, value);
112     node_t *pred;
113     node_t *item = NULL;
114     do {
115         node_t *next = find_pred(&pred, list, key, TRUE);
116
117         // If a node matching <key> already exists in the list, return its value.
118         if (next->key == key) {
119             TRACE("l3", "ll_add: there is already an item %p (value %p) with the same key", next, next->value);
120             if (EXPECT_FALSE(item != NULL)) { nbd_free(item); }
121             return next->value;
122         }
123
124         TRACE("l3", "ll_add: attempting to insert item between %p and %p", pred, next);
125         if (EXPECT_TRUE(item == NULL)) { item = node_alloc(key, value); }
126         item->next = next;
127         node_t *other = SYNC_CAS(&pred->next, next, item);
128         if (other == next) {
129             TRACE("l3", "ll_add: successfully inserted item %p", item, 0);
130             return DOES_NOT_EXIST; // success
131         }
132         TRACE("l3", "ll_add: failed to change pred's link: expected %p found %p", next, other);
133
134     } while (1);
135 }
136
137 uint64_t ll_remove (list_t *list, uint64_t key) {
138     TRACE("l3", "ll_remove: removing item with key %p from list %p", key, list);
139     node_t *pred;
140     node_t *item = find_pred(&pred, list, key, TRUE);
141     if (item->key != key) {
142         TRACE("l3", "ll_remove: remove failed, an item with a matching key does not exist in the list", 0, 0);
143         return DOES_NOT_EXIST;
144     }
145
146     // Mark <item> removed. This must be atomic. If multiple threads try to remove the same item
147     // only one of them should succeed.
148     if (EXPECT_FALSE(IS_TAGGED(item->next))) {
149         TRACE("l3", "ll_remove: %p is already marked for removal by another thread", item, 0);
150         return DOES_NOT_EXIST;
151     }
152     node_t *next = SYNC_FETCH_AND_OR(&item->next, TAG);
153     if (EXPECT_FALSE(IS_TAGGED(next))) {
154         TRACE("l3", "ll_remove: lost race -- %p is already marked for removal by another thread", item, 0);
155         return DOES_NOT_EXIST;
156     }
157
158     uint64_t value = item->value;
159
160     // Unlink <item> from the list.
161     TRACE("l3", "ll_remove: link item's pred %p to it's successor %p", pred, next);
162     node_t *other;
163     if ((other = SYNC_CAS(&pred->next, item, next)) != item) {
164         TRACE("l3", "ll_remove: unlink failed; pred's link changed from %p to %p", item, other);
165         // By marking the item earlier, we logically removed it. It is safe to leave the item.
166         // Another thread will finish physically removing it from the list.
167         return value;
168     } 
169
170     // The thread that completes the unlink should free the memory.
171     nbd_defer_free(item); 
172     return value;
173 }
174
175 void ll_print (list_t *list) {
176     node_t *item;
177     item = list->head;
178     while (item) {
179         printf("0x%llx ", item->key);
180         fflush(stdout);
181         item = item->next;
182     }
183     printf("\n");
184 }
185
186 #ifdef MAKE_list_test
187 #include <errno.h>
188 #include <pthread.h>
189 #include <sys/time.h>
190
191 #include "runtime.h"
192
193 #define NUM_ITERATIONS 10000000
194
195 static volatile int wait_;
196 static long num_threads_;
197 static list_t *ll_;
198
199 void *worker (void *arg) {
200     int id = (int)(size_t)arg;
201
202     unsigned int rand_seed = id+1;//rdtsc_l();
203
204     // Wait for all the worker threads to be ready.
205     SYNC_ADD(&wait_, -1);
206     do {} while (wait_); 
207
208     for (int i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
209         int n = rand_r(&rand_seed);
210         int key = (n & 0xF) + 1;
211         if (n & (1 << 8)) {
212             ll_add(ll_, key, 1);
213         } else {
214             ll_remove(ll_, key);
215         }
216
217         rcu_update();
218     }
219
220     return NULL;
221 }
222
223 int main (int argc, char **argv) {
224     nbd_init();
225     //lwt_set_trace_level("m0l0");
226
227     char* program_name = argv[0];
228     pthread_t thread[MAX_NUM_THREADS];
229
230     if (argc > 2) {
231         fprintf(stderr, "Usage: %s num_threads\n", program_name);
232         return -1;
233     }
234
235     num_threads_ = 2;
236     if (argc == 2)
237     {
238         errno = 0;
239         num_threads_ = strtol(argv[1], NULL, 10);
240         if (errno) {
241             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
242             return -1;
243         }
244         if (num_threads_ <= 0) {
245             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
246             return -1;
247         }
248         if (num_threads_ > MAX_NUM_THREADS) {
249             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
250             return -1;
251         }
252     }
253
254     ll_ = ll_alloc();
255
256     struct timeval tv1, tv2;
257     gettimeofday(&tv1, NULL);
258
259     wait_ = num_threads_;
260
261     for (int i = 0; i < num_threads_; ++i) {
262         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
263         if (rc != 0) { perror("pthread_create"); return rc; }
264     }
265
266     for (int i = 0; i < num_threads_; ++i) {
267         pthread_join(thread[i], NULL);
268     }
269
270     gettimeofday(&tv2, NULL);
271     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
272     printf("Th:%ld Time:%dms\n", num_threads_, ms);
273     ll_print(ll_);
274     lwt_dump("lwt.out");
275
276     return 0;
277 }
278 #endif//list_test