最大化插值矩阵

2024-09-28 23:23:17 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我需要max_(a', m') f(a, m, e, m', a'),我用一个网格V1来近似f。这是一个形状为(nA, nM, nE, nM, nA)的numpy矩阵(附在最后)。你知道吗

我想先插值,然后做最大化。下面是我当前的代码(我粘贴代码以在最后重新创建Grid):

# takes grid indices (first three dimensions) idx and interpolates on V
def interpolateV(idx, V, Grid):
    from scipy.interpolate import interp2d
    f = interp2d(Grid.mGrid, Grid.aGrid, V[idx])
    return f

# (somewhere else:)
s2 = (Grid.nM, Grid.nA, Grid.nE)
v1Max = np.empty(s2)
v1ArgMaxA = np.empty(s2)
v1ArgMaxM = np.empty(s2)
from scipy import optimize
for idx in np.ndindex(V1[..., 0,0].shape):
    V1i = interpolateV(idx, V1, Grid)
    x, f, d = optimize.fmin_l_bfgs_b(lambda x: -V1i(x[0], x[1]), np.array([1, 1]), bounds=[(Grid.aMin, Grid.aMax), (Grid.mMin, Grid.mMax)], approx_grad=True)
    v1Max[idx] = f
    v1ArgMaxA[idx], v1ArgMaxM[idx] = x
# let's compare with standard grid-wise optimization (without interpolation):
temp = V1.max(axis=-1)
# maximize over m
v1Max = temp.max(axis=-1)
# now max over a, given optimal m
v1ArgMaxAGrid = temp.argmax(axis=-1)    

到目前为止,还不错。但是,插值最大化的值相差甚远:

In[51]: v1ArgMaxAGrid[:,:,0]
Out[51]: 
array([[0, 0, 0, 0, 2],
       [0, 0, 0, 0, 2],
       [0, 0, 0, 2, 2],
       [0, 0, 0, 2, 3],
       [0, 0, 0, 2, 3]], dtype=int64)
In[54]: Grid.aGrid[v1ArgMaxAGrid[:,:,0]]
Out[54]: 
array([[ 0.  ,  0.  ,  0.  ,  0.  ,  3.5 ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  3.5 ],
       [ 0.  ,  0.  ,  0.  ,  3.5 ,  3.5 ],
       [ 0.  ,  0.  ,  0.  ,  3.5 ,  5.25],
       [ 0.  ,  0.  ,  0.  ,  3.5 ,  5.25]])
In[52]: v1ArgMaxA[:,:,0]
Out[52]: 
array([[ 0.   ,  0.75 ,  2.25 ,  7.   ,  7.   ],
       [ 0.   ,  1.5  ,  4.247,  7.   ,  7.   ],
       [ 0.75 ,  2.25 ,  7.   ,  7.   ,  7.   ],
       [ 1.5  ,  1.5  ,  7.   ,  7.   ,  7.   ],
       [ 2.25 ,  4.939,  7.   ,  7.   ,  7.   ]])

这里发生了什么;为什么值如此偏离?我做错了吗?你知道吗

在复制粘贴后重新创建GridV1

class Grids(object):
    nE = 2
    nA = 5
    nM = 5

    M = 3
    A = 7

    mMin = 0
    mMax = M
    aMin = 0
    aMax = A

    def __init__(self):
         self.reset();

    def reset(self):
        self.mGrid = np.linspace(self.mMin, self.mMax, self.nM)
        self.aGrid = np.linspace(self.aMin, self.aMax, self.nA)

        self.eGrid = np.array([0.318, 3.149])
        self.transitionE = np.array([[1., 0.],
                              [0., 1.]])
import numpy as np
Grid = Grids() 
V1 = np.array([[[[[  1.19 ,   0.975,  -0.371,  -2.848,  -6.456],
      [ -1.463,  -4.313,  -8.294, -13.407, -19.65 ],
      [ -9.888, -15.377, -21.997, -29.748, -38.63 ],
      [-24.574, -32.701, -41.958, -52.347, -63.866],
      [-45.562, -56.325, -68.218, -81.242, -95.397]],

     [[ 64.724,  64.672,  64.567,  64.358,  54.127],
      [ 64.247,  63.964,  53.759,  52.687,  50.487],
      [ 53.526,  52.078,  49.501,  45.799,  40.969],
      [ 48.389,  44.307,  39.105,  32.769,  25.314],
      [ 37.062,  30.347,  22.52 ,  13.553,   3.47 ]]],


    [[[ 12.624,  12.704,  12.591,   2.602,   1.618],
      [  2.237,   2.011,   0.655,  -1.832,  -5.45 ],
      [ -0.064,  -2.928,  -6.923, -12.049, -18.306],
      [ -8.624, -14.126, -20.759, -28.522, -37.416],
      [-23.488, -31.625, -40.894, -51.293, -62.822]],

     [[ 65.686,  65.695,  65.679,  65.631,  65.537],
      [ 65.401,  65.342,  65.23 ,  65.014,  54.778],
      [ 65.174,  64.881,  54.667,  53.59 ,  51.385],
      [ 54.43 ,  52.973,  50.396,  46.685,  41.855],
      [ 49.228,  45.138,  39.936,  33.594,  26.136]]],


    [[[ 13.681,  13.872,  14.024,  14.117,  14.093],
      [ 13.671,  13.74 ,  13.617,   3.617,   2.624],
      [  3.636,   3.397,   2.027,  -0.474,  -4.106],
      [  1.2  ,  -1.677,  -5.684, -10.823, -17.092],
      [ -7.538, -13.051, -19.694, -27.468, -36.373]],

     [[ 66.553,  66.597,  66.623,  66.631,  66.614],
      [ 66.362,  66.364,  66.342,  66.287,  66.188],
      [ 66.327,  66.259,  66.138,  65.917,  55.676],
      [ 66.077,  65.776,  55.562,  54.476,  52.271],
      [ 55.269,  53.804,  51.227,  47.51 ,  42.677]]],


    [[[ 14.6  ,  14.839,  15.054,  15.242,  15.394],
      [ 14.728,  14.909,  15.05 ,  15.133,  15.098],
      [ 15.07 ,  15.126,  14.989,   4.975,   3.968],
      [  4.9  ,   4.648,   3.265,   0.752,  -2.892],
      [  2.286,  -0.601,  -4.62 ,  -9.769, -16.048]],

     [[ 67.36 ,  67.427,  67.481,  67.52 ,  67.543],
      [ 67.229,  67.266,  67.286,  67.287,  67.265],
      [ 67.288,  67.281,  67.25 ,  67.19 ,  67.085],
      [ 67.231,  67.154,  67.033,  66.803,  56.562],
      [ 66.917,  66.607,  56.393,  55.301,  53.093]]],


    [[[ 15.442,  15.71 ,  15.96 ,  16.191,  16.4  ],
      [ 15.647,  15.875,  16.08 ,  16.257,  16.399],
      [ 16.128,  16.294,  16.422,  16.491,  16.443],
      [ 16.334,  16.377,  16.227,   6.201,   5.182],
      [  5.986,   5.723,   4.33 ,   1.806,  -1.849]],

     [[ 68.123,  68.207,  68.28 ,  68.342,  68.391],
      [ 68.036,  68.096,  68.143,  68.176,  68.194],
      [ 68.155,  68.183,  68.195,  68.19 ,  68.163],
      [ 68.192,  68.176,  68.145,  68.076,  67.971],
      [ 68.07 ,  67.984,  67.864,  67.628,  57.384]]]],



   [[[[ 11.877,   1.81 ,   1.59 ,   0.238,  -2.246],
      [  0.873,  -0.853,  -3.709,  -7.696, -12.814],
      [ -4.928,  -9.292, -14.787, -21.413, -29.17 ],
      [-16.988, -23.99 , -32.123, -41.386, -51.78 ],
      [-35.352, -44.989, -55.758, -67.657, -80.686]],

     [[ 65.151,  65.131,  65.075,  64.966,  64.753],
      [ 64.779,  64.647,  64.36 ,  54.151,  53.076],
      [ 54.24 ,  53.917,  52.465,  49.888,  46.183],
      [ 51.728,  48.771,  44.694,  39.483,  33.153],
      [ 43.026,  37.436,  30.734,  22.892,  13.934]]],


    [[[ 13.101,  13.245,  13.318,  13.2  ,   3.204],
      [ 12.924,   2.847,   2.616,   1.253,  -1.24 ],
      [  2.272,   0.533,  -2.337,  -6.338, -11.47 ],
      [ -3.664,  -8.041, -13.548, -20.187, -27.956],
      [-15.902, -22.915, -31.058, -40.332, -50.737]],

     [[ 66.066,  66.092,  66.098,  66.078,  66.025],
      [ 65.827,  65.8  ,  65.738,  65.622,  65.403],
      [ 65.706,  65.564,  65.268,  55.054,  53.974],
      [ 55.144,  54.812,  53.36 ,  50.774,  47.069],
      [ 52.567,  49.602,  45.525,  40.308,  33.975]]],


    [[[ 14.087,  14.302,  14.487,  14.632,  14.72 ],
      [ 14.148,  14.281,  14.344,  14.216,   4.21 ],
      [ 14.323,   4.232,   3.987,   2.611,   0.104],
      [  3.536,   1.784,  -1.099,  -5.112, -10.256],
      [ -2.578,  -6.965, -12.484, -19.133, -26.912]],

     [[ 66.905,  66.96 ,  66.999,  67.022,  67.025],
      [ 66.742,  66.762,  66.76 ,  66.734,  66.676],
      [ 66.754,  66.718,  66.646,  66.525,  66.301],
      [ 66.609,  66.459,  66.163,  55.94 ,  54.86 ],
      [ 55.983,  55.643,  54.191,  51.599,  47.891]]],


    [[[ 14.969,  15.221,  15.454,  15.663,  15.844],
      [ 15.134,  15.338,  15.513,  15.648,  15.725],
      [ 15.548,  15.667,  15.716,  15.574,   5.554],
      [ 15.587,   5.483,   5.226,   3.837,   1.318],
      [  4.622,   2.859,  -0.034,  -4.058,  -9.213]],

     [[ 67.691,  67.766,  67.829,  67.879,  67.915],
      [ 67.581,  67.629,  67.662,  67.678,  67.676],
      [ 67.669,  67.679,  67.669,  67.637,  67.574],
      [ 67.658,  67.612,  67.541,  67.411,  67.187],
      [ 67.449,  67.29 ,  66.994,  56.765,  55.682]]],


    [[[ 15.786,  16.063,  16.324,  16.569,  16.794],
      [ 16.016,  16.258,  16.48 ,  16.679,  16.85 ],
      [ 16.533,  16.724,  16.884,  17.006,  17.07 ],
      [ 16.811,  16.918,  16.954,  16.8  ,   6.768],
      [ 16.673,   6.559,   6.29 ,   4.891,   2.362]],

     [[ 68.439,  68.529,  68.61 ,  68.679,  68.737],
      [ 68.368,  68.436,  68.492,  68.535,  68.566],
      [ 68.507,  68.546,  68.571,  68.581,  68.574],
      [ 68.572,  68.573,  68.563,  68.523,  68.46 ],
      [ 68.497,  68.443,  68.372,  68.236,  68.009]]]],



   [[[[ 12.453,  12.498,   2.425,   2.198,   0.84 ],
      [  2.083,   1.483,  -0.248,  -3.111,  -7.104],
      [ -1.092,  -4.331,  -8.701, -14.202, -20.834],
      [-10.528, -16.405, -23.412, -31.551, -40.82 ],
      [-26.266, -34.779, -44.422, -55.196, -67.101]],

     [[ 65.555,  65.558,  65.534,  65.474,  65.36 ],
      [ 65.252,  65.179,  65.043,  64.752,  54.54 ],
      [ 64.974,  54.631,  54.304,  52.852,  50.272],
      [ 53.942,  52.11 ,  49.158,  45.072,  39.867],
      [ 47.865,  43.4  ,  37.823,  31.106,  23.273]]],


    [[[ 13.541,  13.722,  13.859,  13.927,  13.802],
      [ 13.5  ,  13.534,   3.451,   3.214,   1.846],
      [  3.482,   2.868,   1.123,  -1.753,  -5.76 ],
      [  0.172,  -3.08 ,  -7.463, -12.976, -19.62 ],
      [ -9.442, -15.329, -22.348, -30.497, -39.776]],

     [[ 66.433,  66.473,  66.495,  66.496,  66.472],
      [ 66.231,  66.227,  66.196,  66.13 ,  66.011],
      [ 66.178,  66.096,  65.952,  65.655,  55.438],
      [ 65.877,  55.526,  55.199,  53.738,  51.158],
      [ 54.781,  52.941,  49.989,  45.897,  40.689]]],


    [[[ 14.475,  14.708,  14.917,  15.095,  15.235],
      [ 14.588,  14.759,  14.886,  14.943,  14.808],
      [ 14.899,  14.92 ,   4.823,   4.572,   3.19 ],
      [  4.746,   4.119,   2.362,  -0.527,  -4.546],
      [  1.258,  -2.005,  -6.398, -11.922, -18.577]],

     [[ 67.247,  67.312,  67.362,  67.398,  67.417],
      [ 67.109,  67.142,  67.158,  67.152,  67.123],
      [ 67.158,  67.145,  67.105,  67.033,  66.909],
      [ 67.082,  66.991,  66.847,  66.541,  56.324],
      [ 66.717,  56.357,  56.03 ,  54.563,  51.98 ]]],


    [[[ 15.325,  15.59 ,  15.836,  16.062,  16.265],
      [ 15.522,  15.744,  15.943,  16.111,  16.24 ],
      [ 15.987,  16.144,  16.257,  16.301,  16.152],
      [ 16.163,  16.171,   6.061,   5.798,   4.404],
      [  5.832,   5.195,   3.426,   0.527,  -3.502]],

     [[ 68.016,  68.098,  68.169,  68.228,  68.274],
      [ 67.924,  67.981,  68.025,  68.054,  68.067],
      [ 68.036,  68.059,  68.066,  68.056,  68.021],
      [ 68.061,  68.039,  68.   ,  67.919,  67.794],
      [ 67.921,  67.822,  67.677,  67.366,  57.146]]],


    [[[ 16.121,  16.406,  16.678,  16.933,  17.171],
      [ 16.372,  16.626,  16.862,  17.078,  17.271],
      [ 16.921,  17.13 ,  17.314,  17.469,  17.585],
      [ 17.251,  17.395,  17.496,  17.527,  17.366],
      [ 17.249,  17.246,   7.126,   6.852,   5.447]],

     [[ 68.749,  68.846,  68.932,  69.008,  69.074],
      [ 68.692,  68.768,  68.832,  68.884,  68.925],
      [ 68.85 ,  68.898,  68.934,  68.957,  68.965],
      [ 68.939,  68.954,  68.961,  68.941,  68.906],
      [ 68.901,  68.87 ,  68.831,  68.744,  68.617]]]],



   [[[[ 12.947,  13.074,  13.112,   3.034,   2.8  ],
      [ 12.693,   2.693,   2.087,   0.35 ,  -2.518],
      [  1.618,  -0.496,  -3.741,  -8.117, -13.624],
      [ -5.192,  -9.944, -15.827, -22.84 , -30.984],
      [-18.306, -25.693, -34.212, -43.861, -54.64 ]],

     [[ 65.941,  65.962,  65.961,  65.932,  65.868],
      [ 65.688,  65.652,  65.575,  65.435,  65.141],
      [ 65.537,  65.364,  55.018,  54.691,  53.236],
      [ 55.031,  54.324,  52.497,  49.536,  45.456],
      [ 51.579,  48.239,  43.787,  38.195,  31.487]]],


    [[[ 13.954,  14.162,  14.337,  14.468,  14.529],
      [ 13.994,  14.11 ,  14.139,   4.049,   3.806],
      [ 14.093,   4.079,   3.459,   1.708,  -1.174],
      [  2.882,   0.755,  -2.502,  -6.891, -12.41 ],
      [ -4.106,  -8.869, -14.762, -21.786, -29.941]],

     [[ 66.789,  66.84 ,  66.876,  66.893,  66.891],
      [ 66.617,  66.631,  66.623,  66.588,  66.519],
      [ 66.614,  66.569,  66.484,  66.339,  66.039],
      [ 66.441,  66.259,  55.913,  55.577,  54.122],
      [ 55.87 ,  55.155,  53.328,  50.361,  46.278]]],


    [[[ 14.847,  15.096,  15.323,  15.525,  15.698],
      [ 15.001,  15.198,  15.363,  15.484,  15.535],
      [ 15.394,  15.496,  15.51 ,   5.407,   5.15 ],
      [ 15.356,   5.33 ,   4.697,   2.934,   0.04 ],
      [  3.968,   1.831,  -1.438,  -5.837, -11.366]],

     [[ 67.582,  67.654,  67.714,  67.761,  67.792],
      [ 67.465,  67.509,  67.538,  67.549,  67.542],
      [ 67.543,  67.548,  67.532,  67.492,  67.417],
      [ 67.518,  67.463,  67.379,  67.224,  66.925],
      [ 67.28 ,  67.09 ,  56.744,  56.402,  54.944]]],


    [[[ 15.672,  15.946,  16.204,  16.444,  16.665],
      [ 15.894,  16.132,  16.349,  16.541,  16.703],
      [ 16.4  ,  16.584,  16.734,  16.842,  16.879],
      [ 16.657,  16.746,  16.749,   6.633,   6.364],
      [ 16.443,   6.405,   5.762,   3.988,   1.083]],

     [[ 68.334,  68.423,  68.501,  68.568,  68.623],
      [ 68.258,  68.324,  68.377,  68.417,  68.443],
      [ 68.391,  68.426,  68.447,  68.453,  68.439],
      [ 68.447,  68.443,  68.427,  68.378,  68.302],
      [ 68.357,  68.294,  68.209,  68.049,  67.747]]],


    [[[ 16.448,  16.742,  17.021,  17.286,  17.535],
      [ 16.719,  16.983,  17.23 ,  17.46 ,  17.67 ],
      [ 17.294,  17.517,  17.72 ,  17.899,  18.048],
      [ 17.664,  17.835,  17.973,  18.068,  18.093],
      [ 17.744,  17.822,  17.813,   7.687,   7.408]],

     [[ 69.055,  69.156,  69.248,  69.331,  69.403],
      [ 69.01 ,  69.092,  69.163,  69.224,  69.273],
      [ 69.184,  69.241,  69.286,  69.32 ,  69.341],
      [ 69.295,  69.321,  69.341,  69.339,  69.325],
      [ 69.286,  69.274,  69.257,  69.203,  69.125]]]],



   [[[[ 13.398,  13.568,  13.688,  13.721,   3.636],
      [ 13.32 ,  13.303,   3.298,   2.685,   0.942],
      [  3.204,   2.215,   0.095,  -3.156,  -7.538],
      [ -0.982,  -4.609,  -9.366, -15.255, -22.274],
      [-11.47 , -17.733, -25.126, -33.65 , -43.305]],

     [[ 66.312,  66.348,  66.364,  66.359,  66.327],
      [ 66.099,  66.088,  66.048,  65.967,  65.825],
      [ 66.025,  65.928,  65.752,  55.405,  55.075],
      [ 65.656,  55.413,  54.711,  52.875,  49.92 ],
      [ 54.168,  51.953,  48.626,  44.159,  38.576]]],


    [[[ 14.347,  14.575,  14.776,  14.945,  15.07 ],
      [ 14.445,  14.605,  14.714,  14.737,   4.642],
      [ 14.72 ,  14.689,   4.669,   4.043,   2.286],
      [  4.468,   3.466,   1.333,  -1.93 ,  -6.324],
      [  0.104,  -3.533,  -8.302, -14.201, -21.23 ]],

     [[ 67.134,  67.195,  67.243,  67.274,  67.288],
      [ 66.988,  67.017,  67.027,  67.015,  66.978],
      [ 67.025,  67.005,  66.956,  66.871,  66.722],
      [ 66.929,  66.822,  66.646,  56.291,  55.961],
      [ 66.496,  56.244,  55.542,  53.7  ,  50.742]]],


    [[[ 15.208,  15.468,  15.71 ,  15.931,  16.128],
      [ 15.394,  15.611,  15.802,  15.961,  16.076],
      [ 15.844,  15.99 ,  16.086,  16.095,   5.986],
      [ 15.983,  15.94 ,   5.908,   5.269,   3.5  ],
      [  5.554,   4.541,   2.398,  -0.876,  -5.281]],

     [[ 67.909,  67.988,  68.057,  68.113,  68.155],
      [ 67.81 ,  67.865,  67.905,  67.93 ,  67.939],
      [ 67.915,  67.934,  67.936,  67.919,  67.875],
      [ 67.929,  67.9  ,  67.851,  67.756,  67.608],
      [ 67.768,  67.653,  67.477,  57.116,  56.783]]],


    [[[ 16.01 ,  16.293,  16.561,  16.813,  17.047],
      [ 16.255,  16.505,  16.736,  16.947,  17.133],
      [ 16.794,  16.997,  17.174,  17.319,  17.42 ],
      [ 17.108,  17.241,  17.324,  17.321,   7.2  ],
      [ 17.07 ,  17.015,   6.972,   6.323,   4.544]],

     [[ 68.647,  68.741,  68.825,  68.899,  68.962],
      [ 68.585,  68.658,  68.719,  68.769,  68.806],
      [ 68.737,  68.782,  68.814,  68.834,  68.836],
      [ 68.818,  68.829,  68.83 ,  68.804,  68.761],
      [ 68.768,  68.731,  68.682,  68.581,  68.431]]],


    [[[ 16.769,  17.069,  17.356,  17.63 ,  17.888],
      [ 17.057,  17.329,  17.587,  17.829,  18.053],
      [ 17.654,  17.89 ,  18.108,  18.305,  18.477],
      [ 18.057,  18.248,  18.412,  18.545,  18.634],
      [ 18.194,  18.316,  18.389,  18.375,   8.243]],

     [[ 69.355,  69.461,  69.559,  69.647,  69.725],
      [ 69.323,  69.41 ,  69.488,  69.555,  69.613],
      [ 69.511,  69.575,  69.628,  69.672,  69.704],
      [ 69.64 ,  69.677,  69.708,  69.719,  69.722],
      [ 69.658,  69.66 ,  69.661,  69.63 ,  69.584]]]]])

Tags: importselfdefnparraymaxgridv1