[程式碼筆記] 利用程式碼實現 KNN

 因為前陣子在審視架構的時候,亂報(X)沒讀熟(O),

所以架構被老師質疑,被遣送回實驗室重讀一遍 DGCNN 怎麼刻 KNN 的部分?


以下的 KNN 程式碼 都是 base on DGCNN 的 model.py 裡面的 KNN function。



前情提要:


1. KNN 要幹嘛?

:拿到兩點的歐式距離。


2. 歐式距離是什麼?

吃公式:





我們塞進 KNN function 的資料 x 會長得像是:

[ Batch_size, Channel, Number_of_point ]


其中 Channel 是資料維度的意思,如果是三維的 x, y ,z 座標,

Channel 就會等於 3。


接下來來看程式碼:


inner = -2*torch.matmul(x.transpose(2, 1), x)

因為 torch.matmul 代表矩陣相乘

ineer 出來的資料尺寸就會是

[ Batch_size, Number_of_point , Number_of_point ]


令 P1 = [ x1, y1, z1], P2 =[ x2, y2, z2]

ineer 所代表的數學意義為

[ x1*x2 + y1*y2 + z1*z2]


再來看第二行程式碼:


xx = torch.sum(x**2, dim=1, keepdim=True)


xx 出來的資料尺寸為

Batch_size1 , Number_of_point ]

因為 keepdim=True 的關係,所有 sum 起來的資料會存到第一格。


xx 所代表的數學意義為

[x1^2 + y1^2 + z1^2, ... , xn^2 + yn^2 + zn^2]


最後一行程式碼就是組合起來,變成歐式距離的平方:


pairwise_distance = -xx - inner - xx.transpose(2, 1)

會變成:


就是歐式距離平方的展開。


至於負號是為了後面要做 TopK 所以放上的。

以上 ODO





留言

這個網誌中的熱門文章

[心得] 破解 Google 雲端下載限制(非建立副本)