読者です 読者をやめる 読者になる 読者になる

ThreeRooks

 この記事はTSG Advent Calendar 2016 - Adventarの24日目の記事として書かれました。
 この記事は、僕がたまたま通すことができたAOJ-ICPC難易度1000のThreeRooksという問題を一緒に解いていこうという趣旨のものです。完全にネタバレなのでもし自力で解きたいと思っている人はすぐにブラウザバックをしてください。僕が思考したステップを順にあげていくので途中でわかったらそこから自力で解くのでもいいかもしれません。

 問題文はhttp://judge.u-aizu.ac.jp/onlinejudge/IMAGE/winter-camp-2011-day3.pdfにあります。とても広いチェスの盤面が与えられ、10^5個以下の障害物が適当に置かれるのでどのルークも互いに取られないような三つのルークの置き方の組み合わせの数を求めよという問題です。チェス盤をh*w、障害物の数がkと置いています。

 ~つ目の思考とか分けていますが、分け方は適当なのであまり気にしないでください。
 

一つ目の思考

 この問題におけるチェスの盤面は一辺が10^9もありますが、障害物はせいぜい10^5個であり、障害物がおかれていない行や列はまとめて考えることができます。一辺が大きいのには変わりないので処理によってはどうしてもlong longのサイズを超えてしまいます。適切にmodを取ったり、割り算のときはmodの逆元を取ったりしましょう。

二つ目の思考

 ルークが三つもあると難しいのでとりあえずルーク二つで考えてみます。
 ルークが二つのときは比較的簡単で、とりあえずルーク二つの置き方を全部列挙して、そのうち条件を満たさないものを除くという考え方に至ると思います。
 ルーク二つの置き方は、\frac{(w*h-k)*(w*h-k-1)}{2}通りです。
 ルーク二つの置き方のうち条件を満たさないものは、行で被るものと列で被るものがあります。行も同様なので列について考えます。障害物を置かれていない列がl本あるとすると、障害物を置かれていない列でルークが被る置き方はl*h*(h-1)/2通りあります。また、障害物が置かれている列については、障害物一つにつき一回分割が増えるだけなので一つ一つ処理をしてやればいいです。

三つ目の思考

 ルークが二つあるときはまずは置き方を全部数えてから条件を満たさないものを除くという考え方がうまくいきました。これを三つのときにも適用してみようと思います。では、どのようなときに条件を満たさないのかと考えると、次の三つの場合が考えられます。

  1. 三つのルークがそれぞれ取れる状態で一直線に並んでいる。(間に障害物が含まれていない)
  2. 二つのルークがお互いに取れる状態になっており、もう一つのルークがどちらのルークにも取られないところにいる。
  3. 二つのルークがお互いに取れる状態になっており、もう一つのルークがどちらかのルークとだけお互いに取れる状態になっている。(L字型の配置)
四つ目の思考

 まずは一つ目のケースについて考えます。
 この場合は簡単で、ルークが二つだった場合と本質的に同じことをするだけで式を少し変えるだけです。

五つ目の思考

 次に二つ目のケースです。
 ルーク二つの配置をお互いに取り合うように配置して、残りの一つのルークを適当に置いてみます。すると、1や3のケースも勿論含んでしまうのですが、何回含むのかを計算することができます。1のケースについては3回数えてしまい、2のケースについては2回数えてしまうことがわかります。
 二つのルークの置き方を二つ目の思考でやったようにして求め、それにh*w-k-2をかけてやった数をX、一つ目のケースの置き方をY通り、三つ目のケースの置き方をZ通りとすると、求めるべき条件を満たさない置き方はX-2*Y-Z通りとなります。Xは上記のように求まり、Yは四つ目の思考で求めているので、後は三つ目のケースであるZを求めればよいです。

六つ目の思考

 最後に三つ目のケースです。これができればこの問題が解けたことになります。そして、長々と書きましたがこのパートがこの問題の本質です。
 ルークの配置はL字型をしているので、ルークのうち二つは障害物を挟まず縦に並んでいます。また、残りの一つのルークはその二つのルークのうちどちらかの横に障害物を挟まず並んでいるはずです。よって、縦に並べるルークを決めたとき、残り一つのルークの置き方がいくつあるかを全て列挙することを考えますが、勿論これを満たす配置はかなり多いので普通に数えていては話になりません。
 まず、縦に並んでいるルーク二つの配置について考えます。障害物が置かれていない列が隣り合っている場合は圧縮することができる*1ので圧縮します。すると、列の数は最大で2*k+1になります。よって、端から端まで列を順に考えていくことが可能なサイズになるので、端から順に処理をしていきます。

七つ目の思考

 今いる列について、その各要素の左右にはどれだけスペース*2が空いているのかを考えます。障害物がない行についてはw-1で一定値なのでまとめて扱うことができ、障害物がある行については障害物に当たったときだけ更新をすればよいので、今いる行の各要素の左右にどれだけスペースがあるかは全体でもO(k)で求めることができます。また、各要素について上下のスペースも求めることができれば、(左右のスペース)*(上下のスペース)を計算することで、その要素をL字の角の部分にするL字を全て列挙することができます。そして、各要素について上下のスペースは入力時に受け取った情報からすぐにわかります。よって、各列の全ての要素について(左右のスペース)*(上下のスペース)を求めることで全てのL字を列挙することができます。しかし、列や行を圧縮したとしても、列数や行数は最大で2*10^5+1もあるので、このままでは計算量がO(10^{10})になってしまい、間に合いません。

八つ目の思考

 七つ目の思考において、今いる列の各要素について(左右のスペース)*(上下のスペース)を計算しましたが、これを効率よく計算できないかと考えます。まず、(上下のスペース)はそのスペースに属しているマスと共通であることがわかります。それを一つのグループとみなします。すると、そのグループにおけるL字の数を求めるには各要素の左右のスペースの総和を求めればいいことになります。そして、これは各要素の左右のスペースをBITで管理をすることによって対数時間で処理をすることができます。グループの更新は障害物が来るときにしか起きないので全体でもO(k)しかかかりません。よって、これらをまとめると解けます。計算量BITの更新の部分が一番重く、O(k \log k)です*3

実装

 なんとなく察してきたと思いますが、この問題は実装がメンドイです*4。僕もたくさんバグらせてしまいました。最後に僕のコードを貼って終わりにしますが、僕自身もコードを読解したくないぐらいなので真面目に読むものじゃないです。

#include <bits/stdc++.h>
#define rep(i,n) for(int i = 0; i < n; i++)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> P;
typedef pair<ll,P> PP;
const long long int MOD = 1000000007;
const long long int INF = 1000000000;
 
ll x, y, k, al;
ll idsize;
map<ll,ll> toid;
vector<P> xy, yx;
ll ans = 0;
ll nod[200001];
ll bit[200001];
vector<ll> vec[200001];
ll forvec[200001];
 
ll sum(ll i){
    ll s = 0;
    while(i > 0){
        s = (s+bit[i]+MOD)%MOD;
        i -= i&-i;
    }
    return s;
}
 
void add(ll i, ll x){
    while(i <= idsize){
        bit[i] = (bit[i]+x+MOD)%MOD;
        i += i&-i;
    }
}
 
ll mod_pow(ll a, ll b){
    ll ret = 1;
    while(b){
        if(b&1) ret = ret*a%MOD;
        a = a*a%MOD;
        b /= 2;
    }
    return ret;
}
 
int main(){
    cin >> x >> y >> k;
    al = x*y%MOD;
    if(x*y-k <= 2){
        cout << 0 << endl;
        return 0;
    }
    rep(i,k){
        ll xx, yy;
        cin >> xx >> yy;
        xy.push_back(P(xx,yy));
        yx.push_back(P(yy,xx));
    }
    sort(xy.begin(),xy.end());
    sort(yx.begin(),yx.end());
    ans = (al-k+MOD)*(al-k-1+MOD)%MOD;
    ans = ans*(al-k-2+MOD)%MOD;
    ans = ans*mod_pow(6,MOD-2)%MOD;
    //cout << ans << endl;
    {
        ll cnt = 0;
        ll ad;
        ll ad2;
        rep(i,xy.size()){
            if(i == 0 || xy[i].first != xy[i-1].first){
                cnt++;
            }
            if(i == 0 || xy[i].first != xy[i-1].first){
                if(xy[i].second-1 > 0){
                    ad = ((xy[i].second)*(xy[i].second-1)/2)%MOD;
                    ad = ad*(al-k-xy[i].second+5*MOD)%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = xy[i].second*(xy[i].second-1)%MOD;
                ad2 = ad2*(xy[i].second-2)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2)%MOD;
                ans = (ans-ad2+MOD)%MOD;
            } else{
                if(xy[i].second-xy[i-1].second-2 > 0){
                    ad = ((xy[i].second-xy[i-1].second-1)*(xy[i].second-xy[i-1].second-2)/2)%MOD;
                    ad = ad*(al-k-(xy[i].second-xy[i-1].second-1)+5*MOD)%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = (xy[i].second-xy[i-1].second-1)*(xy[i].second-xy[i-1].second-2)%MOD;
                ad2 = ad2*(xy[i].second-xy[i-1].second-3)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2)%MOD;
                ans = (ans-ad2+MOD)%MOD;
            }
            if(i == xy.size()-1 || xy[i].first != xy[i+1].first){
                if(y-xy[i].second-2 > 0){
                    ad = ((y-xy[i].second-1)*(y-xy[i].second-2)/2)%MOD;
                    ad = ad*(al-k-(y-xy[i].second-1))%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = (y-xy[i].second-1)*(y-xy[i].second-2)%MOD;
                ad2 = ad2*(y-xy[i].second-3)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2);
                ans = (ans-ad2+MOD)%MOD;
            }
        }
        //cout << ans << " " << cnt << endl;
        ad = (y*(y-1)/2)%MOD;
        ad = ad*(x-cnt)%MOD;
        ad = ad*(al-k-y+5*MOD)%MOD;
        ans = (ans-ad+MOD)%MOD;
        ad2 = y*(y-1)%MOD;
        ad2 = ad2*(y-2)%MOD;
        ad2 = ad2*(x-cnt)%MOD;
        ad2 = ad2*mod_pow(6,MOD-2)%MOD;
        ans = (ans-ad2+MOD)%MOD;
        //cout << ans << endl;
         
        cnt = 0;
        rep(i,yx.size()){
            if(i == 0 || yx[i].first != yx[i-1].first){
                cnt++;
            }
            if(i == 0 || yx[i].first != yx[i-1].first){
                if(yx[i].second-1 > 0){
                    ad = (yx[i].second*(yx[i].second-1)/2)%MOD;
                    ad = ad*(al-k-yx[i].second+5*MOD)%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = yx[i].second*(yx[i].second-1)%MOD;
                ad2 = ad2*(yx[i].second-2)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2)%MOD;
                ans = (ans-ad2+MOD)%MOD;
            } else{
                if(yx[i].second-yx[i-1].second-2 > 0){
                    ad = ((yx[i].second-yx[i-1].second-1)*(yx[i].second-yx[i-1].second-2)/2)%MOD;
                    ad = ad*(al-k-(yx[i].second-yx[i-1].second-1)+5*MOD)%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = (yx[i].second-yx[i-1].second-1)*(yx[i].second-yx[i-1].second-2)%MOD;
                ad2 = ad2*(yx[i].second-yx[i-1].second-3)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2)%MOD;
                ans = (ans-ad2+MOD)%MOD;
            }
            //cout << ans << endl;
            if(i == yx.size()-1 || yx[i].first != yx[i+1].first){
                if(x-yx[i].second-2 > 0){
                    ad = ((x-yx[i].second-1)*(x-yx[i].second-2)/2)%MOD;
                    ad = ad*(al-k-(x-yx[i].second-1)+5*MOD)%MOD;
                    ans = (ans-ad+MOD)%MOD;
                }
                ad2 = (x-yx[i].second-1)*(x-yx[i].second-2)%MOD;
                ad2 = ad2*(x-yx[i].second-3)%MOD;
                ad2 = ad2*mod_pow(6,MOD-2)%MOD;
                ans = (ans-ad2+MOD)%MOD;
            }
            //cout << ans << endl;
        }
        //cout << ans << " " << cnt << endl;
        ad = (x*(x-1)/2)%MOD;
        ad = ad*(y-cnt)%MOD;
        ad = ad*(al-k-x+2*MOD)%MOD;
        ans = (ans-ad+MOD)%MOD;
        ad2 = x*(x-1)%MOD;
        ad2 = ad2*(x-2)%MOD;
        ad2 = ad2*(y-cnt)%MOD;
        ad2 = ad2*mod_pow(6,MOD-2)%MOD;
        ans = (ans-ad2+MOD)%MOD;
    }
    //cout << ans << endl;
    {
        ll now = -1;
        ll last = -1;
        rep(i,xy.size()){
            if(xy[i].first > now+1){
                ll val = y*(xy[i].first-now-1)%MOD;
                nod[idsize] = val;
                idsize++;
            }
            if(xy[i].first == now){
                //que[idsize-1].push(xy[i].second-last-1);
                vec[idsize-1].push_back(xy[i].second-last-1);
                last = xy[i].second;
            } else{
                nod[idsize] = xy[i].second;
                toid[xy[i].first] = idsize;
                idsize++;
                last = xy[i].second;
            }
            if(i == xy.size()-1 || xy[i].first != xy[i+1].first){
                //que[idsize-1].push(y-last-1);
                vec[idsize-1].push_back(y-last-1);
            }
            now = xy[i].first;
        }
        if(x > now+1){
            ll val = y*(x-now-1)%MOD;
            nod[idsize] = val;
            idsize++;
        }
    }
    //cout << "nod:";
    rep(i,idsize){
        //cout << nod[i] << " ";
        add(i+1,nod[i]);
    }
    //cout << endl;
    {
        ll nowx = -1, nowy = -1;
        rep(i,yx.size()){
            //cout << yx[i].first << " " << yx[i].second << endl;
            ll id = toid[yx[i].second];
            if(yx[i].first > nowy+1){
                //cout << "___1___" << endl;
                ll val = yx[i].first-nowy-1;
                val = val*(x-1)%MOD;
                val = val*(sum(idsize)-x+MOD)%MOD;
                ans = (ans+val)%MOD;
            }
            if(yx[i].first != nowy){
                //cout << "___2___" << endl;
                ll val = yx[i].second-1;
                if(val > 0){
                    //cout << "_2_" << endl;
                    val = val*(sum(id)-yx[i].second+MOD)%MOD;
                    ans = (ans+val)%MOD;
                }
                nowy = yx[i].first;
                nowx = yx[i].second;
            } else{
                //cout << "___3___" << endl;
                ll val = yx[i].second-nowx-2;
                if(val > 0){
                    val = val*(sum(id)-sum(toid[nowx]+1)-val-1+5*MOD)%MOD;
                    ans = (ans+val)%MOD;
                }
                nowx = yx[i].second;
            }
            if(i == yx.size()-1 || yx[i].first != yx[i+1].first){
                //cout << "___4___" << endl;
                ll val = x-nowx-2;
                if(val > 0){
                    val = val*(sum(idsize)-sum(toid[nowx]+1)-val-1+5*MOD)%MOD;
                    ans = (ans+val)%MOD;
                }
            }
            //ll q = que[id].front();
            ll q = vec[id][forvec[id]];
            forvec[id]++;
            //que[id].pop();
            ll dif = (q-nod[id]+MOD)%MOD;
            add(id+1,dif);
            nod[id] = q;
        }
        ll val = y-nowy-1;
        //cout << "val " << val <<  " " << sum(idsize) << endl;
        val = val*(x-1)%MOD;
        val = val*(sum(idsize)-x+MOD)%MOD;
        ans = (ans+val)%MOD;
    }
    cout << ans << endl;
}

*1:各ルーク二つの配置について、残りのルークが(圧縮した列数-1)分だけは横に必ず動ける

*2:障害物や壁にぶつかるまでのマス

*3:厳密にはmodの逆元を取る部分でも毎回\log ({10^9+7})かかっていますが、まぁ今回は無視で

*4:少なくとも僕にとっては