]> pd.if.org Git - nbds/blob - struct/list.c
tweak list
[nbds] / struct / list.c
1 /* 
2  * Written by Josh Dybnis and released to the public domain, as explained at
3  * http://creativecommons.org/licenses/publicdomain
4  *
5  * Harris-Michael lock-free list-based set
6  * http://www.research.ibm.com/people/m/michael/spaa-2002.pdf
7  */
8 #include <stdio.h>
9 #include <string.h>
10
11 #include "common.h"
12 #include "struct.h"
13 #include "mem.h"
14
15 typedef struct node {
16     uint64_t key;
17     uint64_t value;
18     struct node *next;
19 } node_t;
20
21 typedef struct list {
22     node_t *head;
23     node_t *last;
24 } list_t;
25
26 node_t *node_alloc (uint64_t key, uint64_t value) {
27     node_t *item = (node_t *)nbd_malloc(sizeof(node_t));
28     memset(item, 0, sizeof(node_t));
29     item->key   = key;
30     item->value = value;
31     return item;
32 }
33
34 list_t *list_alloc (void) {
35     list_t *list = (list_t *)nbd_malloc(sizeof(list_t));
36     list->head = node_alloc(0, 0);
37     list->last = node_alloc((uint64_t)-1, 0);
38     list->head->next = list->last;
39     return list;
40 }
41
42 static node_t *find_pred (node_t **pred_ptr, list_t *list, uint64_t key, int help_remove) {
43     node_t *pred = list->head;
44     node_t *item = pred->next;
45     TRACE("l3", "find_pred: searching for key %p in list (head is %p)", key, pred);
46 #ifndef NDEBUG
47     int count = 0;
48 #endif
49
50     do {
51         node_t *next = item->next;
52         TRACE("l3", "find_pred: visiting item %p (next %p)", item, next);
53         TRACE("l3", "find_pred: key %p value %p", item->key, item->value);
54
55         // Marked items are partially removed.
56         while (EXPECT_FALSE(IS_TAGGED(next))) {
57
58             // Skip over partially removed items.
59             if (!help_remove) {
60                 item = (node_t *)STRIP_TAG(item->next);
61                 next = item->next;
62                 continue;
63             }
64
65             // Unlink partially removed items.
66             node_t *other;
67             if ((other = SYNC_CAS(&pred->next, item, STRIP_TAG(next))) == item) {
68                 item = (node_t *)STRIP_TAG(next);
69                 next = item->next;
70                 TRACE("l3", "find_pred: unlinked item; %p is the new item (next is %p)", item, next);
71                 nbd_defer_free(other);
72             } else {
73                 TRACE("l3", "find_pred: lost race to unlink item from pred %p; its link changed to %p", pred, other);
74                 if (IS_TAGGED(other))
75                     return find_pred(pred_ptr, list, key, help_remove); // retry
76                 item = other;
77                 next = item->next;
78             }
79         }
80
81         // If we reached the key (or passed where it should be), we found the right predesssor
82         if (item->key >= key) {
83             TRACE("l3", "find_pred: returning pred %p and item %p", pred, item);
84             if (pred_ptr != NULL) {
85                 *pred_ptr = pred;
86             }
87             return item;
88         }
89
90         assert(count++ < 18);
91         pred = item;
92         item = next;
93
94     } while (1);
95 }
96
97 // Fast find. Do not help unlink partially removed nodes and do not return the found item's predecessor.
98 uint64_t list_lookup (list_t *list, uint64_t key) {
99     TRACE("l3", "list_lookup: searching for key %p in list %p", key, list);
100     node_t *item = find_pred(NULL, list, key, FALSE);
101
102     // If we found an <item> matching the <key> return its value.
103     return (item->key == key) ? item->value : DOES_NOT_EXIST;
104 }
105
106 // Insert the <key>, if it doesn't already exist in the <list>
107 uint64_t list_add (list_t *list, uint64_t key, uint64_t value) {
108     TRACE("l3", "list_add: inserting key %p value %p", key, value);
109     node_t *pred;
110     node_t *item = NULL;
111     do {
112         node_t *next = find_pred(&pred, list, key, TRUE);
113
114         // If a node matching <key> already exists in the list, return its value.
115         if (next->key == key) {
116             TRACE("l3", "list_add: there is already an item %p (value %p) with the same key", next, next->value);
117             if (EXPECT_FALSE(item != NULL)) { nbd_free(item); }
118             return next->value;
119         }
120
121         TRACE("l3", "list_add: attempting to insert item between %p and %p", pred, next);
122         if (EXPECT_TRUE(item == NULL)) { item = node_alloc(key, value); }
123         item->next = next;
124         node_t *other = SYNC_CAS(&pred->next, next, item);
125         if (other == next) {
126             TRACE("l3", "list_add: insert was successful", 0, 0);
127             return DOES_NOT_EXIST; // success
128         }
129         TRACE("l3", "list_add: failed to change pred's link: expected %p found %p", next, other);
130
131     } while (1);
132 }
133
134 uint64_t list_remove (list_t *list, uint64_t key) {
135     TRACE("l3", "list_remove: removing item with key %p from list %p", key, list);
136     node_t *pred;
137     node_t *item = find_pred(&pred, list, key, TRUE);
138     if (item->key != key) {
139         TRACE("l3", "list_remove: remove failed, an item with a matching key does not exist in the list", 0, 0);
140         return DOES_NOT_EXIST;
141     }
142
143     // Mark <item> removed. This must be atomic. If multiple threads try to remove the same item
144     // only one of them should succeed.
145     if (EXPECT_FALSE(IS_TAGGED(item->next))) {
146         TRACE("l3", "list_remove: %p is already marked for removal by another thread", item, 0);
147         return DOES_NOT_EXIST;
148     }
149     node_t *next = SYNC_FETCH_AND_OR(&item->next, TAG);
150     if (EXPECT_FALSE(IS_TAGGED(next))) {
151         TRACE("l3", "list_remove: lost race -- %p is already marked for removal by another thread", item, 0);
152         return DOES_NOT_EXIST;
153     }
154
155     uint64_t value = item->value;
156
157     // Unlink <item> from the list.
158     TRACE("l3", "list_remove: link item's pred %p to it's successor %p", pred, next);
159     node_t *other;
160     if ((other = SYNC_CAS(&pred->next, item, next)) != item) {
161         TRACE("l3", "list_remove: unlink failed; pred's link changed from %p to %p", item, other);
162         // By being marked, the item was logically removed. It is safe to leave it for
163         // another thread to finish physically removing it from the skiplist.
164         return value;
165     } 
166
167     // The thread that completes the unlink should free the memory.
168     nbd_defer_free(item); 
169     return value;
170 }
171
172 void list_print (list_t *list) {
173     node_t *item;
174     item = list->head;
175     while (item) {
176         printf("0x%llx ", item->key);
177         fflush(stdout);
178         item = item->next;
179     }
180     printf("\n");
181 }
182
183 #ifdef MAKE_list_test
184 #include <errno.h>
185 #include <pthread.h>
186 #include <sys/time.h>
187
188 #include "runtime.h"
189
190 #define NUM_ITERATIONS 10000000
191
192 static volatile int wait_;
193 static long num_threads_;
194 static list_t *list_;
195
196 void *worker (void *arg) {
197     int id = (int)(size_t)arg;
198
199     unsigned int rand_seed = id+1;//rdtsc_l();
200
201     // Wait for all the worker threads to be ready.
202     SYNC_ADD(&wait_, -1);
203     do {} while (wait_); 
204     __asm__ __volatile__("lfence"); 
205
206     for (int i = 0; i < NUM_ITERATIONS/num_threads_; ++i) {
207         int n = rand_r(&rand_seed);
208         int key = (n & 0xF) + 1;
209         if (n & (1 << 8)) {
210             list_add(list_, key, 1);
211         } else {
212             list_remove(list_, key);
213         }
214
215         rcu_update();
216     }
217
218     return NULL;
219 }
220
221 int main (int argc, char **argv) {
222     nbd_init();
223     //lwt_set_trace_level("m0l0");
224
225     char* program_name = argv[0];
226     pthread_t thread[MAX_NUM_THREADS];
227
228     if (argc > 2) {
229         fprintf(stderr, "Usage: %s num_threads\n", program_name);
230         return -1;
231     }
232
233     num_threads_ = 2;
234     if (argc == 2)
235     {
236         errno = 0;
237         num_threads_ = strtol(argv[1], NULL, 10);
238         if (errno) {
239             fprintf(stderr, "%s: Invalid argument for number of threads\n", program_name);
240             return -1;
241         }
242         if (num_threads_ <= 0) {
243             fprintf(stderr, "%s: Number of threads must be at least 1\n", program_name);
244             return -1;
245         }
246         if (num_threads_ > MAX_NUM_THREADS) {
247             fprintf(stderr, "%s: Number of threads cannot be more than %d\n", program_name, MAX_NUM_THREADS);
248             return -1;
249         }
250     }
251
252     list_ = list_alloc();
253
254     struct timeval tv1, tv2;
255     gettimeofday(&tv1, NULL);
256
257     __asm__ __volatile__("sfence"); 
258     wait_ = num_threads_;
259
260     for (int i = 0; i < num_threads_; ++i) {
261         int rc = nbd_thread_create(thread + i, i, worker, (void*)(size_t)i);
262         if (rc != 0) { perror("pthread_create"); return rc; }
263     }
264
265     for (int i = 0; i < num_threads_; ++i) {
266         pthread_join(thread[i], NULL);
267     }
268
269     gettimeofday(&tv2, NULL);
270     int ms = (int)(1000000*(tv2.tv_sec - tv1.tv_sec) + tv2.tv_usec - tv1.tv_usec) / 1000;
271     printf("Th:%ld Time:%dms\n", num_threads_, ms);
272     list_print(list_);
273     lwt_dump("lwt.out");
274
275     return 0;
276 }
277 #endif//list_test