diff --git a/spot/misc/satsolver.cc b/spot/misc/satsolver.cc index f02d5ca8f..6ab312992 100644 --- a/spot/misc/satsolver.cc +++ b/spot/misc/satsolver.cc @@ -128,22 +128,64 @@ namespace spot start(); } + satsolver::~satsolver() + { + delete cnf_tmp_; + delete cnf_stream_; + delete nclauses_; + } + void satsolver::start() { cnf_tmp_ = create_tmpfile("sat-", ".cnf"); cnf_stream_ = new std::ofstream(cnf_tmp_->name(), std::ios_base::trunc); cnf_stream_->exceptions(std::ofstream::failbit | std::ofstream::badbit); + nclauses_ = new clause_counter(); + + // Add empty line for the header + *cnf_stream_ << " \n"; } - satsolver::~satsolver() + void satsolver::end_clause() { - delete cnf_tmp_; - delete cnf_stream_; + *cnf_stream_ << '\n'; + *nclauses_ += 1; } - std::ostream& satsolver::operator()() + void satsolver::add(std::initializer_list values) { - return *cnf_stream_; + for (auto& v : values) + { + *cnf_stream_ << v << ' '; + if (!v) // ..., 0) + end_clause(); + } + } + + void satsolver::add(int v) + { + *cnf_stream_ << v << ' '; + if (!v) // 0 + end_clause(); + } + + int satsolver::get_nb_clauses() const + { + return nclauses_->nb_clauses(); + } + + std::pair satsolver::stats(int nvars) + { + int nclaus = nclauses_->nb_clauses(); + cnf_stream_->seekp(0); + *cnf_stream_ << "p cnf " << nvars << ' ' << nclaus; + return std::make_pair(nvars, nclaus); + } + + std::pair satsolver::stats() + { + *cnf_stream_ << "p cnf 1 2\n-1 0\n1 0\n"; + return std::make_pair(1, 2); } satsolver::solution_pair diff --git a/spot/misc/satsolver.hh b/spot/misc/satsolver.hh index 75ef7d0de..6a28b565a 100644 --- a/spot/misc/satsolver.hh +++ b/spot/misc/satsolver.hh @@ -24,6 +24,7 @@ #include #include #include +#include namespace spot { @@ -68,14 +69,16 @@ namespace spot /// \brief Interface with a SAT solver. /// - /// Call start() to create some temporary file, then send DIMACs - /// text to the stream returned by operator(), and finally call - /// get_solution(). + /// Call start() to initialize the cnf file. This class provides the + /// necessary functions to handle the cnf file, add clauses, count them, + /// update the header, add some comments... + /// It is not possible to write in the file without having to call these + /// functions. /// /// The satsolver called can be configured via the - /// SPOT_SATSOLVER environment variable. It - /// defaults to - /// "satsolver -verb=0 %I >%O" + /// SPOT_SATSOLVER environment variable. It must be this set + /// following this: "satsolver -verb=0 %I >%O". + /// /// where %I and %O are replaced by input and output files. class SPOT_API satsolver { @@ -83,15 +86,67 @@ namespace spot satsolver(); ~satsolver(); + /// \brief Initialize private attributes void start(); - std::ostream& operator()(); + + /// \brief Add a list of lit. to the current clause. + void add(std::initializer_list values); + + /// \brief Add a single lit. to the current clause. + void add(int v); + + /// \breif Get the current number of clauses. + int get_nb_clauses() const; + + /// \breif Update cnf_file's header with the correct stats. + std::pair stats(int nvars); + + /// \breif Create an unsatisfiable cnf_file, return stats about it. + std::pair stats(); + + /// \breif Add a comment in cnf file. + template + void comment_rec(T single) + { + *cnf_stream_ << single << ' '; + } + + /// \breif Add a comment in cnf_file. + template + void comment_rec(T first, Args... args) + { + *cnf_stream_ << first << ' '; + comment_rec(args...); + } + + /// \breif Add a comment in the cnf_file, starting with 'c'. + template + void comment(T single) + { + *cnf_stream_ << "c " << single << ' '; + } + + /// \breif Add comment in the cnf_file, starting with 'c'. + template + void comment(T first, Args... args) + { + *cnf_stream_ << "c " << first << ' '; + comment_rec(args...); + } typedef std::vector solution; typedef std::pair solution_pair; solution_pair get_solution(); + + private: + /// \breif End the current clause and increment the counter. + void end_clause(); + private: temporary_file* cnf_tmp_; std::ostream* cnf_stream_; + clause_counter* nclauses_; + }; /// \brief Extract the solution of a SAT solver output. diff --git a/spot/twaalgos/dtbasat.cc b/spot/twaalgos/dtbasat.cc index 0f250fa9c..a54c3a67d 100644 --- a/spot/twaalgos/dtbasat.cc +++ b/spot/twaalgos/dtbasat.cc @@ -43,8 +43,10 @@ #define DEBUG 0 #if DEBUG #define dout out << "c " +#define cnf_comment(...) solver.comment(__VA_ARGS__) #define trace std::cerr #else +#define cnf_comment(...) while (0) solver.comment(__VA_ARGS__) #define dout while (0) std::cout #define trace dout #endif @@ -294,12 +296,10 @@ namespace spot typedef std::pair sat_stats; static - sat_stats dtba_to_sat(std::ostream& out, + sat_stats dtba_to_sat(satsolver solver, const const_twa_graph_ptr& ref, dict& d, bool state_based) { - clause_counter nclauses; - // Compute the AP used in the hard way. bdd ap = bddtrue; for (auto& t: ref->edges()) @@ -324,21 +324,15 @@ namespace spot // empty automaton is impossible if (d.cand_size == 0) - { - out << "p cnf 1 2\n-1 0\n1 0\n"; - return std::make_pair(1, 2); - } - - // An empty line for the header - out << " \n"; + return solver.stats(); #if DEBUG debug_dict = ref->get_dict(); - dout << "ref_size: " << ref_size << '\n'; - dout << "cand_size: " << d.cand_size << '\n'; + solver.comment("ref_size", ref_size, '\n'); + solver.comment("cand_size", d.cand_size, '\n'); #endif - dout << "symmetry-breaking clauses\n"; + cnf_comment("symmetry-breaking clauses\n"); unsigned j = 0; bdd all = bddtrue; while (all != bddfalse) @@ -350,16 +344,15 @@ namespace spot { transition t(i, s, k); int ti = d.transid[t]; - dout << "¬" << t << '\n'; - out << -ti << " 0\n"; - ++nclauses; + cnf_comment("¬", t, '\n'); + solver.add({-ti, 0}); } ++j; } - if (!nclauses.nb_clauses()) - dout << "(none)\n"; + if (!solver.get_nb_clauses()) + cnf_comment("(none)\n"); - dout << "(1) the candidate automaton is complete\n"; + cnf_comment("(1) the candidate automaton is complete\n"); for (unsigned q1 = 0; q1 < d.cand_size; ++q1) { bdd all = bddtrue; @@ -369,36 +362,32 @@ namespace spot all -= s; #if DEBUG - dout; + solver.comment(""); for (unsigned q2 = 0; q2 < d.cand_size; q2++) { transition t(q1, s, q2); - out << t << "δ"; + solver.comment_rec(t, "δ"); if (q2 != d.cand_size) - out << " ∨ "; + solver.comment_rec(" ∨ "); } - out << '\n'; + solver.comment_rec('\n'); #endif for (unsigned q2 = 0; q2 < d.cand_size; q2++) { transition t(q1, s, q2); int ti = d.transid[t]; - - out << ti << ' '; + solver.add(ti); } - out << "0\n"; - - ++nclauses; + solver.add(0); } } - dout << "(2) the initial state is reachable\n"; + cnf_comment("(2) the initial state is reachable\n"); { unsigned init = ref->get_init_state_number(); - dout << state_pair(0, init) << '\n'; - out << d.prodid[state_pair(0, init)] << " 0\n"; - ++nclauses; + cnf_comment(state_pair(0, init), '\n'); + solver.add({d.prodid[state_pair(0, init)], 0}); } for (std::map::const_iterator pit = d.prodid.begin(); @@ -407,8 +396,8 @@ namespace spot unsigned q1 = pit->first.a; unsigned q1p = pit->first.b; - dout << "(3) augmenting paths based on Cand[" << q1 - << "] and Ref[" << q1p << "]\n"; + cnf_comment("(3) augmenting paths based on Cand[", q1, "] and Ref[", + q1p, "]\n"); for (auto& tr: ref->out(q1p)) { unsigned dp = tr.dst; @@ -429,10 +418,8 @@ namespace spot if (pit->second == succ) continue; - dout << pit->first << " ∧ " << t << "δ → " << p2 << '\n'; - out << -pit->second << ' ' << -ti << ' ' - << succ << " 0\n"; - ++nclauses; + cnf_comment(pit->first, " ∧ ", t, "δ → ", p2, '\n'); + solver.add({-pit->second, -ti, succ, 0}); } } } @@ -463,8 +450,8 @@ namespace spot { path p1(q1, q1p, q2, q2p); - dout << "(4&5) matching paths from reference based on " - << p1 << '\n'; + cnf_comment("(4&5) matching paths from reference based on", + p1, '\n'); int pid1; if (q1 == q2 && q1p == q2p) @@ -495,11 +482,9 @@ namespace spot int ti = d.transid[t]; int ta = d.transacc[t]; - dout << p1 << "R ∧ " << t << "δ → ¬" << t - << "F\n"; - out << -pid1 << ' ' << -ti << ' ' - << -ta << " 0\n"; - ++nclauses; + cnf_comment(p1, "R ∧", t, "δ → ¬", t, + "F\n"); + solver.add({-pid1, -ti, -ta, 0}); } @@ -521,11 +506,8 @@ namespace spot transition t(q2, s, q3); int ti = d.transid[t]; - dout << p1 << "R ∧ " << t << "δ → " << p2 - << "R\n"; - out << -pid1 << ' ' << -ti << ' ' - << pid2 << " 0\n"; - ++nclauses; + cnf_comment(p1, "R ∧", t, "δ →", p2, "R\n"); + solver.add({-pid1, -ti, pid2, 0}); } } } @@ -555,8 +537,8 @@ namespace spot for (unsigned q2 = 0; q2 < d.cand_size; ++q2) { path p1(q1, q1p, q2, q2p); - dout << "(6&7) matching paths from candidate based on " - << p1 << '\n'; + cnf_comment("(6&7) matching paths from candidate based on", + p1, '\n'); int pid1; if (q1 == q2 && q1p == q2p) @@ -588,11 +570,8 @@ namespace spot int ti = d.transid[t]; int ta = d.transacc[t]; - dout << p1 << "C ∧ " << t << "δ → " << t - << "F\n"; - out << -pid1 << ' ' << -ti << ' ' << ta - << " 0\n"; - ++nclauses; + cnf_comment(p1, "C ∧", t, "δ →", t, "F\n"); + solver.add({-pid1, -ti, ta, 0}); } } else // (7) no loop @@ -613,12 +592,9 @@ namespace spot int ti = d.transid[t]; int ta = d.transacc[t]; - dout << p1 << "C ∧ " << t << "δ ∧ ¬" - << t << "F → " << p2 << "C\n"; - - out << -pid1 << ' ' << -ti << ' ' - << ta << ' ' << pid2 << " 0\n"; - ++nclauses; + cnf_comment(p1, "C ∧", t, "δ ∧ ¬", t, + "F →", p2, "C\n"); + solver.add({-pid1, -ti, ta, pid2, 0}); } } } @@ -626,9 +602,7 @@ namespace spot } } } - out.seekp(0); - out << "p cnf " << d.nvars << ' ' << nclauses.nb_clauses(); - return std::make_pair(d.nvars, nclauses.nb_clauses()); + return solver.stats(d.nvars); } static twa_graph_ptr @@ -757,7 +731,7 @@ namespace spot timer_map t; t.start("encode"); - sat_stats s = dtba_to_sat(solver(), a, d, state_based); + sat_stats s = dtba_to_sat(solver, a, d, state_based); t.stop("encode"); t.start("solve"); solution = solver.get_solution(); diff --git a/spot/twaalgos/dtwasat.cc b/spot/twaalgos/dtwasat.cc index 581dd9625..51bff5b47 100644 --- a/spot/twaalgos/dtwasat.cc +++ b/spot/twaalgos/dtwasat.cc @@ -51,8 +51,10 @@ #define DEBUG 0 #if DEBUG #define dout out << "c " +#define cnf_comment(...) solver.comment(__VA_ARGS__) #define trace std::cerr #else +#define cnf_comment(...) while (0) solver.comment(__VA_ARGS__) #define dout while (0) std::cout #define trace dout #endif @@ -596,13 +598,12 @@ namespace spot typedef std::pair sat_stats; static - sat_stats dtwa_to_sat(std::ostream& out, const_twa_graph_ptr ref, + sat_stats dtwa_to_sat(satsolver solver, const_twa_graph_ptr ref, dict& d, bool state_based, bool colored) { #if DEBUG debug_dict = ref->get_dict(); #endif - clause_counter nclauses; // Compute the AP used in the hard way. bdd ap = bddtrue; @@ -629,23 +630,17 @@ namespace spot // empty automaton is impossible if (d.cand_size == 0) - { - out << "p cnf 1 2\n-1 0\n1 0\n"; - return std::make_pair(1, 2); - } - - // An empty line for the header - out << " \n"; + return solver.stats(); #if DEBUG debug_ref_acc = &ref->acc(); debug_cand_acc = &d.cacc; - dout << "ref_size: " << ref_size << '\n'; - dout << "cand_size: " << d.cand_size << '\n'; + solver.comment("ref_size:", ref_size, '\n'); + solver.comment("cand_size:", d.cand_size, '\n'); #endif auto& racc = ref->acc(); - dout << "symmetry-breaking clauses\n"; + cnf_comment("symmetry-breaking clauses\n"); int j = 0; bdd all = bddtrue; while (all != bddfalse) @@ -657,16 +652,15 @@ namespace spot { transition t(i, s, k); int ti = d.transid[t]; - dout << "¬" << t << '\n'; - out << -ti << " 0\n"; - ++nclauses; + cnf_comment("¬", t, '\n'); + solver.add({-ti, 0}); } ++j; } - if (!nclauses.nb_clauses()) - dout << "(none)\n"; + if (!solver.get_nb_clauses()) + cnf_comment("(none)\n"); - dout << "(8) the candidate automaton is complete\n"; + cnf_comment("(8) the candidate automaton is complete\n"); for (unsigned q1 = 0; q1 < d.cand_size; ++q1) { bdd all = bddtrue; @@ -676,15 +670,15 @@ namespace spot all -= s; #if DEBUG - dout; + solver.comment(""); for (unsigned q2 = 0; q2 < d.cand_size; ++q2) { transition t(q1, s, q2); - out << t << "δ"; + solver.comment_rec(t, "δ"); if (q2 != d.cand_size) - out << " ∨ "; + solver.comment_rec(" ∨ "); } - out << '\n'; + solver.comment_rec('\n'); #endif for (unsigned q2 = 0; q2 < d.cand_size; ++q2) @@ -692,26 +686,24 @@ namespace spot transition t(q1, s, q2); int ti = d.transid[t]; - out << ti << ' '; + solver.add(ti); } - out << "0\n"; - ++nclauses; + solver.add(0); } } - dout << "(9) the initial state is reachable\n"; + cnf_comment("(9) the initial state is reachable\n"); { unsigned init = ref->get_init_state_number(); - dout << path(0, init) << '\n'; - out << d.pathid[path(0, init)] << " 0\n"; - ++nclauses; + cnf_comment(path(0, init), '\n'); + solver.add({d.pathid[path(0, init)], 0}); } if (colored) { unsigned nacc = d.cand_nacc; - dout << "transitions belong to exactly one of the " - << nacc << " acceptance set\n"; + cnf_comment("transitions belong to exactly one of the", nacc, + "acceptance set\n"); bdd all = bddtrue; while (all != bddfalse) { @@ -730,25 +722,23 @@ namespace spot { transition_acc tj(q1, l, {j}, q2); int taj = d.transaccid[tj]; - out << -tai << ' ' << -taj << " 0\n"; - ++nclauses; + solver.add({-tai, -taj, 0}); } } for (unsigned i = 0; i < nacc; ++i) { transition_acc ti(q1, l, {i}, q2); int tai = d.transaccid[ti]; - out << tai << ' '; + solver.add(tai); } - out << "0\n"; - ++nclauses; + solver.add(0); } } } if (!d.all_silly_cand_acc.empty()) { - dout << "no transition with silly acceptance\n"; + cnf_comment("no transition with silly acceptance\n"); bdd all = bddtrue; while (all != bddfalse) { @@ -758,25 +748,24 @@ namespace spot for (unsigned q2 = 0; q2 < d.cand_size; ++q2) for (auto& s: d.all_silly_cand_acc) { - dout << "no (" << q1 << ',' - << bdd_format_formula(debug_dict, l) - << ',' << s << ',' << q2 << ")\n"; + cnf_comment("no (", q1, ',', + bdd_format_formula(debug_dict, l), ',', s, + ',', q2, ")\n"); for (unsigned v: s.sets()) { transition_acc ta(q1, l, d.cacc.mark(v), q2); int tai = d.transaccid[ta]; assert(tai != 0); - out << ' ' << -tai; + solver.add(-tai); } for (unsigned v: d.cacc.comp(s).sets()) { transition_acc ta(q1, l, d.cacc.mark(v), q2); int tai = d.transaccid[ta]; assert(tai != 0); - out << ' ' << tai; + solver.add(tai); } - out << " 0\n"; - ++nclauses; + solver.add(0); } } } @@ -786,8 +775,8 @@ namespace spot { if (!sm.reachable_state(q1p)) continue; - dout << "(10) augmenting paths based on Cand[" << q1 - << "] and Ref[" << q1p << "]\n"; + cnf_comment("(10) augmenting paths based on Cand[", q1, + "] and Ref[", q1p, "]\n"); path p1(q1, q1p); int p1id = d.pathid[p1]; @@ -811,9 +800,8 @@ namespace spot if (p1id == succ) continue; - dout << p1 << " ∧ " << t << "δ → " << p2 << '\n'; - out << -p1id << ' ' << -ti << ' ' << succ << " 0\n"; - ++nclauses; + cnf_comment(p1, "∧", t, "δ →", p2, '\n'); + solver.add({-p1id, -ti, succ, 0}); } } } @@ -854,7 +842,8 @@ namespace spot path p(q1, q1p, q2, q2p, d.all_cand_acc[f], refhist); - dout << "(11&12&13) paths from " << p << '\n'; + cnf_comment("(11&12&13) paths from ", p, + '\n'); int pid = d.pathid[p]; @@ -892,10 +881,9 @@ namespace spot for (auto& v: missing) { #if DEBUG - dout << (rejloop ? - "(11) " : "(12) ") - << p << " ∧ " - << t << "δ → ("; + solver.comment((rejloop ? + "(11) " : "(12) "), p, + " ∧ ", t, "δ → ("); const char* orsep = ""; for (int s: v) { @@ -905,21 +893,23 @@ namespace spot ta(q2, l, d.cacc.mark(-s - 1), q1); - out << orsep << "¬" << ta; + solver.comment_rec(orsep, + "¬", ta); } else { transition_acc ta(q2, l, d.cacc.mark(s), q1); - out << orsep << ta; + solver.comment_rec(orsep, + ta); } - out << "FC"; + solver.comment_rec("FC"); orsep = " ∨ "; } - out << ")\n"; + solver.comment_rec(")\n"); #endif // DEBUG - out << -pid << ' ' << -ti; + solver.add({-pid, -ti}); for (int s: v) if (s < 0) { @@ -929,7 +919,7 @@ namespace spot q1); int tai = d.transaccid[ta]; assert(tai != 0); - out << ' ' << -tai; + solver.add(-tai); } else { @@ -938,10 +928,9 @@ namespace spot d.cacc.mark(s), q1); int tai = d.transaccid[ta]; assert(tai != 0); - out << ' ' << tai; + solver.add(tai); } - out << " 0\n"; - ++nclauses; + solver.add(0); } } // (13) augmenting paths (always). @@ -964,8 +953,8 @@ namespace spot if (pid == p2id) continue; #if DEBUG - dout << "(13) " << p << " ∧ " - << t << "δ "; + solver.comment("(13) ", p, " ∧ ", + t, "δ "); auto biga_ = d.all_cand_acc[f]; for (unsigned m = 0; @@ -977,12 +966,13 @@ namespace spot const char* not_ = "¬"; if (biga_.has(m)) not_ = ""; - out << " ∧ " << not_ - << ta << "FC"; + solver.comment_rec(" ∧ ", not_, + ta, "FC"); } - out << " → " << p2 << '\n'; + solver.comment_rec(" → ", p2, + '\n'); #endif - out << -pid << ' ' << -ti << ' '; + solver.add({-pid, -ti}); auto biga = d.all_cand_acc[f]; for (unsigned m = 0; m < d.cand_nacc; ++m) @@ -993,11 +983,9 @@ namespace spot int tai = d.transaccid[ta]; if (biga.has(m)) tai = -tai; - out << tai << ' '; + solver.add(tai); } - - out << p2id << " 0\n"; - ++nclauses; + solver.add({p2id, 0}); } } } @@ -1007,9 +995,7 @@ namespace spot } } } - out.seekp(0); - out << "p cnf " << d.nvars << ' ' << nclauses.nb_clauses(); - return std::make_pair(d.nvars, nclauses.nb_clauses()); + return solver.stats(d.nvars); } static twa_graph_ptr @@ -1136,7 +1122,7 @@ namespace spot timer_map t; t.start("encode"); - sat_stats s = dtwa_to_sat(solver(), a, d, state_based, colored); + sat_stats s = dtwa_to_sat(solver, a, d, state_based, colored); t.stop("encode"); t.start("solve"); solution = solver.get_solution();