Remove unused codes
[c11tester.git] / snapshot.cc
1 #include <inttypes.h>
2 #include <sys/mman.h>
3 #include <unistd.h>
4 #include <signal.h>
5 #include <stdlib.h>
6 #include <string.h>
7 #include <errno.h>
8 #include <sys/wait.h>
9
10 #include "hashtable.h"
11 #include "snapshot.h"
12 #include "mymemory.h"
13 #include "common.h"
14 #include "context.h"
15 #include "model.h"
16
17
18 #if USE_MPROTECT_SNAPSHOT
19
20 /** PageAlignedAdressUpdate return a page aligned address for the
21  * address being added as a side effect the numBytes are also changed.
22  */
23 static void * PageAlignAddressUpward(void *addr)
24 {
25         return (void *)((((uintptr_t)addr) + PAGESIZE - 1) & ~(PAGESIZE - 1));
26 }
27
28 /* Each SnapShotRecord lists the firstbackingpage that must be written to
29  * revert to that snapshot */
30 struct SnapShotRecord {
31         unsigned int firstBackingPage;
32 };
33
34 /** @brief Backing store page */
35 typedef unsigned char snapshot_page_t[PAGESIZE];
36
37 /* List the base address of the corresponding page in the backing store so we
38  * know where to copy it to */
39 struct BackingPageRecord {
40         void *basePtrOfPage;
41 };
42
43 /* Struct for each memory region */
44 struct MemoryRegion {
45         void *basePtr;  // base of memory region
46         int sizeInPages;        // size of memory region in pages
47 };
48
49 /** ReturnPageAlignedAddress returns a page aligned address for the
50  * address being added as a side effect the numBytes are also changed.
51  */
52 static void * ReturnPageAlignedAddress(void *addr)
53 {
54         return (void *)(((uintptr_t)addr) & ~(PAGESIZE - 1));
55 }
56
57 /* Primary struct for snapshotting system */
58 struct mprot_snapshotter {
59         mprot_snapshotter(unsigned int numbackingpages, unsigned int numsnapshots, unsigned int nummemoryregions);
60         ~mprot_snapshotter();
61
62         struct MemoryRegion *regionsToSnapShot; //This pointer references an array of memory regions to snapshot
63         snapshot_page_t *backingStore;  //This pointer references an array of snapshotpage's that form the backing store
64         void *backingStoreBasePtr;      //This pointer references an array of snapshotpage's that form the backing store
65         struct BackingPageRecord *backingRecords;       //This pointer references an array of backingpagerecord's (same number of elements as backingstore
66         struct SnapShotRecord *snapShots;       //This pointer references the snapshot array
67
68         unsigned int lastSnapShot;      //Stores the next snapshot record we should use
69         unsigned int lastBackingPage;   //Stores the next backingpage we should use
70         unsigned int lastRegion;        //Stores the next memory region to be used
71
72         unsigned int maxRegions;        //Stores the max number of memory regions we support
73         unsigned int maxBackingPages;   //Stores the total number of backing pages
74         unsigned int maxSnapShots;      //Stores the total number of snapshots we allow
75
76         MEMALLOC
77 };
78
79 static struct mprot_snapshotter *mprot_snap = NULL;
80
81 mprot_snapshotter::mprot_snapshotter(unsigned int backing_pages, unsigned int snapshots, unsigned int regions) :
82         lastSnapShot(0),
83         lastBackingPage(0),
84         lastRegion(0),
85         maxRegions(regions),
86         maxBackingPages(backing_pages),
87         maxSnapShots(snapshots)
88 {
89         regionsToSnapShot = (struct MemoryRegion *)model_malloc(sizeof(struct MemoryRegion) * regions);
90         backingStoreBasePtr = (void *)model_malloc(sizeof(snapshot_page_t) * (backing_pages + 1));
91         //Page align the backingstorepages
92         backingStore = (snapshot_page_t *)PageAlignAddressUpward(backingStoreBasePtr);
93         backingRecords = (struct BackingPageRecord *)model_malloc(sizeof(struct BackingPageRecord) * backing_pages);
94         snapShots = (struct SnapShotRecord *)model_malloc(sizeof(struct SnapShotRecord) * snapshots);
95 }
96
97 mprot_snapshotter::~mprot_snapshotter()
98 {
99         model_free(regionsToSnapShot);
100         model_free(backingStoreBasePtr);
101         model_free(backingRecords);
102         model_free(snapShots);
103 }
104
105 /** mprot_handle_pf is the page fault handler for mprotect based snapshotting
106  * algorithm.
107  */
108 static void mprot_handle_pf(int sig, siginfo_t *si, void *unused)
109 {
110         if (si->si_code == SEGV_MAPERR) {
111                 model_print("Segmentation fault at %p\n", si->si_addr);
112                 model_print("For debugging, place breakpoint at: %s:%d\n",
113                                                                 __FILE__, __LINE__);
114                 // print_trace(); // Trace printing may cause dynamic memory allocation
115                 exit(EXIT_FAILURE);
116         }
117         void* addr = ReturnPageAlignedAddress(si->si_addr);
118
119         unsigned int backingpage = mprot_snap->lastBackingPage++;       //Could run out of pages...
120         if (backingpage == mprot_snap->maxBackingPages) {
121                 model_print("Out of backing pages at %p\n", si->si_addr);
122                 exit(EXIT_FAILURE);
123         }
124
125         //copy page
126         memcpy(&(mprot_snap->backingStore[backingpage]), addr, sizeof(snapshot_page_t));
127         //remember where to copy page back to
128         mprot_snap->backingRecords[backingpage].basePtrOfPage = addr;
129         //set protection to read/write
130         if (mprotect(addr, sizeof(snapshot_page_t), PROT_READ | PROT_WRITE)) {
131                 perror("mprotect");
132                 // Handle error by quitting?
133         }
134 }
135
136 static void mprot_snapshot_init(unsigned int numbackingpages,
137                                                                                                                                 unsigned int numsnapshots, unsigned int nummemoryregions,
138                                                                                                                                 unsigned int numheappages)
139 {
140         /* Setup a stack for our signal handler....  */
141         stack_t ss;
142         ss.ss_sp = PageAlignAddressUpward(model_malloc(SIGSTACKSIZE + PAGESIZE - 1));
143         ss.ss_size = SIGSTACKSIZE;
144         ss.ss_flags = 0;
145         sigaltstack(&ss, NULL);
146
147         struct sigaction sa;
148         sa.sa_flags = SA_SIGINFO | SA_NODEFER | SA_RESTART | SA_ONSTACK;
149         sigemptyset(&sa.sa_mask);
150         sa.sa_sigaction = mprot_handle_pf;
151 #ifdef MAC
152         if (sigaction(SIGBUS, &sa, NULL) == -1) {
153                 perror("sigaction(SIGBUS)");
154                 exit(EXIT_FAILURE);
155         }
156 #endif
157         if (sigaction(SIGSEGV, &sa, NULL) == -1) {
158                 perror("sigaction(SIGSEGV)");
159                 exit(EXIT_FAILURE);
160         }
161
162         mprot_snap = new mprot_snapshotter(numbackingpages, numsnapshots, nummemoryregions);
163
164         // EVIL HACK: We need to make sure that calls into the mprot_handle_pf method don't cause dynamic links
165         // The problem is that we end up protecting state in the dynamic linker...
166         // Solution is to call our signal handler before we start protecting stuff...
167
168         siginfo_t si;
169         memset(&si, 0, sizeof(si));
170         si.si_addr = ss.ss_sp;
171         mprot_handle_pf(SIGSEGV, &si, NULL);
172         mprot_snap->lastBackingPage--;  //remove the fake page we copied
173
174         void *basemySpace = model_malloc((numheappages + 1) * PAGESIZE);
175         void *pagealignedbase = PageAlignAddressUpward(basemySpace);
176         user_snapshot_space = create_mspace_with_base(pagealignedbase, numheappages * PAGESIZE, 1);
177         snapshot_add_memory_region(pagealignedbase, numheappages);
178
179         void *base_model_snapshot_space = model_malloc((numheappages + 1) * PAGESIZE);
180         pagealignedbase = PageAlignAddressUpward(base_model_snapshot_space);
181         model_snapshot_space = create_mspace_with_base(pagealignedbase, numheappages * PAGESIZE, 1);
182         snapshot_add_memory_region(pagealignedbase, numheappages);
183 }
184
185 static void mprot_startExecution(ucontext_t * context, VoidFuncPtr entryPoint) {
186         /* setup the shared-stack context */
187         create_context(context, fork_snap->mStackBase, model_calloc(STACK_SIZE_DEFAULT, 1), STACK_SIZE_DEFAULT, entryPoint);
188 }
189
190 static void mprot_add_to_snapshot(void *addr, unsigned int numPages)
191 {
192         unsigned int memoryregion = mprot_snap->lastRegion++;
193         if (memoryregion == mprot_snap->maxRegions) {
194                 model_print("Exceeded supported number of memory regions!\n");
195                 exit(EXIT_FAILURE);
196         }
197
198         DEBUG("snapshot region %p-%p (%u page%s)\n",
199                                 addr, (char *)addr + numPages * PAGESIZE, numPages,
200                                 numPages > 1 ? "s" : "");
201         mprot_snap->regionsToSnapShot[memoryregion].basePtr = addr;
202         mprot_snap->regionsToSnapShot[memoryregion].sizeInPages = numPages;
203 }
204
205 static snapshot_id mprot_take_snapshot()
206 {
207         for (unsigned int region = 0;region < mprot_snap->lastRegion;region++) {
208                 if (mprotect(mprot_snap->regionsToSnapShot[region].basePtr, mprot_snap->regionsToSnapShot[region].sizeInPages * sizeof(snapshot_page_t), PROT_READ) == -1) {
209                         perror("mprotect");
210                         model_print("Failed to mprotect inside of takeSnapShot\n");
211                         exit(EXIT_FAILURE);
212                 }
213         }
214         unsigned int snapshot = mprot_snap->lastSnapShot++;
215         if (snapshot == mprot_snap->maxSnapShots) {
216                 model_print("Out of snapshots\n");
217                 exit(EXIT_FAILURE);
218         }
219         mprot_snap->snapShots[snapshot].firstBackingPage = mprot_snap->lastBackingPage;
220
221         return snapshot;
222 }
223
224 static void mprot_roll_back(snapshot_id theID)
225 {
226 #if USE_MPROTECT_SNAPSHOT == 2
227         if (mprot_snap->lastSnapShot == (theID + 1)) {
228                 for (unsigned int page = mprot_snap->snapShots[theID].firstBackingPage;page < mprot_snap->lastBackingPage;page++) {
229                         memcpy(mprot_snap->backingRecords[page].basePtrOfPage, &mprot_snap->backingStore[page], sizeof(snapshot_page_t));
230                 }
231                 return;
232         }
233 #endif
234
235         HashTable< void *, bool, uintptr_t, 4, model_malloc, model_calloc, model_free> duplicateMap;
236         for (unsigned int region = 0;region < mprot_snap->lastRegion;region++) {
237                 if (mprotect(mprot_snap->regionsToSnapShot[region].basePtr, mprot_snap->regionsToSnapShot[region].sizeInPages * sizeof(snapshot_page_t), PROT_READ | PROT_WRITE) == -1) {
238                         perror("mprotect");
239                         model_print("Failed to mprotect inside of takeSnapShot\n");
240                         exit(EXIT_FAILURE);
241                 }
242         }
243         for (unsigned int page = mprot_snap->snapShots[theID].firstBackingPage;page < mprot_snap->lastBackingPage;page++) {
244                 if (!duplicateMap.contains(mprot_snap->backingRecords[page].basePtrOfPage)) {
245                         duplicateMap.put(mprot_snap->backingRecords[page].basePtrOfPage, true);
246                         memcpy(mprot_snap->backingRecords[page].basePtrOfPage, &mprot_snap->backingStore[page], sizeof(snapshot_page_t));
247                 }
248         }
249         mprot_snap->lastSnapShot = theID;
250         mprot_snap->lastBackingPage = mprot_snap->snapShots[theID].firstBackingPage;
251         mprot_take_snapshot();  //Make sure current snapshot is still good...All later ones are cleared
252 }
253
254 #else   /* !USE_MPROTECT_SNAPSHOT */
255
256 #define SHARED_MEMORY_DEFAULT  (200 * ((size_t)1 << 20))        // 100mb for the shared memory
257 #define STACK_SIZE_DEFAULT      (((size_t)1 << 20) * 20)        // 20 mb out of the above 100 mb for my stack
258
259 struct fork_snapshotter {
260         /** @brief Pointer to the shared (non-snapshot) memory heap base
261          * (NOTE: this has size SHARED_MEMORY_DEFAULT - sizeof(*fork_snap)) */
262         void *mSharedMemoryBase;
263
264         /** @brief Pointer to the shared (non-snapshot) stack region */
265         void *mStackBase;
266
267         /** @brief Size of the shared stack */
268         size_t mStackSize;
269
270         /**
271          * @brief Stores the ID that we are attempting to roll back to
272          *
273          * Used in inter-process communication so that each process can
274          * determine whether or not to take over execution (w/ matching ID) or
275          * exit (we're rolling back even further). Dubiously marked 'volatile'
276          * to prevent compiler optimizations from messing with the
277          * inter-process behavior.
278          */
279         volatile snapshot_id mIDToRollback;
280
281
282
283         /** @brief Inter-process tracking of the next snapshot ID */
284         snapshot_id currSnapShotID;
285 };
286
287 static struct fork_snapshotter *fork_snap = NULL;
288 ucontext_t shared_ctxt;
289
290 /** @statics
291  *   These variables are necessary because the stack is shared region and
292  *   there exists a race between all processes executing the same function.
293  *   To avoid the problem above, we require variables allocated in 'safe' regions.
294  *   The bug was actually observed with the forkID, these variables below are
295  *   used to indicate the various contexts to which to switch to.
296  *
297  *   @private_ctxt: the context which is internal to the current process. Used
298  *   for running the internal snapshot/rollback loop.
299  *   @exit_ctxt: a special context used just for exiting from a process (so we
300  *   can use swapcontext() instead of setcontext() + hacks)
301  *   @snapshotid: it is a running counter for the various forked processes
302  *   snapshotid. it is incremented and set in a persistently shared record
303  */
304 static ucontext_t private_ctxt;
305 static ucontext_t exit_ctxt;
306 static snapshot_id snapshotid = 0;
307
308 /**
309  * @brief Create a new context, with a given stack and entry function
310  * @param ctxt The context structure to fill
311  * @param stack The stack to run the new context in
312  * @param stacksize The size of the stack
313  * @param func The entry point function for the context
314  */
315 static void create_context(ucontext_t *ctxt, void *stack, size_t stacksize,
316                                                                                                          void (*func)(void))
317 {
318         getcontext(ctxt);
319         ctxt->uc_stack.ss_sp = stack;
320         ctxt->uc_stack.ss_size = stacksize;
321         ctxt->uc_link = NULL;
322         makecontext(ctxt, func, 0);
323 }
324
325 /** @brief An empty function, used for an "empty" context which just exits a
326  *  process */
327 static void fork_exit()
328 {
329         _Exit(EXIT_SUCCESS);
330 }
331
332 static void createSharedMemory()
333 {
334         //step 1. create shared memory.
335         void *memMapBase = mmap(0, SHARED_MEMORY_DEFAULT + STACK_SIZE_DEFAULT, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANON, -1, 0);
336         if (memMapBase == MAP_FAILED) {
337                 perror("mmap");
338                 exit(EXIT_FAILURE);
339         }
340
341         //Setup snapshot record at top of free region
342         fork_snap = (struct fork_snapshotter *)memMapBase;
343         fork_snap->mSharedMemoryBase = (void *)((uintptr_t)memMapBase + sizeof(*fork_snap));
344         fork_snap->mStackBase = (void *)((uintptr_t)memMapBase + SHARED_MEMORY_DEFAULT);
345         fork_snap->mStackSize = STACK_SIZE_DEFAULT;
346         fork_snap->mIDToRollback = -1;
347         fork_snap->currSnapShotID = 0;
348 }
349
350 /**
351  * Create a new mspace pointer for the non-snapshotting (i.e., inter-process
352  * shared) memory region. Only for fork-based snapshotting.
353  *
354  * @return The shared memory mspace
355  */
356 mspace create_shared_mspace()
357 {
358         if (!fork_snap)
359                 createSharedMemory();
360         return create_mspace_with_base((void *)(fork_snap->mSharedMemoryBase), SHARED_MEMORY_DEFAULT - sizeof(*fork_snap), 1);
361 }
362
363 static void fork_snapshot_init(unsigned int numbackingpages,
364                                                                                                                          unsigned int numsnapshots, unsigned int nummemoryregions,
365                                                                                                                          unsigned int numheappages)
366 {
367         if (!fork_snap)
368                 createSharedMemory();
369
370         model_snapshot_space = create_mspace(numheappages * PAGESIZE, 1);
371 }
372
373 volatile int modellock = 0;
374
375 static void fork_loop() {
376         /* switch back here when takesnapshot is called */
377         snapshotid = fork_snap->currSnapShotID;
378         if (model->params.nofork) {
379                 setcontext(&shared_ctxt);
380                 _Exit(EXIT_SUCCESS);
381         }
382
383         while (true) {
384                 pid_t forkedID;
385                 fork_snap->currSnapShotID = snapshotid + 1;
386
387                 modellock = 1;
388                 forkedID = fork();
389                 modellock = 0;
390
391                 if (0 == forkedID) {
392                         setcontext(&shared_ctxt);
393                 } else {
394                         DEBUG("parent PID: %d, child PID: %d, snapshot ID: %d\n",
395                                                 getpid(), forkedID, snapshotid);
396
397                         while (waitpid(forkedID, NULL, 0) < 0) {
398                                 /* waitpid() may be interrupted */
399                                 if (errno != EINTR) {
400                                         perror("waitpid");
401                                         exit(EXIT_FAILURE);
402                                 }
403                         }
404
405                         if (fork_snap->mIDToRollback != snapshotid)
406                                 _Exit(EXIT_SUCCESS);
407                 }
408         }
409 }
410
411 static void fork_startExecution(ucontext_t *context, VoidFuncPtr entryPoint) {
412         /* setup an "exiting" context */
413         int exit_stack_size = 256;
414         create_context(&exit_ctxt, snapshot_calloc(exit_stack_size, 1), exit_stack_size, fork_exit);
415
416         /* setup the system context */
417         create_context(context, fork_snap->mStackBase, STACK_SIZE_DEFAULT, entryPoint);
418         /* switch to a new entryPoint context, on a new stack */
419         create_context(&private_ctxt, snapshot_calloc(STACK_SIZE_DEFAULT, 1), STACK_SIZE_DEFAULT, fork_loop);
420 }
421
422 static snapshot_id fork_take_snapshot() {
423         model_swapcontext(&shared_ctxt, &private_ctxt);
424         DEBUG("TAKESNAPSHOT RETURN\n");
425         return snapshotid;
426 }
427
428 static void fork_roll_back(snapshot_id theID)
429 {
430         DEBUG("Rollback\n");
431         fork_snap->mIDToRollback = theID;
432         model_swapcontext(model->get_system_context(), &exit_ctxt);
433         fork_snap->mIDToRollback = -1;
434 }
435
436 #endif  /* !USE_MPROTECT_SNAPSHOT */
437
438 /**
439  * @brief Initializes the snapshot system
440  * @param entryPoint the function that should run the program.
441  */
442 void snapshot_system_init(unsigned int numbackingpages,
443                                                                                                         unsigned int numsnapshots, unsigned int nummemoryregions,
444                                                                                                         unsigned int numheappages)
445 {
446 #if USE_MPROTECT_SNAPSHOT
447         mprot_snapshot_init(numbackingpages, numsnapshots, nummemoryregions, numheappages);
448 #else
449         fork_snapshot_init(numbackingpages, numsnapshots, nummemoryregions, numheappages);
450 #endif
451 }
452
453 void startExecution(ucontext_t *context, VoidFuncPtr entryPoint)
454 {
455 #if USE_MPROTECT_SNAPSHOT
456         mprot_startExecution(context, entryPoint);
457 #else
458         fork_startExecution(context, entryPoint);
459 #endif
460 }
461
462 /** Assumes that addr is page aligned. */
463 void snapshot_add_memory_region(void *addr, unsigned int numPages)
464 {
465 #if USE_MPROTECT_SNAPSHOT
466         mprot_add_to_snapshot(addr, numPages);
467 #else
468         /* not needed for fork-based snapshotting */
469 #endif
470 }
471
472 /** Takes a snapshot of memory.
473  * @return The snapshot identifier.
474  */
475 snapshot_id take_snapshot()
476 {
477 #if USE_MPROTECT_SNAPSHOT
478         return mprot_take_snapshot();
479 #else
480         return fork_take_snapshot();
481 #endif
482 }
483
484 /** Rolls the memory state back to the given snapshot identifier.
485  *  @param theID is the snapshot identifier to rollback to.
486  */
487 void snapshot_roll_back(snapshot_id theID)
488 {
489 #if USE_MPROTECT_SNAPSHOT
490         mprot_roll_back(theID);
491 #else
492         fork_roll_back(theID);
493 #endif
494 }