我的神经网络只能预测一个类

2024-05-18 19:23:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我制作了一个NeuralNetwork类,如下所示

import numpy as np
import matplotlib.pyplot as plt

class NeuralNetwork():
    
    def __init__(self, alpha, layer_dims):
        self.alpha = alpha
        self.L = len(layers_dims) - 1
        self.n = {}
        self.W = {}
        self.b = {}
        for l in range(1, len(layer_dims)):
            self.n['n' + str(l)] = layer_dims[l]
            self.W['W' + str(l)] = np.random.randn(layer_dims[l], layer_dims[l-1]) / np.sqrt(layer_dims[l-1])
            self.b['b' + str(l)] = np.zeros((layer_dims[l], 1))
    
    @staticmethod
    def sigmoid(z):
        return 1 / (1 + np.exp(-z))
    
    def sigmoid_derivative(self, z):
        return self.sigmoid(z) * (1 - self.sigmoid(z))
    
    @staticmethod
    def relu(z):
        return np.maximum(0, z)
    
    @staticmethod
    # NOT SURE WITH THIS SO JUST CHANGE LATER
    def relu_derivative(z):
        dadz = np.zeros(z.shape)
        dadz[z > 0] = 1
        return dadz
        
    def get_cost(self, Y):
        logprobs = Y * np.log(self.A['A' + str(self.L)]) + (1 - Y) * np.log(1 - self.A['A' + str(self.L)])
        return -1/self.m * np.sum(logprobs)
    
    def forward_propagation(self, X):
        self.Z = {}
        self.A = {}
        self.A['A0'] = X
        for l in range(1, self.L):
            self.Z['Z' + str(l)] = np.dot(self.W['W' + str(l)], self.A['A' + str(l-1)]) + self.b['b' + str(l)]
            self.A['A' + str(l)] = self.relu(self.Z['Z' + str(l)])
        self.Z['Z' + str(self.L)] = np.dot(self.W['W' + str(self.L)], self.A['A' + str(self.L-1)]) + self.b['b' + str(self.L)]
        self.A['A' + str(self.L)] = self.sigmoid(self.Z['Z' + str(self.L)])
    
    def backward_propagation(self, Y):
        self.dZ = {}
        self.dA = {}
        self.dW = {}
        self.db = {}
        self.dA['dA' + str(self.L)] = -Y/self.A['A' + str(self.L)] + (1-Y)/(1-self.A['A' + str(self.L)])
        self.dZ['dZ' + str(self.L)] = self.dA['dA' + str(self.L)] * self.sigmoid_derivative(self.Z['Z' + str(self.L)])
        self.dW['dW' + str(self.L)] = 1/self.m * np.dot(self.dZ['dZ' + str(self.L)], self.A['A' + str(self.L-1)].T)
        self.db['db' + str(self.L)] = 1/self.m * np.sum(self.dZ['dZ' + str(self.L)], axis=1, keepdims=True)
        for l in reversed(range(1, self.L)):
            self.dA['dA' + str(l)] = np.dot(self.W['W' + str(l + 1)].T, self.dZ['dZ' + str(l + 1)])
            self.dZ['dZ' + str(l)] = self.dA['dA' + str(l)] * self.relu_derivative(self.Z['Z' + str(l)])
            self.dW['dW' + str(l)] = 1/self.m * np.dot(self.dZ['dZ' + str(l)], self.A['A' + str(l-1)].T)
            self.db['db' + str(l)] = 1/self.m * np.sum(self.dZ['dZ' + str(l)], axis=1, keepdims=True)
    
    def update_parameters(self):
        for l in range(1, self.L+1):
            self.W['W' + str(l)] = self.W['W' + str(l)] - self.alpha * self.dW['dW' + str(l)]
            self.b['b' + str(l)] = self.b['b' + str(l)] - self.alpha * self.db['db' + str(l)]
            
    def train(self, X_train, y_train, max_iters):
        self.m = X_train.shape[-1]
        self.costs = []
        for _ in range(max_iters):
            self.forward_propagation(X_train)
            self.backward_propagation(y_train)
            self.update_parameters()
            self.costs.append(self.get_cost(y_train))
        print(f'Training done! Loss after {max_iters} iterations: {self.costs[-1]}')
        
    def predict(self, X_test):
        self.forward_propagation(X_test)
        y_pred = self.sigmoid(self.A['A' + str(self.L)])
        y_pred[y_pred >= 0.5] = 1
        y_pred[y_pred < 0.5] = 0
        return y_pred
    
    def plot_cost(self):
        plt.figure(dpi=100)
        plt.plot(self.costs)
        plt.show()

其中layer_dims[n_x, n_h1, n_h2, ..., n_y]。我确实试过测试forward_propagationbackward_propagation,它们确实工作得很好。我还尝试了不同类型的参数初始化,但没有成功。每当我预测新的结果时,它们都得到相同的概率。我的代码怎么了?我错过了什么

X_train的形状是(n_x, m),而y_train的形状是(n_y, m)


Tags: selflayerdbreturndefnptrainda
1条回答
网友
1楼 · 发布于 2024-05-18 19:23:56

请注意,您已经在正向传播和预测中应用了sigmoid激活。这将是一个问题,因为2个嵌套sigmoid的输出将始终不小于0.5

相关问题 更多 >

    热门问题