带有`情节性'的动画情节`

2024-10-06 08:36:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我想用plotly库绘制MLE算法的收敛过程

要求:

  • 这些点必须按照簇的颜色进行着色,并在每次迭代中相应地更改
  • 应在每次迭代中绘制簇的质心

单个迭代的绘图可由Code 1生成,所需输出如Figure 1所示:

代码1

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )
fig.show()

图1

Figure 1

我看过this教程,但似乎您只能在graph_objects.Frame()中绘制一条轨迹,Code 2代表了一个生成所有点的动画散点图的简单示例,其中每个帧绘制来自不同簇和质心的点:

代码2

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 1', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 2', marker_color=colors[1])]),
            go.Frame(data=[go.Scatter(x=A[3:5][:,0], y=A[3:5][:,1], mode='markers', name='cluster 3', marker_color=colors[2])]),
            go.Frame(data=[go.Scatter(x=A[5:8][:,0], y=A[5:8][:,1], mode='markers', name='cluster 4', marker_color=colors[3])]),
            go.Frame(data=[go.Scatter(x=A[8:][:,0], y=A[8:][:,1], mode='markers', name='cluster 5', marker_color=colors[4])]),
            go.Frame(data=[go.Scatter(x=[centroids[0][0]], y=[centroids[0][1]], mode='markers', name='centroid of cluster 1', marker_color=colors[0], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[1][0]], y=[centroids[1][1]], mode='markers', name='centroid of cluster 2', marker_color=colors[1], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[2][0]], y=[centroids[2][1]], mode='markers', name='centroid of cluster 3', marker_color=colors[2], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[3][0]], y=[centroids[3][1]], mode='markers', name='centroid of cluster 4', marker_color=colors[3], marker_symbol='x')]),
            go.Frame(data=[go.Scatter(x=[centroids[4][0]], y=[centroids[4][1]], mode='markers', name='centroid of cluster 5', marker_color=colors[4], marker_symbol='x')])]
)
fig.show()

为什么代码2不适合我的需要:

  • 我需要将Code 2生成的所有帧绘制在算法每次迭代的单个帧中(即,所需解决方案的每个帧看起来像Figure 1

我尝试过的:

  • 我曾尝试生成一个graph_objects.Figure(),并将其添加到graph_objects.Frame()中,如Code 3所示,但得到了Error 1

代码3:

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure()
for i in range(5):
    fig.add_trace(
        go.Scatter(
            x=A[i:i+3][:, 0],
            y=A[i:i+3][:, 1],
            mode='markers',
            name=f'cluster {i+1}',
            marker_color=colors[i]
        )
    )
    
for c in clusters:
    fig.add_trace(
        go.Scatter(
            x=[centroids[c-1][0]],
            y=[centroids[c-1][1]],
            name=f'centroid of cluster {c}',
            mode='markers',
            marker_color=colors[c-1],
            marker_symbol='x'
        )
    )

animated_fig = go.Figure(
    data=[go.Scatter(x=A[:3][:, 0], y=A[:3][:, 1], mode='markers', name=f'cluster 0', marker_color=colors[0])],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[fig])]
)

animated_fig.show()

错误1:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-681-11264f38e6f7> in <module>
     43                           args=[None])])]
     44     ),
---> 45     frames=[go.Frame(data=[fig])]
     46 )
     47 

~\Anaconda3\lib\site-packages\plotly\graph_objs\_frame.py in __init__(self, arg, baseframe, data, group, layout, name, traces, **kwargs)
    241         _v = data if data is not None else _v
    242         if _v is not None:
--> 243             self["data"] = _v
    244         _v = arg.pop("group", None)
    245         _v = group if group is not None else _v

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in __setitem__(self, prop, value)
   3973                 # ### Handle compound array property ###
   3974                 elif isinstance(validator, (CompoundArrayValidator, BaseDataValidator)):
-> 3975                     self._set_array_prop(prop, value)
   3976 
   3977                 # ### Handle simple property ###

~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in _set_array_prop(self, prop, val)
   4428         # ------------
   4429         validator = self._get_validator(prop)
-> 4430         val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
   4431 
   4432         # Save deep copies of current and new states

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
   2671 
   2672             if invalid_els:
-> 2673                 self.raise_invalid_elements(invalid_els)
   2674 
   2675             v = to_scalar_or_list(res)

~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in raise_invalid_elements(self, invalid_els)
    298                     pname=self.parent_name,
    299                     invalid=invalid_els[:10],
--> 300                     valid_clr_desc=self.description(),
    301                 )
    302             )

ValueError: 
    Invalid element(s) received for the 'data' property of frame
        Invalid elements include: [Figure({
    'data': [{'marker': {'color': 'red'},
              'mode': 'markers',
              'name': 'cluster 1',
              'type': 'scatter',
              'x': array([-1.30634452, -1.73005459,  0.58746435]),
              'y': array([ 0.15388112,  0.47452796, -1.86354483])},
             {'marker': {'color': 'green'},
              'mode': 'markers',
              'name': 'cluster 2',
              'type': 'scatter',
              'x': array([-1.73005459,  0.58746435, -0.27492892]),
              'y': array([ 0.47452796, -1.86354483, -0.20329897])},
             {'marker': {'color': 'blue'},
              'mode': 'markers',
              'name': 'cluster 3',
              'type': 'scatter',
              'x': array([ 0.58746435, -0.27492892,  0.21002816]),
              'y': array([-1.86354483, -0.20329897,  1.99487636])},
             {'marker': {'color': 'yellow'},
              'mode': 'markers',
              'name': 'cluster 4',
              'type': 'scatter',
              'x': array([-0.27492892,  0.21002816, -0.0148647 ]),
              'y': array([-0.20329897,  1.99487636,  0.73484184])},
             {'marker': {'color': 'magenta'},
              'mode': 'markers',
              'name': 'cluster 5',
              'type': 'scatter',
              'x': array([ 0.21002816, -0.0148647 ,  1.13589386]),
              'y': array([1.99487636, 0.73484184, 2.08810809])},
             {'marker': {'color': 'red', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 1',
              'type': 'scatter',
              'x': [9],
              'y': [6]},
             {'marker': {'color': 'green', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 2',
              'type': 'scatter',
              'x': [0],
              'y': [5]},
             {'marker': {'color': 'blue', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 3',
              'type': 'scatter',
              'x': [8],
              'y': [6]},
             {'marker': {'color': 'yellow', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 4',
              'type': 'scatter',
              'x': [7],
              'y': [1]},
             {'marker': {'color': 'magenta', 'symbol': 'x'},
              'mode': 'markers',
              'name': 'centroid of cluster 5',
              'type': 'scatter',
              'x': [6],
              'y': [2]}],
    'layout': {'template': '...'}
})]

    The 'data' property is a tuple of trace instances
    that may be specified as:
      - A list or tuple of trace instances
        (e.g. [Scatter(...), Bar(...)])
      - A single trace instance
        (e.g. Scatter(...), Bar(...), etc.)
      - A list or tuple of dicts of string/value properties where:
        - The 'type' property specifies the trace type
            One of: ['area', 'bar', 'barpolar', 'box',
                     'candlestick', 'carpet', 'choropleth',
                     'choroplethmapbox', 'cone', 'contour',
                     'contourcarpet', 'densitymapbox', 'funnel',
                     'funnelarea', 'heatmap', 'heatmapgl',
                     'histogram', 'histogram2d',
                     'histogram2dcontour', 'image', 'indicator',
                     'isosurface', 'mesh3d', 'ohlc', 'parcats',
                     'parcoords', 'pie', 'pointcloud', 'sankey',
                     'scatter', 'scatter3d', 'scattercarpet',
                     'scattergeo', 'scattergl', 'scattermapbox',
                     'scatterpolar', 'scatterpolargl',
                     'scatterternary', 'splom', 'streamtube',
                     'sunburst', 'surface', 'table', 'treemap',
                     'violin', 'volume', 'waterfall']

        - All remaining properties are passed to the constructor of
          the specified trace type

        (e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])
  • 通过使用plotly.express模块,我成功地获得了每个帧中的所有点,如Code 3所示,但唯一缺少的是将质心标记为xs

代码3:

import plotly.express as px
import numpy as np
import pandas as pd

A = np.random.randn(200).reshape((100, 2))
iteration = np.array([1, 2, 3, 4, 5]).repeat(20)
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = np.random.randint(1, 6, size=100)
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

df = pd.DataFrame(dict(x1=A[:, 0], x2=A[:, 1], type='point', cluster=pd.Series(clusters, dtype='str'), iteration=iteration))
centroid_df = pd.DataFrame(dict(x1=centroids[:, 0], x2=centroids[:, 1], type='centroid', cluster=[1, 2, 3, 4, 5], iteration=[1, 2, 3, 4, 5]))
df = df.append(centroid_df, ignore_index=True)
px.scatter(df, x="x1", y="x2", animation_frame="iteration", color="cluster", hover_name="cluster", range_x=[-10,10], range_y=[-10,10])

我将非常感谢为达到预期结果所提供的任何帮助。谢谢


Tags: ofnamegodatamodetypenpmarker
1条回答
网友
1楼 · 发布于 2024-10-06 08:36:05

您可以为每帧添加两个跟踪,但显然您也需要在第一个data中定义这两个跟踪。我再次添加前两个轨迹作为一个帧,以便在后续播放中可以看到它们。这里是完整的代码

import plotly.graph_objects as go
import numpy as np

A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']

fig = go.Figure(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ],
    layout=go.Layout(
        xaxis=dict(range=[-10, 10], autorange=False),
        yaxis=dict(range=[-10, 10], autorange=False),
        title="Start Title",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None]),
                     dict(label="Pause",
                          method="animate",
                          args=[None,
                               {"frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}}],
                         )])]
    ),
    frames=[
    go.Frame(
    data=[go.Scatter(x=A[:3][:,0],
                     y=A[:3][:,1],
                     mode='markers',
                     name='cluster 1',
                     marker_color=colors[0]),
          go.Scatter(x=[centroids[0][0]],
                     y=[centroids[0][1]],
                     mode='markers',
                     name='centroid of cluster 1',
                     marker_color=colors[0],
                     marker_symbol='x')
         ]),
    go.Frame(
        data=[
            go.Scatter(x=A[:3][:,0],
                       y=A[:3][:,1],
                       mode='markers',
                       name='cluster 2',
                       marker_color=colors[1]),
            go.Scatter(x=[centroids[1][0]],
                       y=[centroids[1][1]],
                       mode='markers',
                       name='centroid of cluster 2',
                       marker_color=colors[1],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[3:5][:,0],
                       y=A[3:5][:,1],
                       mode='markers',
                       name='cluster 3',
                       marker_color=colors[2]),
            go.Scatter(x=[centroids[2][0]],
                       y=[centroids[2][1]],
                       mode='markers',
                       name='centroid of cluster 3',
                       marker_color=colors[2],
                       marker_symbol='x')
        ]),
    go.Frame(
        data=[
            go.Scatter(x=A[5:8][:,0],
                       y=A[5:8][:,1],
                       mode='markers',
                       name='cluster 4',
                       marker_color=colors[3]),
        go.Scatter(x=[centroids[3][0]],
                   y=[centroids[3][1]],
                   mode='markers',
                   name='centroid of cluster 4',
                   marker_color=colors[3],
                   marker_symbol='x')]),
    go.Frame(
        data=[
            go.Scatter(x=A[8:][:,0],
                       y=A[8:][:,1],
                       mode='markers',
                       name='cluster 5',
                       marker_color=colors[4]),
            go.Scatter(x=[centroids[4][0]],
                       y=[centroids[4][1]],
                       mode='markers',
                       name='centroid of cluster 5',
                       marker_color=colors[4],
                       marker_symbol='x')
        ]),
    ])
            
fig.show()

enter image description here

相关问题 更多 >