Adding support for SMT solvers
[Benchmarks_CSolver.git] / sudoku-csolver / csolversudoku.py
1 import pycsolver as ps
2 from ctypes import *
3 import numpy as np
4 import sys
5
6 class Solver:
7         CSOLVER=1
8         SERIALISE=2
9         ALLOY=3
10         Z3=4
11         MATHSAT=5
12         SMTRAT=6
13 def getSolverType(solverID):
14         return solverID-2
15
16
17
18 def generateProblem(N):
19         return generateSudokuConstraints(N)     
20
21 def solveProblem(N, problem, solverID):
22         return generateSudokuConstraints(N, problem, solverID)
23
24 def replaceWithElemConstOptimization(elemConsts, problem, sudoku):
25         for i,row in enumerate(sudoku):
26                 for j, cell in enumerate(row):
27                         if cell != 0:
28                                 problem[i][j] = elemConsts[cell-1]
29                                 
30 def constantCellConstraint(csolverlb, solver, elemConsts, problem, sudoku):
31         for i,row in enumerate(sudoku):
32                 for j, cell in enumerate(row):
33                         if cell != 0:
34                                 csolverlb.addConstraint(solver, generateEqualityConstraint(csolverlb, solver, problem[i][j], elemConsts[cell-1]))
35
36 def generateEqualityConstraint(csolverlb, solver, e1, e2):
37         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
38         inp = [e1,e2]
39         inputs = (c_void_p*len(inp))(*inp)
40         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
41         return b
42         
43 def extractItemInSetOptimization(csolverlb, solver, sudoku, N):
44         sets =[ [[i for i in range(1, N+1)] for i in range(N)] for i in range (N)]
45         root = int(N**(0.5))
46         for i, row in enumerate(sudoku):
47                 for j, item in enumerate(row):
48                         if item != 0:
49                                 for k in range(N):
50                                         if item in sets[i][k]:
51                                                 sets[i][k].remove(item)
52                                 for k in range(N):
53                                         if item in sets[k][j]:
54                                                 sets[k][j].remove(item)
55                                 ii = (i/root)*root
56                                 jj = (j/root)*root
57                                 for k in range(N):
58                                         if item in sets[ii +k% root][ jj + k//root]:
59                                                 sets[ii +k% root][ jj + k//root].remove(item)                   
60         for i in range(N):
61                 for j in range(N):
62                         setSize = len(sets[i][j])
63                         setp = (c_long*setSize)(*sets[i][j])
64                         sets[i][j] = csolverlb.createSet(solver, c_uint(1), setp, c_uint(setSize))
65         
66         return np.array([[csolverlb.getElementVar(solver,sets[i][j]) for j in range(N)] for i in range(N)])
67
68 def generateSudokuConstraints(N, sudoku = None, solverID  = -1):
69         csolverlb = ps.loadCSolver()
70         solver = csolverlb.createCCSolver()
71         if solverID >= Solver.ALLOY:
72                 csolverlb.setInterpreter(solver, getSolverType(solverID))
73         s1 = [ i for i in range(1, N+1)]
74         set1 = (c_long* len(s1))(*s1)
75         s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(N))
76         problem = np.array([[csolverlb.getElementVar(solver,s1) for i in range(N)] for i in range(N)])# if sudoku is None else extractItemInSetOptimization(csolverlb, solver, sudoku, N)
77         elemConsts = [csolverlb.getElementConst(solver, c_uint(1), i) for i in range(1, N+1)]
78         
79                                 
80         def valid(cells):
81                 for i, ei in enumerate(cells):
82                         for j, ej in enumerate(cells):
83                                 if i < j:
84                                         si = csolverlb.getElementRange(solver, ei)
85                                         sj = csolverlb.getElementRange(solver,ej)
86                                         d = [si,sj]
87                                         domain = (c_void_p *len(d))(*d)
88                                         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
89                                         inp = [ei,ej]
90                                         inputs = (c_void_p*len(inp))(*inp)
91                                         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
92                                         b = csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, b)
93                                         csolverlb.addConstraint(solver,b);
94
95
96         # ensure each cell at least has one value!
97 #       for i,row in enumerate(problem):
98 #               for j, elem in enumerate(row):
99 #                       constr = []
100 #                       for econst in elemConsts:
101 #                               s1 = csolverlb.getElementRange(solver, elem)
102 #                               sconst = csolverlb.getElementRange(solver,econst)
103 #                               d = [s1,sconst]
104 #                               domain = (c_void_p *len(d))(*d)
105 #                               equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
106 #                               inp = [elem,econst]
107 #                               inputs = (c_void_p*len(inp))(*inp)
108 #                               constr.append( csolverlb.applyPredicate(solver,equals, inputs, c_uint(2)))
109 #                       b = (c_void_p*len(constr))(*constr)
110 #                       b = csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR, b, len(constr))
111 #                       csolverlb.addConstraint(solver,b);
112         
113         
114         #ensure each cell at least has one value
115         for i,row in enumerate(problem):
116                 for j, elem in enumerate(row):
117                         csolverlb.mustHaveValue(solver, elem)
118
119         # ensure rows and columns have distinct values
120         for i in range( N):
121                 valid(problem[:,i])
122                 valid(problem[i,:])
123         
124         # ensure each block has distinct values
125         root = int(N**(0.5))
126         collections = [ root*i for i in range(root)]
127         for i in collections:
128                 for j in collections:
129                         valid([problem[i + k % root, j + k // root] for k in range(N)])
130
131         
132         # Is it a sudoku to solve?
133         if sudoku is not None:
134 #               replaceWithElemConstOptimization(elemConsts, problem, sudoku)
135                 constantCellConstraint(csolverlb, solver, elemConsts, problem, sudoku)                                          
136
137 #       csolverlb.printConstraints(solver);     
138         #Serializing the problem before solving it ....
139         if solverID == Solver.SERIALISE:
140                 csolverlb.serialize(solver)
141
142         if csolverlb.solve(solver) != 1:
143                 print "Problem is unsolvable!"
144                 sys.exit(1)
145         result = [[0 for i in range(N)] for i in range(N)]
146         for i,row in enumerate(problem):
147                 for j, elem in enumerate(row):
148                         result[i][j] = csolverlb.getElementValue(solver, elem)
149         return result
150