Files
Obsidian-Main/20.01. Programming/numpy/numpy axis 運算.md

3.2 KiB
Raw Blame History

numpy有很多運算都可以指定axis例如x.sum(axis=0)或是x.max(axis=0)之類。 axis在2軸像是 [[1, 2], [3, 4]] 可以理解成x方向或是y方向。但是在更多軸的情況下就很難這樣理解了。

我的理解方式是把axis當作「第幾層」。例如x.sum(axis=0)就是把「第0層」之下的東西都加起來例如說有一個array a長這樣:

array([[[1. , 2. , 3. ],
        [4. , 5. , 6. ],
        [7. , 8. , 9. ]],

       [[0.1, 0.2, 0.3],
        [0.4, 0.5, 0.6],
        [0.7, 0.8, 0.9]]])

axis=0

那麼np.sum(a, axis=0)就是把「第0層」之下的東西都加起來a的shape是(2, 3, 3)所以第0層之下就是有2個3x3的array,也就是

[[1. , 2. , 3. ],
 [4. , 5. , 6. ],
 [7. , 8. , 9. ]]

[[0.1, 0.2, 0.3],
 [0.4, 0.5, 0.6],
 [0.7, 0.8, 0.9]]

要加起來也就是:

np.array([[1. , 2. , 3. ],
          [4. , 5. , 6. ],
          [7. , 8. , 9. ]]) + 
np.array([[0.1, 0.2, 0.3],
          [0.4, 0.5, 0.6],
          [0.7, 0.8, 0.9]])

答案跟np.sum(a, axis=0)是一樣的。

axis=1

那麼np.sum(a, axis=1)也就是把「第1層」之下的東西都加起來a的shape是(2, 3, 3)所以「第1層」有2個分別是

[[1. , 2. , 3. ],
 [4. , 5. , 6. ],
 [7. , 8. , 9. ]]

[[0.1, 0.2, 0.3],
 [0.4, 0.5, 0.6],
 [0.7, 0.8, 0.9]]

這2個各自會產生各自的結果先看第一個。我們要把「第1層」之下的東西都加起來「第1層」之下的東西就是

[1. , 2. , 3. ],
[4. , 5. , 6. ],
[7. , 8. , 9. ]

我要把他們加起來,也就是[1. , 2. , 3. ] + [4. , 5. , 6. ] + [7. , 8. , 9. ] = [12., 15., 18.]

再看第二個我們要把「第1層」之下的東西都加起來「第1層」之下的東西就是

[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]

我要把他們加起來,也就是[0.1, 0.2, 0.3] + [0.4, 0.5, 0.6] + [0.7, 0.8, 0.9] = [1.2, 1.5, 1.8]

所以np.sum(a, axis=1)的答案就是:

[[12., 15., 18.],
 [1.2, 1.5, 1.8]]

axis=2

那麼np.sum(a, axis=2)也就是把「第2層」之下的東西都加起來a的shape是(2, 3, 3)所以「第1層」有2個分別是

[[1. , 2. , 3. ],
 [4. , 5. , 6. ],
 [7. , 8. , 9. ]]

[[0.1, 0.2, 0.3],
 [0.4, 0.5, 0.6],
 [0.7, 0.8, 0.9]]

而這2個第1層又各自有3個的第2層分別是

[##第0層
    [##第1層-0
        [1. , 2. , 3. ]  ##第2層-0 <-- 裡面要加起來
        [4. , 5. , 6. ]  ##第2層-1 <-- 裡面要加起來
        [7. , 8. , 9. ]  ##第2層-2 <-- 裡面要加起來
    ],
    [##第1層-1
        [0.1, 0.2, 0.3]  ##第2層-0 <-- 裡面要加起來
        [0.4, 0.5, 0.6]  ##第2層-1 <-- 裡面要加起來
        [0.7, 0.8, 0.9]  ##第2層-2 <-- 裡面要加起來
    ]
]

總共有6個加起來之後就變成

[
    [
        [1. , 2. , 3. ] # 1+2+3 = 6
        [4. , 5. , 6. ] # = 15
        [7. , 8. , 9. ] # = 24
    ],
    [
        [0.1, 0.2, 0.3] # 0.1+0.2+0.3 = 0.6
        [0.4, 0.5, 0.6] # = 1.5
        [0.7, 0.8, 0.9] # = 2.4
    ]
]

所以np.sum(a, axis=2)的答案就是:

[[ 6. , 15., 24. ],
 [ 0.6, 1.5, 2.4]]