주뇽's 저장소

9. GNN 학습(1) Prediction-head CS224W: Machine Learning with Graphs 정리 본문

GNN/CS224

9. GNN 학습(1) Prediction-head CS224W: Machine Learning with Graphs 정리

뎁쭌 2024. 4. 12. 17:55
728x90
반응형

https://web.stanford.edu/class/cs224w

목차

1. GNN 학습(1) Prediction-head

- Node-Level

- Edge-Level

- Graph-Level

https://web.stanford.edu/class/cs224w/slides/05-GNN3.pdf

👉 1. GNN 학습(1) Prediction-head

https://web.stanford.edu/class/cs224w/slides/05-GNN3.pdf

 

 

GNN의 출력은 노드 임베딩의 집합 {𝐡v ∈ ℝd, ∀𝑣 ∈ 𝐺} 이다. 이 노드 임베딩들을 가지고 최종적인 prediction을 생성하기 위해서는 prediction head 모듈이 필요하며 Prediction head는 노드, 에지, 그래프 레벨의 예측 작업에 따라 다르게 설계된다.

 

- Node-Level


방법: 

  • 가장 간단한 방법은 노드 임베딩에 대해 선형변환을 수행하는 것이다.

 

  • 노드 분류 문제의 경우 선형변환 결과에 Softmax를 적용해 multi-class 확률값을 얻고, 
  • 노드 회귀 문제라면 그대로 yv를 사용하면 된다.

- Edge-Level

입력: 에지 (u,v)를 구성하는 두 노드의 임베딩 hu ∈ ℝd, hv ∈ ℝd

방법1: 

  • 두 노드 임베딩을 concatenate하여 2d 차원 벡터를 만들고, 여기에 MLP를 적용한다.
  • yuv = MLP(concat(hu, hv))

방법2: 

  • 두 노드 임베딩 hu, hv에 대해 각각 선형변환 후 내적(dot product)을 취한다
  • 이는 에지 존재 여부를 예측하는 링크 예측 문제에 자주 쓰인다. 
  • yuv = (hu)⊤ W(1)hv, 여기서 W(1) ∈ ℝd×d는 학습 가능한 가중치 행렬

위 두가지 에지 prediction head의 출력 yuv는 k-way 에지 분류 문제의 경우 Softmax를 취해주고, 에지 회귀 문제라면 그대로 사용하면 됩다.

 

 

- Graph-Level

입력: 그래프 G 위의 모든 노드 임베딩 {hv ∈ ℝd, ∀𝑣 ∈ 𝐺}


방법1

  • 가장 간단한 방법은 노드 임베딩에 대해 평균(mean)을 취하는 것이다.
  • hG = Readout({hv}) = Mean({hv}), yG = MLP(hG)

방법2:

  •  최댓값(max)이나 합(sum)을 취하는 것도 가능하다. 
  • hG = Readout({hv}) = Max({hv}) 혹은 Sum({hv})

마지막으로 그래프 representation hG에 MLP를 적용하여 그래프 예측값 yG를 출력한다.

 

📝 정리

GNN의 출력인 노드 임베딩 {𝐡v ∈ ℝd, ∀𝑣 ∈ 𝐺}을 사용하여 최종 예측을 수행하기 위해서는 prediction head가 필요하며, 이는 노드, 에지, 그래프 레벨에 따라 다르게 설계된다.

1. 노드 레벨 prediction head:
   - 노드 임베딩에 선형변환을 적용하여 예측값을 계산
   - 분류 문제: 선형변환 후 Softmax 적용
   - 회귀 문제: 선형변환 결과 그대로 사용

2. 에지 레벨 prediction head:
   - 방법1: 두 노드 임베딩을 concatenate 후 MLP 적용
   - 방법2: 두 노드 임베딩에 선형변환 후 내적(dot product) 계산
   - 분류 문제: 출력에 Softmax 적용
   - 회귀 문제: 출력 그대로 사용

3. 그래프 레벨 prediction head:
   - 그래프의 모든 노드 임베딩을 집계(Readout)하여 그래프 representation을 계산
   - 방법1: 노드 임베딩의 평균(mean) 사용
   - 방법2: 노드 임베딩의 최댓값(max) 또는 합(sum) 사용
   - 방법3: Attention을 활용한 그래프 pooling
   - 그래프 representation에 MLP를 적용하여 최종 예측값 계산

노드, 에지, 그래프 레벨 각각에 적합한 prediction head를 설계하는 것이 GNN 모델 개발에 있어 중요한 부분이다. 다양한 방법들 중에서 해결하고자 하는 문제의 특성을 고려하여 적절한 방법을 선택해야 한다.