我的递归矩阵乘法算法有什么问题?

2024-09-26 22:52:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个递归矩阵乘法,该算法对于(2,2)矩阵很好,但是对于较大的矩阵,它会产生一个错误的矩阵 我不知道我做错了什么

这个算法是O(n^3)

这是我在斯坦福大学coursera专攻之后对算法研究的一部分

def default_matrix_multiplication(a, b):
    """
    Only for 2x2 matrices
    """
    if len(a) != 2 or len(a[0]) != 2 or len(b) != 2 or len(b[0]) != 2:
        raise Exception('Matrices should be 2x2!')

    new_matrix = [[a[0][0] * b[0][0] + a[0][1] * b[1][0], a[0][0] * b[0][1] + a[0][1] * b[1][1]],
                  [a[1][0] * b[0][0] + a[1][1] * b[1][0], a[1][0] * b[0][1] + a[1][1] * b[1][1]]]

    return new_matrix


def matrix_addition(x, y):

    result = [
        [x[row][col] + y[row][col] for col in range(len(x[row]))] for row in range(len(x))
    ]
    return result


def get_matrix_dimensions(matrix):
    return len(matrix), len(matrix[0])


def split_matrix(matrix):
    # split matrices to 4 portions and returns a tuple of them
    len_matrix = len(matrix)
    mid = len_matrix // 2

    # [1, 2] [9, 10]
    # [3, 4] [11, 12]
    # [5, 6] [13, 14]
    # [7, 8] [15, 16]

    top_left = [
        [matrix[i][j] for j in range(mid)] for i in range(mid)
    ]

    bot_left = [
        [matrix[i][j] for j in range(mid)] for i in range(mid, len_matrix)
    ]
    top_right = [
        [matrix[i][j] for j in range(mid, len_matrix)] for i in range(mid)
    ]
    bot_right = [
        [matrix[i][j] for j in range(mid, len_matrix)] for i in range(mid, len_matrix)
    ]
    return top_left, bot_left, top_right, bot_right


def recMatrixMult(x, y):
    # recursively multiplies matrices
    if get_matrix_dimensions(x) != get_matrix_dimensions(y):
        raise Exception('Both Matrices are not the same dimensions')
    # base case
    if get_matrix_dimensions(x) == (2, 2):
        return default_matrix_multiplication(x, y)
    else:
        A, B, C, D = split_matrix(x)
        E, F, G, H = split_matrix(y)
    # the dot matrix multiplication is as the following 
    # where top_left = (A*E) + (B*G)
    # top_right = (C*E) + (D*G)
    # bot_left = (A*F) + (B*H)
    # bot_right = (C*F) + (D*H)
    # [(A*E) + (B*G), (C*E) + (D*G)],
    # [(A*F) + (B*H), (C*F) + (D*H)]
    AE = recMatrixMult(A, E)
    BG = recMatrixMult(B, G)
    AF = recMatrixMult(A, F)
    BH = recMatrixMult(B, H)

    CE = recMatrixMult(C, E)
    DG = recMatrixMult(D, G)
    CF = recMatrixMult(C, F)
    DH = recMatrixMult(D, H)

    top_left = matrix_addition(AE, BG)
    top_right = matrix_addition(AF, BH)

    bot_left = matrix_addition(CE, DG)
    bot_right = matrix_addition(CF, DH)

    z = []

    for i in range(len(top_right)):
        z.append([top_left[i] + top_right[i]])

    for i in range(len(bot_right)):
        z.append([bot_left[i] + bot_right[i]])
    return z


X = [
    [10, 9, 4, 3],  # Row0
    [8, 3, 4, 1],  # Row1
    [93, 1, 9, 3],  # Row2
    [2, 2, 7, 6]   # Row3
]
#   col0 col1 col2 col3
Y = [
    [4, 5, 3, 5],  # Row0
    [4, 1, 2, 1],  # Row1
    [9, 8, 3, 5],  # Row2
    [6, 3, 7, 9]  # Row3
]

print(recMatrixMult(X, Y))


Tags: inrightforlenreturntopdefbot

热门问题