ABC268_G : Random Student ID を(ほぼ)ソートだけで解く

問題概要など

atcoder.jp

 ここでは、新入生  i の学籍番号の期待値は、 S_j S_i の接頭辞であるような文字列の数を  A_i,  S_i S_j の接頭辞であるような文字列の数を  B_i としたときに(ただし、 A_i にのみ  S_i を含む)、 \frac{A_i - B_i + N}{2} を計算することで求めることができることを前提とします。詳しくは公式解説をご覧ください。

atcoder.jp

解法

  1. 各文字列  S_i について、末尾に | を付け足した文字列を  S_i' とします。
  2.  S_i S_i' を同じ配列に入れ、昇順ソートします。ソート後の  S_i の index を  idx_i,  S_i' の index を  idx_i' とします。
  3.  A_i = idx_i 以前に現れた 末尾が | でない文字列の数 - 末尾が | である文字列の数 です。
    また、 B_i = \frac{idx_i' - idx_i - 1}{2} です。
  4.  A_i, B_i は、ソート後の配列を前から順に見ていくことで  O(N) で求めることができます。よって配列のソートと合わせて期待値を  O(N\ log\ N) で求めることができ、これは十分高速です。

解説

 簡単な例で見ていきましょう。

  N=6
  S=( a,\ ab,\ ac,\ abc,\ b,\ bc )

とします。このとき、各  S_i について、末尾に | を付け足したものの配列を  S' とすると、

  S'=(a|,\ ab|,\ ac|,\ abc|,\ b|,\ bc|)

となります。よって、これらを合体させてソートした配列  T は、

  T=(a,\ ab,\ abc,\ abc|,\ ab|,\ ac,\ ac|,\ a|,\ b,\ bc,\ bc|,\ b|)

となります。
 分かりやすくするため、この配列 T の要素を順に縦に並べます。

  a
  ab
  abc
  abc|
  ab|
  ac
  ac|
  a|
  b
  bc
  bc|
  b|

 これを見ると、どの  S_i S_i' の間にも  S_i が接頭辞であるような文字列のみが過不足なく存在している ということが成り立つと推測できます。例えば、 ab ab| の間には、  ab が接頭辞である  abc abc| があり、 ac ac| の間には、 ac が接頭辞であるような文字列は存在していないため何もありません。
 また、このことを逆手に取ると、 S_i の接頭辞である文字列の数は、  idx_j \leqq idx_i \lt idx_j' であるような  j の数と等しい ということも分かります。例えば、 abc を挟むように存在している文字列は  a,\ a| ab,\ ab| の 2 つであるため、この 2 つが  S のうち  abc の接頭辞である文字列であることが分かります。

 では、何故このことが成り立つのでしょうか?

 ここで重要になのが、| という文字の ASCII コードは 124 である、ということです。z の ASCII コードは 122 であるため、ソートしたときに如何なる英小文字よりも後ろに来る ということが分かります。
 このことから、ソートしたときに  S_i より後かつ  S_i' より前に来る文字列  S_j は、「 |S_i| 文字目まで  S_i と一致しており、  |S_i| \lt |S_j| である」という条件を満たしている必要があります。すなわち、 S_i は、この条件を満たすような  S_j の接頭辞である ということが成立します。

 以上の議論により、この方法で  S_j S_i の接頭辞であるような文字列の数  A_i,  S_i S_j の接頭辞であるような文字列の数  B_i を求めることができると分かりました。

 余談ですが、この性質から、 S_i を開き括弧、 S_i' を閉じ括弧に置き換えて作成した括弧列は正しい括弧列であり、対応する括弧は必ず  S_i S_i' の関係になっています。
 先程の例で試すと ((())())(()) となることからも分かるかと思います。(この性質から何かしらの問題が生えるかもしれませんね)

実装例 (PyPy3, 798ms)

n = int(input())
t = []
mod = 998244353
cnt = 0
a = [0] * n
b = [0] * n
idx = [0] * n

for i in range(n):
    s = input()
    t.append([s, i])
    t.append([s + "|", i])
t.sort(key = lambda x: x[0])

for i in range(len(t)):
    if t[i][0][-1] != "|":
        cnt += 1
        idx[t[i][1]] = i
        a[t[i][1]] = cnt
    else:
        cnt -= 1
        b[t[i][1]] = (i - idx[t[i][1]] - 1) // 2

for i in range(n):
    ans = (a[i] - b[i] + n) * pow(2, mod-2, mod) % mod
    print(ans)

atcoder.jp