whitespace -> use tabs
[satcheck.git] / snapshot.cc
1 /*      Copyright (c) 2015 Regents of the University of California
2  *
3  *      Author: Brian Demsky <bdemsky@uci.edu>
4  *
5  *      This program is free software; you can redistribute it and/or
6  *      modify it under the terms of the GNU General Public License
7  *      version 2 as published by the Free Software Foundation.
8  */
9
10 #include <inttypes.h>
11 #include <sys/mman.h>
12 #include <unistd.h>
13 #include <signal.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <errno.h>
17 #include <sys/wait.h>
18
19 #include "hashtable.h"
20 #include "snapshot.h"
21 #include "mymemory.h"
22 #include "common.h"
23 #include "context.h"
24
25 /** PageAlignedAdressUpdate return a page aligned address for the
26  * address being added as a side effect the numBytes are also changed.
27  */
28 static void * PageAlignAddressUpward(void *addr)
29 {
30         return (void *)((((uintptr_t)addr) + PAGESIZE - 1) & ~(PAGESIZE - 1));
31 }
32
33 #if USE_MPROTECT_SNAPSHOT
34
35 /* Each SnapShotRecord lists the firstbackingpage that must be written to
36  * revert to that snapshot */
37 struct SnapShotRecord {
38         unsigned int firstBackingPage;
39 };
40
41 /** @brief Backing store page */
42 typedef unsigned char snapshot_page_t[PAGESIZE];
43
44 /* List the base address of the corresponding page in the backing store so we
45  * know where to copy it to */
46 struct BackingPageRecord {
47         void *basePtrOfPage;
48 };
49
50 /* Struct for each memory region */
51 struct MemoryRegion {
52         void *basePtr; // base of memory region
53         int sizeInPages; // size of memory region in pages
54 };
55
56 /** ReturnPageAlignedAddress returns a page aligned address for the
57  * address being added as a side effect the numBytes are also changed.
58  */
59 static void * ReturnPageAlignedAddress(void *addr)
60 {
61         return (void *)(((uintptr_t)addr) & ~(PAGESIZE - 1));
62 }
63
64 /* Primary struct for snapshotting system */
65 struct mprot_snapshotter {
66         mprot_snapshotter(unsigned int numbackingpages, unsigned int numsnapshots, unsigned int nummemoryregions);
67         ~mprot_snapshotter();
68
69         struct MemoryRegion *regionsToSnapShot; //This pointer references an array of memory regions to snapshot
70         snapshot_page_t *backingStore; //This pointer references an array of snapshotpage's that form the backing store
71         void *backingStoreBasePtr; //This pointer references an array of snapshotpage's that form the backing store
72         struct BackingPageRecord *backingRecords; //This pointer references an array of backingpagerecord's (same number of elements as backingstore
73         struct SnapShotRecord *snapShots; //This pointer references the snapshot array
74
75         unsigned int lastSnapShot; //Stores the next snapshot record we should use
76         unsigned int lastBackingPage; //Stores the next backingpage we should use
77         unsigned int lastRegion; //Stores the next memory region to be used
78
79         unsigned int maxRegions; //Stores the max number of memory regions we support
80         unsigned int maxBackingPages; //Stores the total number of backing pages
81         unsigned int maxSnapShots; //Stores the total number of snapshots we allow
82
83         MEMALLOC;
84 };
85
86 static struct mprot_snapshotter *mprot_snap = NULL;
87
88 mprot_snapshotter::mprot_snapshotter(unsigned int backing_pages, unsigned int snapshots, unsigned int regions) :
89         lastSnapShot(0),
90         lastBackingPage(0),
91         lastRegion(0),
92         maxRegions(regions),
93         maxBackingPages(backing_pages),
94         maxSnapShots(snapshots)
95 {
96         regionsToSnapShot = (struct MemoryRegion *)model_malloc(sizeof(struct MemoryRegion) * regions);
97         backingStoreBasePtr = (void *)model_malloc(sizeof(snapshot_page_t) * (backing_pages + 1));
98         //Page align the backingstorepages
99         backingStore = (snapshot_page_t *)PageAlignAddressUpward(backingStoreBasePtr);
100         backingRecords = (struct BackingPageRecord *)model_malloc(sizeof(struct BackingPageRecord) * backing_pages);
101         snapShots = (struct SnapShotRecord *)model_malloc(sizeof(struct SnapShotRecord) * snapshots);
102 }
103
104 mprot_snapshotter::~mprot_snapshotter()
105 {
106         model_free(regionsToSnapShot);
107         model_free(backingStoreBasePtr);
108         model_free(backingRecords);
109         model_free(snapShots);
110 }
111
112 /** mprot_handle_pf is the page fault handler for mprotect based snapshotting
113  * algorithm.
114  */
115 static void mprot_handle_pf(int sig, siginfo_t *si, void *unused)
116 {
117         if (si->si_code == SEGV_MAPERR) {
118                 model_print("Segmentation fault at %p\n", si->si_addr);
119                 model_print("For debugging, place breakpoint at: %s:%d\n",
120                                 __FILE__, __LINE__);
121                 // print_trace(); // Trace printing may cause dynamic memory allocation
122                 exit(EXIT_FAILURE);
123         }
124         void* addr = ReturnPageAlignedAddress(si->si_addr);
125
126         unsigned int backingpage = mprot_snap->lastBackingPage++; //Could run out of pages...
127         if (backingpage == mprot_snap->maxBackingPages) {
128                 model_print("Out of backing pages at %p\n", si->si_addr);
129                 exit(EXIT_FAILURE);
130         }
131
132         //copy page
133         memcpy(&(mprot_snap->backingStore[backingpage]), addr, sizeof(snapshot_page_t));
134         //remember where to copy page back to
135         mprot_snap->backingRecords[backingpage].basePtrOfPage = addr;
136         //set protection to read/write
137         if (mprotect(addr, sizeof(snapshot_page_t), PROT_READ | PROT_WRITE)) {
138                 perror("mprotect");
139                 // Handle error by quitting?
140         }
141 }
142
143 static void mprot_snapshot_init(unsigned int numbackingpages,
144                 unsigned int numsnapshots, unsigned int nummemoryregions,
145                 unsigned int numheappages, VoidFuncPtr entryPoint)
146 {
147         /* Setup a stack for our signal handler....  */
148         stack_t ss;
149         ss.ss_sp = PageAlignAddressUpward(model_malloc(SIGSTACKSIZE + PAGESIZE - 1));
150         ss.ss_size = SIGSTACKSIZE;
151         ss.ss_flags = 0;
152         sigaltstack(&ss, NULL);
153
154         struct sigaction sa;
155         sa.sa_flags = SA_SIGINFO | SA_NODEFER | SA_RESTART | SA_ONSTACK;
156         sigemptyset(&sa.sa_mask);
157         sa.sa_sigaction = mprot_handle_pf;
158 #ifdef MAC
159         if (sigaction(SIGBUS, &sa, NULL) == -1) {
160                 perror("sigaction(SIGBUS)");
161                 exit(EXIT_FAILURE);
162         }
163 #endif
164         if (sigaction(SIGSEGV, &sa, NULL) == -1) {
165                 perror("sigaction(SIGSEGV)");
166                 exit(EXIT_FAILURE);
167         }
168
169         mprot_snap = new mprot_snapshotter(numbackingpages, numsnapshots, nummemoryregions);
170
171         // EVIL HACK: We need to make sure that calls into the mprot_handle_pf method don't cause dynamic links
172         // The problem is that we end up protecting state in the dynamic linker...
173         // Solution is to call our signal handler before we start protecting stuff...
174
175         siginfo_t si;
176         memset(&si, 0, sizeof(si));
177         si.si_addr = ss.ss_sp;
178         mprot_handle_pf(SIGSEGV, &si, NULL);
179         mprot_snap->lastBackingPage--; //remove the fake page we copied
180
181         
182         void *basemySpace = model_malloc((numheappages -32 + 1) * PAGESIZE);
183         void *pagealignedbase = PageAlignAddressUpward(basemySpace);
184         user_snapshot_space = pagealignedbase;
185         snapshot_add_memory_region(pagealignedbase, numheappages-32);
186         
187         void *basethreadSpace = model_malloc(33 * PAGESIZE);
188         pagealignedbase = PageAlignAddressUpward(basethreadSpace);
189         thread_snapshot_space = create_mspace_with_base(pagealignedbase, 32 * PAGESIZE, 1);
190         snapshot_add_memory_region(pagealignedbase, 32);
191
192         void *base_model_snapshot_space = model_malloc((numheappages + 1) * PAGESIZE);
193         pagealignedbase = PageAlignAddressUpward(base_model_snapshot_space);
194         model_snapshot_space = create_mspace_with_base(pagealignedbase, numheappages * PAGESIZE, 1);
195         snapshot_add_memory_region(pagealignedbase, numheappages);
196
197         snapshot_struct = (struct snapshot_heap_data *) model_malloc(sizeof(struct snapshot_heap_data));
198         snapshot_struct->allocation_ptr=user_snapshot_space;
199         snapshot_struct->top_ptr=(void *)(((char *)user_snapshot_space)+((numheappages-32)*PAGESIZE));
200         
201         entryPoint();
202 }
203
204 static void mprot_add_to_snapshot(void *addr, unsigned int numPages)
205 {
206         unsigned int memoryregion = mprot_snap->lastRegion++;
207         if (memoryregion == mprot_snap->maxRegions) {
208                 model_print("Exceeded supported number of memory regions!\n");
209                 exit(EXIT_FAILURE);
210         }
211
212         DEBUG("snapshot region %p-%p (%u page%s)\n",
213                         addr, (char *)addr + numPages * PAGESIZE, numPages,
214                         numPages > 1 ? "s" : "");
215         mprot_snap->regionsToSnapShot[memoryregion].basePtr = addr;
216         mprot_snap->regionsToSnapShot[memoryregion].sizeInPages = numPages;
217 }
218
219 static snapshot_id mprot_take_snapshot()
220 {
221         for (unsigned int region = 0; region < mprot_snap->lastRegion; region++) {
222                 if (mprotect(mprot_snap->regionsToSnapShot[region].basePtr, mprot_snap->regionsToSnapShot[region].sizeInPages * sizeof(snapshot_page_t), PROT_READ) == -1) {
223                         perror("mprotect");
224                         model_print("Failed to mprotect inside of takeSnapShot\n");
225                         exit(EXIT_FAILURE);
226                 }
227         }
228         unsigned int snapshot = mprot_snap->lastSnapShot++;
229         if (snapshot == mprot_snap->maxSnapShots) {
230                 model_print("Out of snapshots\n");
231                 exit(EXIT_FAILURE);
232         }
233         mprot_snap->snapShots[snapshot].firstBackingPage = mprot_snap->lastBackingPage;
234
235         return snapshot;
236 }
237
238 static void mprot_roll_back(snapshot_id theID)
239 {
240 #if USE_MPROTECT_SNAPSHOT == 2
241         if (mprot_snap->lastSnapShot == (theID + 1)) {
242                 for (unsigned int page = mprot_snap->snapShots[theID].firstBackingPage; page < mprot_snap->lastBackingPage; page++) {
243                         memcpy(mprot_snap->backingRecords[page].basePtrOfPage, &mprot_snap->backingStore[page], sizeof(snapshot_page_t));
244                 }
245                 return;
246         }
247 #endif
248
249         HashTable< void *, bool, uintptr_t, 4, model_malloc, model_calloc, model_free> duplicateMap;
250         for (unsigned int region = 0; region < mprot_snap->lastRegion; region++) {
251                 if (mprotect(mprot_snap->regionsToSnapShot[region].basePtr, mprot_snap->regionsToSnapShot[region].sizeInPages * sizeof(snapshot_page_t), PROT_READ | PROT_WRITE) == -1) {
252                         perror("mprotect");
253                         model_print("Failed to mprotect inside of takeSnapShot\n");
254                         exit(EXIT_FAILURE);
255                 }
256         }
257         for (unsigned int page = mprot_snap->snapShots[theID].firstBackingPage; page < mprot_snap->lastBackingPage; page++) {
258                 if (!duplicateMap.contains(mprot_snap->backingRecords[page].basePtrOfPage)) {
259                         duplicateMap.put(mprot_snap->backingRecords[page].basePtrOfPage, true);
260                         memcpy(mprot_snap->backingRecords[page].basePtrOfPage, &mprot_snap->backingStore[page], sizeof(snapshot_page_t));
261                 }
262         }
263         mprot_snap->lastSnapShot = theID;
264         mprot_snap->lastBackingPage = mprot_snap->snapShots[theID].firstBackingPage;
265         mprot_take_snapshot(); //Make sure current snapshot is still good...All later ones are cleared
266 }
267
268 #else /* !USE_MPROTECT_SNAPSHOT */
269
270 #define SHARED_MEMORY_DEFAULT  (100 * ((size_t)1 << 20)) // 100mb for the shared memory
271 #define STACK_SIZE_DEFAULT      (((size_t)1 << 20) * 20)  // 20 mb out of the above 100 mb for my stack
272
273 struct fork_snapshotter {
274         /** @brief Pointer to the shared (non-snapshot) memory heap base
275          * (NOTE: this has size SHARED_MEMORY_DEFAULT - sizeof(*fork_snap)) */
276         void *mSharedMemoryBase;
277
278         /** @brief Pointer to the shared (non-snapshot) stack region */
279         void *mStackBase;
280
281         /** @brief Size of the shared stack */
282         size_t mStackSize;
283
284         /**
285          * @brief Stores the ID that we are attempting to roll back to
286          *
287          * Used in inter-process communication so that each process can
288          * determine whether or not to take over execution (w/ matching ID) or
289          * exit (we're rolling back even further). Dubiously marked 'volatile'
290          * to prevent compiler optimizations from messing with the
291          * inter-process behavior.
292          */
293         volatile snapshot_id mIDToRollback;
294
295         /**
296          * @brief The context for the shared (non-snapshot) stack
297          *
298          * This context is passed between the various processes which represent
299          * various snapshot states. It should be used primarily for the
300          * "client-side" code, not the main snapshot loop.
301          */
302         ucontext_t shared_ctxt;
303
304         /** @brief Inter-process tracking of the next snapshot ID */
305         snapshot_id currSnapShotID;
306 };
307
308 static struct fork_snapshotter *fork_snap = NULL;
309
310 /** @statics
311 *   These variables are necessary because the stack is shared region and
312 *   there exists a race between all processes executing the same function.
313 *   To avoid the problem above, we require variables allocated in 'safe' regions.
314 *   The bug was actually observed with the forkID, these variables below are
315 *   used to indicate the various contexts to which to switch to.
316 *
317 *   @private_ctxt: the context which is internal to the current process. Used
318 *   for running the internal snapshot/rollback loop.
319 *   @exit_ctxt: a special context used just for exiting from a process (so we
320 *   can use swapcontext() instead of setcontext() + hacks)
321 *   @snapshotid: it is a running counter for the various forked processes
322 *   snapshotid. it is incremented and set in a persistently shared record
323 */
324 static ucontext_t private_ctxt;
325 static ucontext_t exit_ctxt;
326 static snapshot_id snapshotid = 0;
327
328 /**
329  * @brief Create a new context, with a given stack and entry function
330  * @param ctxt The context structure to fill
331  * @param stack The stack to run the new context in
332  * @param stacksize The size of the stack
333  * @param func The entry point function for the context
334  */
335 static void create_context(ucontext_t *ctxt, void *stack, size_t stacksize,
336                 void (*func)(void))
337 {
338         getcontext(ctxt);
339         ctxt->uc_stack.ss_sp = stack;
340         ctxt->uc_stack.ss_size = stacksize;
341         makecontext(ctxt, func, 0);
342 }
343
344 /** @brief An empty function, used for an "empty" context which just exits a
345  *  process */
346 static void fork_exit()
347 {
348         /* Intentionally empty */
349 }
350
351 static void createSharedMemory()
352 {
353         //step 1. create shared memory.
354         void *memMapBase = mmap(0, SHARED_MEMORY_DEFAULT + STACK_SIZE_DEFAULT, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANON, -1, 0);
355         if (memMapBase == MAP_FAILED) {
356                 perror("mmap");
357                 exit(EXIT_FAILURE);
358         }
359
360         //Setup snapshot record at top of free region
361         fork_snap = (struct fork_snapshotter *)memMapBase;
362         fork_snap->mSharedMemoryBase = (void *)((uintptr_t)memMapBase + sizeof(*fork_snap));
363         fork_snap->mStackBase = (void *)((uintptr_t)memMapBase + SHARED_MEMORY_DEFAULT);
364         fork_snap->mStackSize = STACK_SIZE_DEFAULT;
365         fork_snap->mIDToRollback = -1;
366         fork_snap->currSnapShotID = 0;
367 }
368
369 /**
370  * Create a new mspace pointer for the non-snapshotting (i.e., inter-process
371  * shared) memory region. Only for fork-based snapshotting.
372  *
373  * @return The shared memory mspace
374  */
375 mspace create_shared_mspace()
376 {
377         if (!fork_snap)
378                 createSharedMemory();
379         return create_mspace_with_base((void *)(fork_snap->mSharedMemoryBase), SHARED_MEMORY_DEFAULT - sizeof(*fork_snap), 1);
380 }
381
382 static void fork_snapshot_init(unsigned int numbackingpages,
383                 unsigned int numsnapshots, unsigned int nummemoryregions,
384                 unsigned int numheappages, VoidFuncPtr entryPoint)
385 {
386         if (!fork_snap)
387                 createSharedMemory();
388
389         void *base_model_snapshot_space = malloc((numheappages + 1) * PAGESIZE);
390         void *pagealignedbase = PageAlignAddressUpward(base_model_snapshot_space);
391         model_snapshot_space = create_mspace_with_base(pagealignedbase, numheappages * PAGESIZE, 1);
392
393         /* setup an "exiting" context */
394         char stack[128];
395         create_context(&exit_ctxt, stack, sizeof(stack), fork_exit);
396
397         /* setup the shared-stack context */
398         create_context(&fork_snap->shared_ctxt, fork_snap->mStackBase,
399                         STACK_SIZE_DEFAULT, entryPoint);
400         /* switch to a new entryPoint context, on a new stack */
401         model_swapcontext(&private_ctxt, &fork_snap->shared_ctxt);
402
403         /* switch back here when takesnapshot is called */
404         snapshotid = fork_snap->currSnapShotID;
405
406         while (true) {
407                 pid_t forkedID;
408                 fork_snap->currSnapShotID = snapshotid + 1;
409                 forkedID = fork();
410
411                 if (0 == forkedID) {
412                         setcontext(&fork_snap->shared_ctxt);
413                 } else {
414                         DEBUG("parent PID: %d, child PID: %d, snapshot ID: %d\n",
415                                 getpid(), forkedID, snapshotid);
416
417                         while (waitpid(forkedID, NULL, 0) < 0) {
418                                 /* waitpid() may be interrupted */
419                                 if (errno != EINTR) {
420                                         perror("waitpid");
421                                         exit(EXIT_FAILURE);
422                                 }
423                         }
424
425                         if (fork_snap->mIDToRollback != snapshotid)
426                                 exit(EXIT_SUCCESS);
427                 }
428         }
429 }
430
431 static snapshot_id fork_take_snapshot()
432 {
433         model_swapcontext(&fork_snap->shared_ctxt, &private_ctxt);
434         DEBUG("TAKESNAPSHOT RETURN\n");
435         return snapshotid;
436 }
437
438 static void fork_roll_back(snapshot_id theID)
439 {
440         DEBUG("Rollback\n");
441         fork_snap->mIDToRollback = theID;
442         model_swapcontext(&fork_snap->shared_ctxt, &exit_ctxt);
443         fork_snap->mIDToRollback = -1;
444 }
445
446 #endif /* !USE_MPROTECT_SNAPSHOT */
447
448 /**
449  * @brief Initializes the snapshot system
450  * @param entryPoint the function that should run the program.
451  */
452 void snapshot_system_init(unsigned int numbackingpages,
453                 unsigned int numsnapshots, unsigned int nummemoryregions,
454                 unsigned int numheappages, VoidFuncPtr entryPoint)
455 {
456 #if USE_MPROTECT_SNAPSHOT
457         mprot_snapshot_init(numbackingpages, numsnapshots, nummemoryregions, numheappages, entryPoint);
458 #else
459         fork_snapshot_init(numbackingpages, numsnapshots, nummemoryregions, numheappages, entryPoint);
460 #endif
461 }
462
463 /** Assumes that addr is page aligned. */
464 void snapshot_add_memory_region(void *addr, unsigned int numPages)
465 {
466 #if USE_MPROTECT_SNAPSHOT
467         mprot_add_to_snapshot(addr, numPages);
468 #else
469         /* not needed for fork-based snapshotting */
470 #endif
471 }
472
473 /** Takes a snapshot of memory.
474  * @return The snapshot identifier.
475  */
476 snapshot_id take_snapshot()
477 {
478 #if USE_MPROTECT_SNAPSHOT
479         return mprot_take_snapshot();
480 #else
481         return fork_take_snapshot();
482 #endif
483 }
484
485 /** Rolls the memory state back to the given snapshot identifier.
486  *  @param theID is the snapshot identifier to rollback to.
487  */
488 void snapshot_roll_back(snapshot_id theID)
489 {
490 #if USE_MPROTECT_SNAPSHOT
491         mprot_roll_back(theID);
492 #else
493         fork_roll_back(theID);
494 #endif
495 }