]> pd.if.org Git - nbds/blob - struct/list.c
28c15aa9da2f4e9643b0c1b89a421f4208faf9e6
[nbds] / struct / list.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Harris-Michael lock-free list-based set
6  * http://www.research.ibm.com/people/m/michael/spaa-2002.pdf
7  */
8 #include <stdio.h>
9 #include <string.h>
10 #include <sys/time.h>
11
12 #include "common.h"
13 #include "lwt.h"
14 #include "rcu.h"
15 #include "mem.h"
16
17 #define NUM_ITERATIONS 10000000
18
19 #define PLACE_MARK(x) (((size_t)(x))|1)
20 #define CLEAR_MARK(x) (((size_t)(x))&~(size_t)1)
21 #define IS_MARKED(x)  ((size_t)(x))&1
22
23 typedef struct node {
24     struct node *next;
25     int key;
26 } node_t;
27
28 typedef struct list {
29     node_t head[1];
30     node_t last;
31 } list_t;
32
33 static void list_node_init (node_t *item, int key)
34 {
35     memset(item, 0, sizeof(node_t));
36     item->key = key;
37 }
38
39 node_t *list_node_alloc (int key)
40 {
41     node_t *item = (node_t *)nbd_malloc(sizeof(node_t));
42     list_node_init(item, key);
43     return item;
44 }
45
46 list_t *list_alloc (void)
47 {
48     list_t *list = (list_t *)nbd_malloc(sizeof(list_t));
49     list_node_init(list->head, INT_MIN);
50     list_node_init(&list->last, INT_MAX);
51     list->head->next = &list->last;
52     return list;
53 }
54
55 static void find_pred_and_item (node_t **pred_ptr, node_t **item_ptr, list_t *list, int key)
56 {
57     node_t *pred = list->head;
58     node_t *item = list->head->next; // head is never removed
59     TRACE("l3", "find_pred_and_item: searching for key %llu in list (head is %p)", key, pred);
60 #ifndef NDEBUG
61     int count = 0;
62 #endif
63     do {
64         // skip removed items
65         node_t *other, *next = item->next;
66         TRACE("l3", "find_pred_and_item: visiting item %p (next is %p)", item, next);
67         while (EXPECT_FALSE(IS_MARKED(next))) {
68             
69             // assist in unlinking partially removed items
70             if ((other = SYNC_CAS(&pred->next, item, CLEAR_MARK(next))) != item)
71             {
72                 TRACE("l3", "find_pred_and_item: failed to unlink item from pred %p, pred's next pointer was changed to %p", pred, other);
73                 return find_pred_and_item(pred_ptr, item_ptr, list, key); // retry
74             }
75
76             assert(count++ < 18);
77             item = (node_t *)CLEAR_MARK(next);
78             next = item->next;
79             TRACE("l3", "find_pred_and_item: unlinked item, %p is the new item (next is %p)", item, next);
80         }
81
82         if (item->key >= key) {
83             *pred_ptr = pred;
84             *item_ptr = item;
85             TRACE("l3", "find_pred_and_item: key found, returning pred %p and item %p", pred, item);
86             return;
87         }
88
89         assert(count++ < 18);
90         pred = item;
91         item = next;
92
93     } while (1);
94 }
95
96 int list_insert (list_t *list, node_t *item)
97 {
98     TRACE("l3", "list_insert: inserting %p (with key %llu)", item, item->key);
99     node_t *pred, *next, *other = (node_t *)-1;
100     do {
101         if (other != (node_t *)-1) {
102             TRACE("l3", "list_insert: failed to swap item into list; pred's next was changed to %p", other, 0);
103         }
104         find_pred_and_item(&pred, &next, list, item->key);
105
106         // fail if item already exists in list
107         if (next->key == item->key)
108         {
109             TRACE("l3", "list_insert: insert failed item with key already exists %p", next, 0);
110             return 0;
111         }
112
113         item->next = next;
114         TRACE("l3", "list_insert: attempting to insert item between %p and %p", pred, next);
115
116     } while ((other = __sync_val_compare_and_swap(&pred->next, next, item)) != next);
117
118     TRACE("l3", "list_insert: insert was successful", 0, 0);
119
120     // success
121     return 1;
122 }
123
124 node_t *list_remove (list_t *list, int key)
125 {
126     node_t *pred, *item, *next;
127
128     TRACE("l3", "list_remove: removing item with key %llu", key, 0);
129     find_pred_and_item(&pred, &item, list, key);
130     if (item->key != key)
131     {
132         TRACE("l3", "list_remove: remove failed, key does not exist in list", 0, 0);
133         return NULL;
134     }
135
136     // Mark <item> removed, must be atomic. If multiple threads try to remove the 
137     // same item only one of them should succeed
138     next = item->next;
139     node_t *other = (node_t *)-1;
140     if (IS_MARKED(next) || (other = __sync_val_compare_and_swap(&item->next, next, PLACE_MARK(next))) != next) {
141         if (other == (node_t *)-1) {
142             TRACE("l3", "list_remove: retry; %p is already marked for removal (it's next pointer is %p)", item, next);
143         } else {
144             TRACE("l3", "list_remove: retry; failed to mark %p for removal; it's next pointer was %p, but changed to %p", next, other);
145         }
146         return list_remove(list, key); // retry
147     }
148
149     // Remove <item> from list
150     TRACE("l3", "list_remove: link item's pred %p to it's successor %p", pred, next);
151     if ((other = __sync_val_compare_and_swap(&pred->next, item, next)) != item) {
152         TRACE("l3", "list_remove: link failed; pred's link changed from %p to %p", item, other);
153
154         // make sure item gets unlinked before returning it
155         node_t *d1, *d2;
156         find_pred_and_item(&d1, &d2, list, key);
157     } else {
158         TRACE("l3", "list_remove: link succeeded; pred's link changed from %p to %p", item, next);
159     }
160
161     return item;
162 }
163
164 void list_print (list_t *list)
165 {
166     node_t *item;
167     item = list->head;
168     while (item) {
169         printf("%d ", item->key);
170         fflush(stdout);
171         item = item->next;
172     }
173     printf("\n");
174 }
175
176 #ifdef MAKE_list_test
177 #include <errno.h>
178 #include <pthread.h>
179 #include "nbd.h"
180
181 static volatile int wait_;
182 static long num_threads_;
183 static list_t *list_;
184
185 void *worker (void *arg)
186 {
187     int id = (int)(size_t)arg;
188     nbd_thread_init(id);
189
190     unsigned int rand_seed = id+1;//rdtsc_l();
191
192     // Wait for all the worker threads to be ready.
193     __sync_fetch_and_add(&wait_, -1);
194     do {} while (wait_); 
195     __asm__ __volatile__("lfence"); 
196
197     int i;
198     for (i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
199         int n = rand_r(&rand_seed);
200         int key = (n & 0xF) + 1;
201         if (n & (1 << 8)) {
202             node_t *item = list_node_alloc(key);
203             int success = list_insert(list_, item);
204             if (!success) {
205                 nbd_free(item); 
206             }
207         } else {
208             node_t *item = list_remove(list_, key);
209             if (item) {
210                 nbd_defer_free(item);
211             }
212         }
213
214         rcu_update();
215     }
216
217     return NULL;
218 }
219
220 int main (int argc, char **argv)
221 {
222     nbd_init();
223     //lwt_set_trace_level("m0l0");
224
225     char* program_name = argv[0];
226     pthread_t thread[MAX_NUM_THREADS];
227
228     if (argc > 2) {
229         fprintf(stderr, "Usage: %s num_threads\n", program_name);
230         return -1;
231     }
232
233     num_threads_ = 2;
234     if (argc == 2)
235     {
236         errno = 0;
237         num_threads_ = strtol(argv[1], NULL, 10);
238         if (errno) {
239             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
240             return -1;
241         }
242         if (num_threads_ <= 0) {
243             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
244             return -1;
245         }
246         if (num_threads_ > MAX_NUM_THREADS) {
247             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
248             return -1;
249         }
250     }
251
252     list_ = list_alloc();
253
254     struct timeval tv1, tv2;
255     gettimeofday(&tv1, NULL);
256
257     __asm__ __volatile__("sfence"); 
258     wait_ = num_threads_;
259
260     int i;
261     for (i = 0; i < num_threads_; ++i) {
262         int rc = pthread_create(thread + i, NULL, worker, (void*)(size_t)i);
263         if (rc != 0) { perror("pthread_create"); return rc; }
264     }
265
266     for (i = 0; i < num_threads_; ++i) {
267         pthread_join(thread[i], NULL);
268     }
269
270     gettimeofday(&tv2, NULL);
271     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
272     printf("Th:%ld Time:%dms\n", num_threads_, ms);
273     list_print(list_);
274     lwt_dump("lwt.out");
275
276     return 0;
277 }
278 #endif//list_test