ntk log ntk

競プロや非競技プロのやったことを書いています.

ABC176 D - Wizard in Maze

D - Wizard in Maze

久しぶりにコンテスト中にDまで解けた!嬉しい~~!

問題概要

グリッドグラフが与えられる.スタートのマスからゴールのマスまで移動したい.移動する操作は2種類ある.移動Aは現在のマスから上下左右に1マス移動できて,移動Bは現在のマスを中心とした,5×5の範囲内のマスに移動できる1

移動Bが必要な最小の回数を求めよ.ただし,スタートからゴールにどうやっても到達できなければ-1を出力せよ.

解法

2回BFSをする.

1回目はグリッドグラフ上で行う.移動Aのみで移動できる範囲で,各マスをグループ分けする. 各グループを一つの頂点とみなしたグラフを作成し,移動Bで移動できる頂点間に辺を張る. 作成したグラフ上で,2回目のBFSをし,スタートマスの頂点からゴールマスの頂点への最短距離を求める.

たとえば入力例4だと次のような感じ.

1回目のBFSでグループ分け.

f:id:ntk_ta01:20200823112042p:plain

各マスから,どのグループにワープできるか走査.壁でないマスはすべて調べる.

f:id:ntk_ta01:20200823112141p:plain

f:id:ntk_ta01:20200823112158p:plain

f:id:ntk_ta01:20200823112447p:plain

f:id:ntk_ta01:20200823112220p:plain

あとはグラフを作成する.このグラフ上でBFSすると求める最短距離がわかる.

f:id:ntk_ta01:20200823112254p:plain

コード

def main():
    from collections import deque
    H, W = (int(i) for i in input().split())
    Ch, Cw = (int(i)-1 for i in input().split())
    Dh, Dw = (int(i)-1 for i in input().split())
    A = [input() for i in range(H)]
    V = [[-1]*W for i in range(H)]  # V[h][w]を頂点iとする
    d = ((1, 0), (0, 1), (-1, 0), (0, -1))

    def bfs_2d(sy, sx, v):
        que = deque()
        V[sy][sx] = v
        que.append((sy, sx))
        while que:
            uh, uw = que.popleft()
            for dy, dx in d:
                next_h = uh + dy
                next_w = uw + dx
                if not(0 <= next_h < H and 0 <= next_w < W):
                    continue
                if V[next_h][next_w] != -1:
                    continue
                if A[next_h][next_w] == '#':
                    continue
                V[next_h][next_w] = v
                que.append((next_h, next_w))

    # グループ分け
    vertex = 0
    for h in range(H):
        for w in range(W):
            if A[h][w] == "#" or V[h][w] != -1:
                continue
            bfs_2d(h, w, vertex)
            vertex += 1

    # グラフを作成、全てのマスで5*5の範囲を走査
    G = [[] for _ in range(vertex)]
    used = [set() for _ in range(vertex)]
    for h in range(H):
        for w in range(W):
            if A[h][w] == "#":
                continue
            for dh in range(-2, 3):
                for dw in range(-2, 3):
                    if h + dh < 0 or H <= h + dh or w + dw < 0 or W <= w + dw:
                        continue
                    if A[h+dh][w+dw] == "#":
                        continue
                    if V[h][w] == V[h+dh][w+dw]:
                        continue
                    if V[h+dh][w+dw] in used[V[h][w]]:
                        continue
                    G[V[h][w]].append((V[h+dh][w+dw]))
                    used[V[h][w]].add((V[h+dh][w+dw]))

    # for h in range(H):
    #     print(V[h])

    # for i in range(vertex):
    #     print(G[i])

    # スタートからゴールまでbfsして最短距離を求める
    s = V[Ch][Cw]
    g = V[Dh][Dw]

    dist = [-1 for _ in range(vertex)]
    que = deque()
    dist[s] = 0
    que.append(s)
    while que:
        u = que.popleft()
        for v in G[u]:
            if dist[v] != -1:
                continue
            dist[v] = dist[u] + 1
            que.append(v)

    print(dist[g])


if __name__ == '__main__':
    main()

提出コードへのリンク


グループ分け,つまり連結成分ごとに考え,縮約してグラフを作る考察がきれいにできて気持ちよかった.ただBFSをバグらせていてとても時間をかけてしまった,もったいない…….バグは,スタートマスが頂点0から始まる,としてBFSを書いてしまっていた感じ.これのせいでTLEが出ていたから,マス(h,w)がどの頂点に属するかを持つ二次元配列Vを一次元にしたり,キューに突っ込んでいたタプルを整数に直したりと†定数倍高速化†も頑張ったのだけれど,そういう工夫は今回必要なかったらしい(ABC174 Fとかは必要だった).

ちなみにBFSは自分はこう書くのだけれど,

def bfs(s):
    que = deque([])
    dist = [INF]*N
    dist[s] = 0
    que.append(s)
    while que:
        pos = que.popleft()
        for v in G[pos]:
            if dist[v] != INF:
                continue
            dist[v] = dist[pos] + 1
            que.append(v)

先日こんなツイートを見た.これ壊れないんだ……(びっくり).

想定解法は01BFSらしいので,アルメリアさんの記事を読んでそれでも解くつもり.

(8/23 22:15追記) 01BFSと全マスについて,各5×5マスを走査するだけでいいダイクストラでも解いた.ダイクストラはちょっと遅くなるけど,間に合うしメモリ使用量は減った.

01BFSはdeque()の両端から追加していくBFSだった.アルメリアさんの記事にあるようにdequeを一本用意するか,もしくは競プロフレンズさんのツイートにあるように,コストの種類数だけキューを用意する(こっちのがいろいろできそう,0123BFSとか?).

ダイクストラの提出コード

ダイクストラのコード

# ダイクストラ
INF = 10**9 + 7


def main():
    from heapq import heappop, heappush
    H, W = (int(i) for i in input().split())
    Ch, Cw = (int(i)-1 for i in input().split())
    Dh, Dw = (int(i)-1 for i in input().split())
    A = [input() for i in range(H)]

    dist = [[INF]*W for _ in range(H)]
    dist[Ch][Cw] = 0
    que = [(0, Ch*W + Cw)]
    while que:
        u_dist, u = heappop(que)
        uh, uw = u//W, u % W
        if dist[uh][uw] < u_dist:
            continue
        for dh in range(-2, 3):
            for dw in range(-2, 3):
                nh = uh + dh
                nw = uw + dw
                if nh < 0 or H <= nh or nw < 0 or W <= nw:
                    continue
                if A[nh][nw] == "#":
                    continue
                cost = 0 if abs(dh) + abs(dw) <= 1 else 1
                if dist[nh][nw] > dist[uh][uw] + cost:
                    dist[nh][nw] = dist[uh][uw] + cost
                    heappush(que, (dist[nh][nw], nh*W + nw))

    if dist[Dh][Dw] == INF:
        print(-1)
    else:
        print(dist[Dh][Dw])


if __name__ == '__main__':
    main()

01BFSの提出コード

01BFSのコード

# 01BFS
INF = 10**9 + 7


def main():
    from collections import deque
    H, W = (int(i) for i in input().split())
    Ch, Cw = (int(i)-1 for i in input().split())
    Dh, Dw = (int(i)-1 for i in input().split())
    A = [input() for i in range(H)]

    def bfs_2d(c, sy, sx, gh, gw, H, W):
        que0 = deque()
        que1 = deque()
        dist = [[INF]*W for i in range(H)]

        dist[sy][sx] = 0
        que0.append((sy, sx))
        while que0 or que1:
            if que0:
                uh, uw = que0.popleft()
            else:
                uh, uw = que1.popleft()
            for dh in range(-2, 3):
                for dw in range(-2, 3):
                    next_h = uh + dh
                    next_w = uw + dw
                    cost = 0 if abs(dh) + abs(dw) <= 1 else 1
                    if not(0 <= next_h < H and 0 <= next_w < W):
                        continue
                    if c[next_h][next_w] == "#":
                        continue
                    if dist[next_h][next_w] > dist[uh][uw] + cost:  
                        # ここダイクストラに似てるな~と思いました
                        dist[next_h][next_w] = dist[uh][uw] + cost
                        if cost == 0:
                            que0.append((next_h, next_w))
                        else:
                            que1.append((next_h, next_w))
        return dist[gh][gw] if dist[gh][gw] != INF else -1

    ans = bfs_2d(A, Ch, Cw, Dh, Dw, H, W)
    print(ans)


if __name__ == '__main__':
    main()


  1. 一瞬「ん?」ってなりませんか?clar飛ばそうかと迷った(実際飛んでたし).サンプルを見て考えたり質問タブに気づけるとわかる.それでもわからなかったら自分でも質問飛ばすといいんだろうな(まだコンテスト中に質問したことなし).