주뇽's 저장소
9. GNN 학습(2) Hierarchical CS224W: Machine Learning with Graphs 정리 본문
https://web.stanford.edu/class/cs224w
목차
1. GNN 학습(2) Hierarchical
- Hierarchical Pooling
- Hierarchical Aggregate
- Graph-Level
Issue : Graph Level의 예측을 수행할 때 Gloabal pooling은 그래프 정보를 유실시킬 수 있다.
ex):
- Node embeddings for G1 : {-1, -2, 0, 1, 2}
- Node embeddings for G2 : {-10, -20, 0, 10, 20}
G1과 G2 그래프는 완전히 다른 노드 임베딩을 가지고 있고 구조적으로도 완전히 다르다.
만약 여기서 gloabl sum pooling을 사용하게 된다면 결과값은 다음과 같다.
- Prediction for G1 : y_hat(G) = Sum( {-1, -2, 0, 1, 2} ) = 0
- Prediction for G2 : y_hat(G) = Sum( {-10, -20, 0, 10, 20} ) = 0
위와 같이 우리는 G1 그래프와 G2그래프의 다름을 구별할 수 없다!! 이를 해결하기 위해서 Hierarchical Pooling을 사용할 수 있다.
👉 1. GNN 학습(2) Hierarchical
- Hierarchical Pooling
모든 것을 동시에 동시에 집계하지 않고 더 작은 그룰을 집계한 다음 몇 개의 노드를 가져와 집계하는 방법이다.
ex):
- 우리는 처음 2개의 노드를 먼저 집계한 후 마지막 3개 노드를 집계한다.
- ReLU(sum())을 이용하여 최종 결과를 예측한다.
- embeddings for G1 : {-1, -2, 0, 1, 2}
- Round1 :
- Y_hat(a) = ReLU(Sum({-1, -2})) = 0
- Y_hat(b) = ReLu(Sum({0, 1, 2})) = 3
- Round2:
- Y_hat(G) = ReLU(Sum({Y_hat(a), Y_hat(b)}) = 3
- Round1 :
- embeddings for G2 : {-10, -20, 0, 10, 20}
- Round1 :
- Y_hat(a) = ReLU(Sum({-10, -20})) = 0
- Y_hat(b) = ReLu(Sum({0, 10, 20})) = 30
- Round2:
- Y_hat(G) = ReLU(Sum({Y_hat(a), Y_hat(b)}) = 30
- Round1 :
위와 같은 방식으로 임베딩을 계산하면 두 그래프의 차이를 구별할 수 있다. 그렇다면 여기서 생기는 의문점은 어떤 값을 먼저 집계할지, 어떻게 계층적으로 집계할지 알려줄지 결정하는 방법이다.
- Hierarchical Aggregate
그래프는 소위 커뮤니트 구조를 갖는 경향이 있다. 소셜 네트워크라고 생각해본다면 소셜 네트워크 내부에는 긴밀하게 연결된 커뮤니티가 있다. 따라서 이러한 커뮤니티를 미리 감지할 수 있다면 커뮤니티 내부의 노드를 커뮤니티 임베딩으로 집계한 다음 커뮤니티 임베딩을 슈퍼 커뮤니티 임베딩 등으로 계층적으로 집계할 수있다.
각 커뮤니티에 대해 슈퍼 노드를 생성 -하여 최종 값 예측
방법
각각의 레벨에서 독립적인 독립적인 GNN을 이용
- GNN A : 노드 임베딩을 계산
- GNN B : 노드가 속해있는 계층을 계산
GNN A와 B는 각각의 레벨에서 동시에 실행될 수 있다.
그래프 분해 알고리즘 참고 : Hierarchical Graph Representation Learning with Differentiable Pooling
📝 정리
GNN을 활용한 그래프 레벨 예측 과정에서 글로벌 풀링을 사용하면 그래프 정보가 유실될 수 있다.
이를 해결하기 위해 계층적 풀링(Hierarchical Pooling)과 계층적 집계(Hierarchical Aggregate) 기법을 적용할 수 있다.
- 계층적 풀링 : 전체 노드를 한번에 집계하지 않고 작은 그룹으로 나누어 단계적으로 집계하는 방식
- 계층적 집계 : 그래프의 커뮤니티 구조를 파악하여 커뮤니티 내 노드를 먼저 집계한 후 상위 계층으로 올라가며 점진적으로 집계하는 기법
이를 통해 그래프의 구조적 특성을 보존하면서도 효과적인 예측이 가능해진다.