Introduce a faster merge_states
merge_states is now hash-based, uses the new edge-sorting with src first and can be executed in parallel. * spot/twa/twagraph.cc: Here * tests/python/mergedge.py: Test
This commit is contained in:
parent
71c2a7b1a6
commit
d8cc0c5acb
2 changed files with 253 additions and 193 deletions
|
|
@ -372,13 +372,12 @@ namespace spot
|
|||
throw std::runtime_error(
|
||||
"twa_graph::merge_states() does not work on alternating automata");
|
||||
|
||||
const unsigned nthreads = get_nthreads();
|
||||
|
||||
typedef graph_t::edge_storage_t tr_t;
|
||||
g_.sort_edges_([](const tr_t& lhs, const tr_t& rhs)
|
||||
g_.sort_edges_srcfirst_([](const tr_t& lhs, const tr_t& rhs)
|
||||
{
|
||||
if (lhs.src < rhs.src)
|
||||
return true;
|
||||
if (lhs.src > rhs.src)
|
||||
return false;
|
||||
assert(lhs.src == rhs.src);
|
||||
if (lhs.acc < rhs.acc)
|
||||
return true;
|
||||
if (lhs.acc > rhs.acc)
|
||||
|
|
@ -449,7 +448,7 @@ namespace spot
|
|||
// Represents which states share a hash
|
||||
// Head is in the unordered_map,
|
||||
// hash_linked_list is like a linked list structure
|
||||
// of false pointers
|
||||
// of fake pointers
|
||||
|
||||
auto hash_linked_list = std::vector<unsigned>(n_states, -1u);
|
||||
auto s_to_hash = std::vector<size_t>(n_states, 0);
|
||||
|
|
@ -530,8 +529,8 @@ namespace spot
|
|||
};
|
||||
|
||||
|
||||
static auto checked1 = std::vector<char>();
|
||||
static auto checked2 = std::vector<char>();
|
||||
thread_local auto checked1 = std::vector<char>();
|
||||
thread_local auto checked2 = std::vector<char>();
|
||||
|
||||
auto [i1, nsl1, sl1, e1] = e_idx[s1];
|
||||
auto [i2, nsl2, sl2, e2] = e_idx[s2];
|
||||
|
|
@ -585,12 +584,10 @@ namespace spot
|
|||
// More efficient version?
|
||||
// Skip checked edges
|
||||
// Last element serves as break
|
||||
for (; checked1[idx1 - i1]; ++idx1)
|
||||
{
|
||||
}
|
||||
for (; checked2[idx2 - i2]; ++idx2)
|
||||
{
|
||||
}
|
||||
while (checked1[idx1 - i1])
|
||||
++idx1;
|
||||
while (checked2[idx2 - i2])
|
||||
++idx2;
|
||||
// If one is out of bounds, so is the other
|
||||
if (idx1 == e1)
|
||||
{
|
||||
|
|
@ -614,12 +611,22 @@ namespace spot
|
|||
const unsigned nb_states = num_states();
|
||||
std::vector<unsigned> remap(nb_states, -1U);
|
||||
|
||||
for (unsigned i = 0; i != nb_states; ++i)
|
||||
// Check each hash
|
||||
auto check_ix = [&](unsigned ix)
|
||||
{
|
||||
auto j = spe && (*sp)[i] ? player_map.at(s_to_hash[i]).first
|
||||
: env_map.at(s_to_hash[i]).first;
|
||||
for (; j<i; j=hash_linked_list[j])
|
||||
// Reduce cache miss
|
||||
thread_local auto v = std::vector<unsigned>();
|
||||
v.clear();
|
||||
for (auto i = ix; i != -1U; i = hash_linked_list[i])
|
||||
v.push_back(i);
|
||||
const unsigned N = v.size();
|
||||
|
||||
for (unsigned idx = 0; idx < N; ++idx)
|
||||
{
|
||||
auto i = v[idx];
|
||||
for (unsigned jdx = 0; jdx < idx; ++jdx)
|
||||
{
|
||||
auto j = v[jdx];
|
||||
if (state_equal(j, i))
|
||||
{
|
||||
remap[i] = (remap[j] != -1U) ? remap[j] : j;
|
||||
|
|
@ -655,6 +662,56 @@ namespace spot
|
|||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto upd = [](auto& b, const auto&e, unsigned it)
|
||||
{
|
||||
while ((it > 0) & (b != e))
|
||||
{
|
||||
--it;
|
||||
++b;
|
||||
}
|
||||
};
|
||||
|
||||
auto worker = [&upd, check_ix, nthreads](unsigned pid, auto begp, auto endp,
|
||||
auto bege, auto ende)
|
||||
{
|
||||
upd(begp, endp, pid);
|
||||
upd(bege, ende, pid);
|
||||
for (; begp != endp; upd(begp, endp, nthreads))
|
||||
check_ix(begp->second.first);
|
||||
for (; bege != ende; upd(bege, ende, nthreads))
|
||||
check_ix(bege->second.first);
|
||||
};
|
||||
|
||||
{
|
||||
auto begp = player_map.begin();
|
||||
auto endp = player_map.end();
|
||||
auto bege = env_map.begin();
|
||||
auto ende = env_map.end();
|
||||
|
||||
|
||||
if ((nthreads == 1) & (num_states() > 1000)) // Bound?
|
||||
{
|
||||
worker(0, begp, endp, bege, ende);
|
||||
}
|
||||
else
|
||||
{
|
||||
static auto tv = std::vector<std::thread>();
|
||||
assert(tv.empty());
|
||||
tv.resize(nthreads);
|
||||
for (unsigned pid = 0; pid < nthreads; ++pid)
|
||||
tv[pid] = std::thread(
|
||||
[worker, pid, begp, endp, bege, ende]()
|
||||
{
|
||||
worker(pid, begp, endp, bege, ende);
|
||||
return;
|
||||
});
|
||||
for (auto& t : tv)
|
||||
t.join();
|
||||
tv.clear();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& e: edges())
|
||||
if (remap[e.dst] != -1U)
|
||||
|
|
@ -765,7 +822,7 @@ namespace spot
|
|||
comp_classes_.clear();
|
||||
// get all compatible classes
|
||||
// Candidate classes share a hash
|
||||
// A state is compatible to a class if it is compatble
|
||||
// A state is compatible to a class if it is compatible
|
||||
// to any of its states
|
||||
auto& cand_classes = equiv_class_[hi];
|
||||
unsigned n_c_classes = cand_classes.size();
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ import spot
|
|||
from unittest import TestCase
|
||||
tc = TestCase()
|
||||
|
||||
for nthread in range(1, 16, 2):
|
||||
spot.set_nthreads(nthread)
|
||||
tc.assertEqual(spot.get_nthreads(), nthread)
|
||||
aut = spot.automaton("""HOA: v1 States: 1 Start: 0 AP: 1 "a"
|
||||
Acceptance: 1 Inf(0) --BODY-- State: 0 [0] 0 [0] 0 {0} --END--""")
|
||||
tc.assertEqual(aut.num_edges(), 2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue