Introduce a faster merge_states

merge_states is now hash-based, uses the new edge-sorting with
src first and can be executed in parallel.

* spot/twa/twagraph.cc: Here
* tests/python/mergedge.py: Test
This commit is contained in:
Philipp Schlehuber-Caissier 2022-05-16 00:04:04 +02:00
parent 71c2a7b1a6
commit d8cc0c5acb
2 changed files with 253 additions and 193 deletions

View file

@ -372,13 +372,12 @@ namespace spot
throw std::runtime_error( throw std::runtime_error(
"twa_graph::merge_states() does not work on alternating automata"); "twa_graph::merge_states() does not work on alternating automata");
const unsigned nthreads = get_nthreads();
typedef graph_t::edge_storage_t tr_t; typedef graph_t::edge_storage_t tr_t;
g_.sort_edges_([](const tr_t& lhs, const tr_t& rhs) g_.sort_edges_srcfirst_([](const tr_t& lhs, const tr_t& rhs)
{ {
if (lhs.src < rhs.src) assert(lhs.src == rhs.src);
return true;
if (lhs.src > rhs.src)
return false;
if (lhs.acc < rhs.acc) if (lhs.acc < rhs.acc)
return true; return true;
if (lhs.acc > rhs.acc) if (lhs.acc > rhs.acc)
@ -449,7 +448,7 @@ namespace spot
// Represents which states share a hash // Represents which states share a hash
// Head is in the unordered_map, // Head is in the unordered_map,
// hash_linked_list is like a linked list structure // hash_linked_list is like a linked list structure
// of false pointers // of fake pointers
auto hash_linked_list = std::vector<unsigned>(n_states, -1u); auto hash_linked_list = std::vector<unsigned>(n_states, -1u);
auto s_to_hash = std::vector<size_t>(n_states, 0); auto s_to_hash = std::vector<size_t>(n_states, 0);
@ -530,8 +529,8 @@ namespace spot
}; };
static auto checked1 = std::vector<char>(); thread_local auto checked1 = std::vector<char>();
static auto checked2 = std::vector<char>(); thread_local auto checked2 = std::vector<char>();
auto [i1, nsl1, sl1, e1] = e_idx[s1]; auto [i1, nsl1, sl1, e1] = e_idx[s1];
auto [i2, nsl2, sl2, e2] = e_idx[s2]; auto [i2, nsl2, sl2, e2] = e_idx[s2];
@ -585,12 +584,10 @@ namespace spot
// More efficient version? // More efficient version?
// Skip checked edges // Skip checked edges
// Last element serves as break // Last element serves as break
for (; checked1[idx1 - i1]; ++idx1) while (checked1[idx1 - i1])
{ ++idx1;
} while (checked2[idx2 - i2])
for (; checked2[idx2 - i2]; ++idx2) ++idx2;
{
}
// If one is out of bounds, so is the other // If one is out of bounds, so is the other
if (idx1 == e1) if (idx1 == e1)
{ {
@ -614,12 +611,22 @@ namespace spot
const unsigned nb_states = num_states(); const unsigned nb_states = num_states();
std::vector<unsigned> remap(nb_states, -1U); std::vector<unsigned> remap(nb_states, -1U);
for (unsigned i = 0; i != nb_states; ++i) // Check each hash
auto check_ix = [&](unsigned ix)
{ {
auto j = spe && (*sp)[i] ? player_map.at(s_to_hash[i]).first // Reduce cache miss
: env_map.at(s_to_hash[i]).first; thread_local auto v = std::vector<unsigned>();
for (; j<i; j=hash_linked_list[j]) v.clear();
for (auto i = ix; i != -1U; i = hash_linked_list[i])
v.push_back(i);
const unsigned N = v.size();
for (unsigned idx = 0; idx < N; ++idx)
{ {
auto i = v[idx];
for (unsigned jdx = 0; jdx < idx; ++jdx)
{
auto j = v[jdx];
if (state_equal(j, i)) if (state_equal(j, i))
{ {
remap[i] = (remap[j] != -1U) ? remap[j] : j; remap[i] = (remap[j] != -1U) ? remap[j] : j;
@ -655,6 +662,56 @@ namespace spot
} }
} }
} }
};
auto upd = [](auto& b, const auto&e, unsigned it)
{
while ((it > 0) & (b != e))
{
--it;
++b;
}
};
auto worker = [&upd, check_ix, nthreads](unsigned pid, auto begp, auto endp,
auto bege, auto ende)
{
upd(begp, endp, pid);
upd(bege, ende, pid);
for (; begp != endp; upd(begp, endp, nthreads))
check_ix(begp->second.first);
for (; bege != ende; upd(bege, ende, nthreads))
check_ix(bege->second.first);
};
{
auto begp = player_map.begin();
auto endp = player_map.end();
auto bege = env_map.begin();
auto ende = env_map.end();
if ((nthreads == 1) & (num_states() > 1000)) // Bound?
{
worker(0, begp, endp, bege, ende);
}
else
{
static auto tv = std::vector<std::thread>();
assert(tv.empty());
tv.resize(nthreads);
for (unsigned pid = 0; pid < nthreads; ++pid)
tv[pid] = std::thread(
[worker, pid, begp, endp, bege, ende]()
{
worker(pid, begp, endp, bege, ende);
return;
});
for (auto& t : tv)
t.join();
tv.clear();
}
}
for (auto& e: edges()) for (auto& e: edges())
if (remap[e.dst] != -1U) if (remap[e.dst] != -1U)
@ -765,7 +822,7 @@ namespace spot
comp_classes_.clear(); comp_classes_.clear();
// get all compatible classes // get all compatible classes
// Candidate classes share a hash // Candidate classes share a hash
// A state is compatible to a class if it is compatble // A state is compatible to a class if it is compatible
// to any of its states // to any of its states
auto& cand_classes = equiv_class_[hi]; auto& cand_classes = equiv_class_[hi];
unsigned n_c_classes = cand_classes.size(); unsigned n_c_classes = cand_classes.size();

View file

@ -23,6 +23,9 @@ import spot
from unittest import TestCase from unittest import TestCase
tc = TestCase() tc = TestCase()
for nthread in range(1, 16, 2):
spot.set_nthreads(nthread)
tc.assertEqual(spot.get_nthreads(), nthread)
aut = spot.automaton("""HOA: v1 States: 1 Start: 0 AP: 1 "a" aut = spot.automaton("""HOA: v1 States: 1 Start: 0 AP: 1 "a"
Acceptance: 1 Inf(0) --BODY-- State: 0 [0] 0 [0] 0 {0} --END--""") Acceptance: 1 Inf(0) --BODY-- State: 0 [0] 0 [0] 0 {0} --END--""")
tc.assertEqual(aut.num_edges(), 2) tc.assertEqual(aut.num_edges(), 2)