]> pd.if.org Git - nbds/blob - struct/list.c
2fad6ee340bff01ca7c7b4c65584d7551a619b92
[nbds] / struct / list.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Harris-Michael lock-free list-based set
6  * http://www.research.ibm.com/people/m/michael/spaa-2002.pdf
7  */
8 #include <stdio.h>
9 #include <string.h>
10 #include <sys/time.h>
11
12 #include "common.h"
13 #include "lwt.h"
14 #include "mem.h"
15
16 #define NUM_ITERATIONS 10000000
17
18 #define PLACE_MARK(x) (((size_t)(x))|1)
19 #define CLEAR_MARK(x) (((size_t)(x))&~(size_t)1)
20 #define IS_MARKED(x)  ((size_t)(x))&1
21
22 typedef struct node {
23     struct node *next;
24     int key;
25 } node_t;
26
27 typedef struct list {
28     node_t head[1];
29     node_t last;
30 } list_t;
31
32 static void list_node_init (node_t *item, int key) {
33     memset(item, 0, sizeof(node_t));
34     item->key = key;
35 }
36
37 node_t *list_node_alloc (int key) {
38     node_t *item = (node_t *)nbd_malloc(sizeof(node_t));
39     list_node_init(item, key);
40     return item;
41 }
42
43 list_t *list_alloc (void) {
44     list_t *list = (list_t *)nbd_malloc(sizeof(list_t));
45     list_node_init(list->head, INT_MIN);
46     list_node_init(&list->last, INT_MAX);
47     list->head->next = &list->last;
48     return list;
49 }
50
51 static void find_pred_and_item (node_t **pred_ptr, node_t **item_ptr, list_t *list, int key) {
52     node_t *pred = list->head;
53     node_t *item = list->head->next; // head is never removed
54     TRACE("l3", "find_pred_and_item: searching for key %llu in list (head is %p)", key, pred);
55 #ifndef NDEBUG
56     int count = 0;
57 #endif
58     do {
59         // skip removed items
60         node_t *other, *next = item->next;
61         TRACE("l3", "find_pred_and_item: visiting item %p (next is %p)", item, next);
62         while (EXPECT_FALSE(IS_MARKED(next))) {
63             
64             // assist in unlinking partially removed items
65             if ((other = SYNC_CAS(&pred->next, item, CLEAR_MARK(next))) != item)
66             {
67                 TRACE("l3", "find_pred_and_item: failed to unlink item from pred %p, pred's next pointer was changed to %p", pred, other);
68                 return find_pred_and_item(pred_ptr, item_ptr, list, key); // retry
69             }
70
71             assert(count++ < 18);
72             item = (node_t *)CLEAR_MARK(next);
73             next = item->next;
74             TRACE("l3", "find_pred_and_item: unlinked item, %p is the new item (next is %p)", item, next);
75         }
76
77         if (item->key >= key) {
78             *pred_ptr = pred;
79             *item_ptr = item;
80             TRACE("l3", "find_pred_and_item: key found, returning pred %p and item %p", pred, item);
81             return;
82         }
83
84         assert(count++ < 18);
85         pred = item;
86         item = next;
87
88     } while (1);
89 }
90
91 int list_insert (list_t *list, node_t *item) {
92     TRACE("l3", "list_insert: inserting %p (with key %llu)", item, item->key);
93     node_t *pred, *next, *other = (node_t *)-1;
94     do {
95         if (other != (node_t *)-1) {
96             TRACE("l3", "list_insert: failed to swap item into list; pred's next was changed to %p", other, 0);
97         }
98         find_pred_and_item(&pred, &next, list, item->key);
99
100         // fail if item already exists in list
101         if (next->key == item->key)
102         {
103             TRACE("l3", "list_insert: insert failed item with key already exists %p", next, 0);
104             return 0;
105         }
106
107         item->next = next;
108         TRACE("l3", "list_insert: attempting to insert item between %p and %p", pred, next);
109
110     } while ((other = __sync_val_compare_and_swap(&pred->next, next, item)) != next);
111
112     TRACE("l3", "list_insert: insert was successful", 0, 0);
113
114     // success
115     return 1;
116 }
117
118 node_t *list_remove (list_t *list, int key) {
119     node_t *pred, *item, *next;
120
121     TRACE("l3", "list_remove: removing item with key %llu", key, 0);
122     find_pred_and_item(&pred, &item, list, key);
123     if (item->key != key)
124     {
125         TRACE("l3", "list_remove: remove failed, key does not exist in list", 0, 0);
126         return NULL;
127     }
128
129     // Mark <item> removed, must be atomic. If multiple threads try to remove the 
130     // same item only one of them should succeed
131     next = item->next;
132     node_t *other = (node_t *)-1;
133     if (IS_MARKED(next) || (other = __sync_val_compare_and_swap(&item->next, next, PLACE_MARK(next))) != next) {
134         if (other == (node_t *)-1) {
135             TRACE("l3", "list_remove: retry; %p is already marked for removal (it's next pointer is %p)", item, next);
136         } else {
137             TRACE("l3", "list_remove: retry; failed to mark %p for removal; it's next pointer was %p, but changed to %p", next, other);
138         }
139         return list_remove(list, key); // retry
140     }
141
142     // Remove <item> from list
143     TRACE("l3", "list_remove: link item's pred %p to it's successor %p", pred, next);
144     if ((other = __sync_val_compare_and_swap(&pred->next, item, next)) != item) {
145         TRACE("l3", "list_remove: link failed; pred's link changed from %p to %p", item, other);
146
147         // make sure item gets unlinked before returning it
148         node_t *d1, *d2;
149         find_pred_and_item(&d1, &d2, list, key);
150     } else {
151         TRACE("l3", "list_remove: link succeeded; pred's link changed from %p to %p", item, next);
152     }
153
154     return item;
155 }
156
157 void list_print (list_t *list) {
158     node_t *item;
159     item = list->head;
160     while (item) {
161         printf("%d ", item->key);
162         fflush(stdout);
163         item = item->next;
164     }
165     printf("\n");
166 }
167
168 #ifdef MAKE_list_test
169 #include <errno.h>
170 #include <pthread.h>
171 #include "runtime.h"
172
173 static volatile int wait_;
174 static long num_threads_;
175 static list_t *list_;
176
177 void *worker (void *arg) {
178     int id = (int)(size_t)arg;
179
180     unsigned int rand_seed = id+1;//rdtsc_l();
181
182     // Wait for all the worker threads to be ready.
183     __sync_fetch_and_add(&wait_, -1);
184     do {} while (wait_); 
185     __asm__ __volatile__("lfence"); 
186
187     int i;
188     for (i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
189         int n = rand_r(&rand_seed);
190         int key = (n & 0xF) + 1;
191         if (n & (1 << 8)) {
192             node_t *item = list_node_alloc(key);
193             int success = list_insert(list_, item);
194             if (!success) {
195                 nbd_free(item); 
196             }
197         } else {
198             node_t *item = list_remove(list_, key);
199             if (item) {
200                 nbd_defer_free(item);
201             }
202         }
203
204         rcu_update();
205     }
206
207     return NULL;
208 }
209
210 int main (int argc, char **argv) {
211     nbd_init();
212     //lwt_set_trace_level("m0l0");
213
214     char* program_name = argv[0];
215     pthread_t thread[MAX_NUM_THREADS];
216
217     if (argc > 2) {
218         fprintf(stderr, "Usage: %s num_threads\n", program_name);
219         return -1;
220     }
221
222     num_threads_ = 2;
223     if (argc == 2)
224     {
225         errno = 0;
226         num_threads_ = strtol(argv[1], NULL, 10);
227         if (errno) {
228             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
229             return -1;
230         }
231         if (num_threads_ <= 0) {
232             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
233             return -1;
234         }
235         if (num_threads_ > MAX_NUM_THREADS) {
236             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
237             return -1;
238         }
239     }
240
241     list_ = list_alloc();
242
243     struct timeval tv1, tv2;
244     gettimeofday(&tv1, NULL);
245
246     __asm__ __volatile__("sfence"); 
247     wait_ = num_threads_;
248
249     int i;
250     for (i = 0; i < num_threads_; ++i) {
251         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
252         if (rc != 0) { perror("pthread_create"); return rc; }
253     }
254
255     for (i = 0; i < num_threads_; ++i) {
256         pthread_join(thread[i], NULL);
257     }
258
259     gettimeofday(&tv2, NULL);
260     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
261     printf("Th:%ld Time:%dms\n", num_threads_, ms);
262     list_print(list_);
263     lwt_dump("lwt.out");
264
265     return 0;
266 }
267 #endif//list_test