<p>@kcsquared解决方案的三个NumPy版本,可在10秒时间限制内轻松解决最坏情况:</p>
<pre><code>def numpy1(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
for i in range(1, n+1):
dp[1:,] += dp[:-1, (np.arange(m) + i) % m]
dp %= 10**9 + 7
return dp[k][0]
</code></pre>
<pre><code>def numpy2(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
i = range(m)
for _ in range(n):
i = np.roll(i, 1)
dp[1:,] += dp[:-1, i]
dp %= 10**9 + 7
return dp[k][0]
</code></pre>
<pre><code>def numpy3(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
for i in range(n):
dp[1:,] += np.roll(dp[:-1,], i, axis=1)
dp %= 10**9 + 7
return dp[k][0]
</code></pre>
<p>中小企业和最坏情况的基准:</p>
<pre><code>n = 19 k = 11 m = 13
-
22856.8 μs 23409.3 μs 23421.4 μs naive
496.9 μs 500.2 μs 524.7 μs dtjc
918.6 μs 928.6 μs 936.3 μs kcsquared
173.5 μs 183.6 μs 191.9 μs numpy1
402.2 μs 403.5 μs 411.1 μs numpy2
297.8 μs 318.4 μs 320.1 μs numpy3
n = 200 k = 100 m = 200
-
2033.6 ms 2177.3 ms 2178.1 ms dtjc
1410.6 ms 1420.2 ms 1430.5 ms kcsquared
19.5 ms 19.8 ms 20.3 ms numpy1
22.5 ms 22.9 ms 23.0 ms numpy2
26.8 ms 27.3 ms 27.3 ms numpy3
n = 1000 k = 100 m = 1000
-
508.0 ms 516.1 ms 519.2 ms numpy1
518.3 ms 518.8 ms 526.3 ms numpy2
495.1 ms 496.4 ms 499.2 ms numpy3
</code></pre>
<p>基准代码(<a href="https://tio.run/##tVb/bts2EP7fT3HAEJi0ZUeKkRYV5mIYBgz7Y@sDGILAWHSjWaI0kiraPNueYa@U3lHUTztrV3RCEjEk7@77vjueWH@yj5XaPT@fdFWCzUuZW8jLutIWtKylsAu3klupbVUVpls8VuVDroTNK2XaLadGHSdbCt2kR3F8lAs/oZqy/gTCgKoXi0yeQIn8g2QqgHMAJY8XgI@WttEKTFMy@nVxONxACfs9hG7L6DlVuoUCuZpgYlqo95JFAah1xDECJycsClerN7CG13yx@KkHyP6olOQOUmb/PHaIAtD70MPKT3CGt6DiHoEHGg7LBPBiPaI17dZAFkZ6A7/K2nibqI/IEd1okv7gNNNr5VS4pEGoz0fzVyO0zFgpPqa10DY1@ZOMURUbgK2sKFJU360YP1tWmRtx2Lyld4v8599@TX9/9wvsIQphtQIXZeGWfoB3NVZI/uQUJlqqAvKYO8VBfsyN7dSYxUTppsguZXQzhdDvpcE9DeYb2JQOrGA2sYaIw@0t3PF5YdCzuXCwmePisLpmic@XLEex@QvyeDKDRlTVYGxeFGBKURRSg30UilLR6TYW4EeXoxeUymoU6HAI3RFIqfzbikcTnswnr2FPvJtDmOAPJbx17I5Uo3XaCTC4oVq80H8A2FsOgchUfpDaYGm2PsKLeuQjF50bLcsppfgiTYh8FgzRJAdGpuspA3dyqiyB9f7SLDmgRfJd3N/su/MzPuPoacaYBPcdkFpiNGuBLrWq3j5JXRnGzmvSHXsYTuE53flin2Su0y2fJIta3yAcGkRx0IkQU2dh6FJ4kanxtJ02GZkQp77bzFidZzzuvg@PHEcdpp7YqJjViFPeRtBVUbA8wGr8N7b5txLb/S8JUtdz09HxqBPEHYDA3rqfsPtK/BabCcOeYgIwVdG4Th0QqwepA0JUCpvSd9@DqTUyYEuFoG8y/P/cDUo3WGJ9kDc@3rxZYme@v7@arHEBammaAg/SHoy0jE/6Bt0eyKoHOT3whNAZ4j2CWom7mrBClA@ZiN33cuu@5OmxkEIzDq6HuHBbkWWM3LOVQ847@vv2xfkh3k3Pf8trxUbyMAu33o47wK4tOljokNxv01SJUqbpQMzr0/FG1vEy6HAN24Qxku5LUrFuje4L0cwPfutdNln0BisdSy7a8aDdc3A3qcAJEQy3gcD3F/@@8@9d4s2i0A9aIR2dGJY3r7fRCf752wClm9EsJjiSr/BL1yK4C7GNR/QHRz2I/xr9evBXFLycx971sTFsF5yGgwTfEuz@pWDPz58B" rel="nofollow noreferrer" title="Python 3 – Try It Online">Try it online!</a>):</p>
<pre><code>from timeit import repeat
from itertools import combinations
from functools import lru_cache
import numpy as np
def naive(n, k, m):
return sum(sum(combi) % m == 0
for combi in combinations(range(1, n+1), k)) % (10**9 + 7)
@lru_cache(None)
def dtjc(n, k, m, r=0):
if k > n:
return 0
if k == 0:
return 1 if r == 0 else 0
return (dtjc(n-1, k, m, r) + dtjc(n-1, k-1, m, (r+n) % m)) % (10**9 + 7)
def kcsquared(max_part_size: int, total_num_parts: int, mod: int) -> int:
BIG_MOD = 10 ** 9 + 7
# Optimization if no partitions exist
if total_num_parts > max_part_size:
return 0
largest_sum = ((max_part_size * (max_part_size + 1) // 2)
- ((max_part_size - total_num_parts) *
(max_part_size - total_num_parts + 1) // 2))
# Optimization if largest partition sum still smaller than mod
if largest_sum < mod:
return 0
dp = [[0 for _ in range(mod)] for _ in range(total_num_parts + 1)]
dp[0][0] = 1
for curr_max_part in range(1, max_part_size + 1):
for curr_num_parts in reversed(range(0, total_num_parts)):
for rem in range(mod):
dp[curr_num_parts + 1][(rem + curr_max_part) % mod] += dp[curr_num_parts][rem]
dp[curr_num_parts + 1][(rem + curr_max_part) % mod] %= BIG_MOD
return dp[total_num_parts][0]
def numpy1(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
for i in range(1, n+1):
dp[1:,] += dp[:-1, (np.arange(m) + i) % m]
dp %= 10**9 + 7
return dp[k][0]
def numpy2(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
i = range(m)
for _ in range(n):
i = np.roll(i, 1)
dp[1:,] += dp[:-1, i]
dp %= 10**9 + 7
return dp[k][0]
def numpy3(n, k, m):
dp = np.zeros((k+1, m), np.int32)
dp[0][0] = 1
for i in range(n):
dp[1:,] += np.roll(dp[:-1,], i, axis=1)
dp %= 10**9 + 7
return dp[k][0]
def test(args, solutions, number, format_time):
print('n = %d k = %d m = %d' % args)
print('-' * 55)
for _ in range(1):
results = set()
for func in solutions:
times = sorted(repeat(lambda: dtjc.cache_clear() or results.add(func(*args)), number=number))[:3]
print(*(format_time(t / number) for t in times), func.__name__)
print('results set:', results)
assert len(results) == 1
print()
test((19, 11, 13),
[naive, dtjc, kcsquared, numpy1, numpy2, numpy3],
10,
lambda time: '%7.1f μs ' % (time * 1e6))
test((200, 100, 200),
[dtjc, kcsquared, numpy1, numpy2, numpy3],
1,
lambda time: '%6.1f ms ' % (time * 1e3))
test((1000, 100, 1000),
[numpy1, numpy2, numpy3],
1,
lambda time: '%5.1f ms ' % (time * 1e3))
</code></pre>