주뇽's 저장소

5. Graph Neural Networks(1) CS224W: Machine Learning with Graphs 정리 본문

GNN/CS224

5. Graph Neural Networks(1) CS224W: Machine Learning with Graphs 정리

뎁쭌 2024. 3. 26. 13:15
728x90
반응형

 

https://web.stanford.edu/class/cs224w

목차

1. 그래프 데이터의 특성과 도전 과제
- 그래프 신경망(GNN)의 기본 아이디어
- GNN의 계산 그래프와 집계 함수
- GNN 모델 파라미터 학습
2. 그래프 합성곱 신경망(GCN)
3. GNN과 기존 신경망 아키텍처와의 비교

👉 1. 그래프 데이터의 특성과 도전 과제

그래프 데이터는 이미지나 자연어 데이터와는 다음과 같은 차이점이 있어 전통적인 딥러닝 모델을 바로 적용하기 어렵다:

1. 노드의 수와 연결 구조가 불규칙함 (non-Euclidean)
2. 노드의 순서가 없음 (permutation invariant)
3. 노드마다 이웃의 수와 구조가 다름 (variable neighborhood)

https://web.stanford.edu/class/cs224w/slides/03-GNN1.pdf

 

 

1. 그래프 신경망(GNN)의 기본 아이디어

 

GNN은 노드의 임베딩을 학습할 때 노드의 이웃 정보를 반복적으로 전파하고 집계한다는 아이디어에서 출발한다. 구체적으로:

- 각 노드의 초기 임베딩은 입력 피처로 설정
- 층을 쌓아 올리면서 단계적으로 이웃 정보 집계
- 집계된 정보와 자신의 정보를 결합해 새로운 노드 임베딩 계산
- 여러 층을 거치며 넓은 범위의 이웃 정보가 취합됨

https://web.stanford.edu/class/cs224w/slides/03-GNN1.pdf

 

https://web.stanford.edu/class/cs224w/slides/03-GNN1.pdf

 

 

 

2. GNN의 계산 그래프와 집계 함수

 

GNN은 입력 그래프의 노드들에 각각 계산 그래프(computation graph)를 정의하고, 이를 통해 노드 임베딩을 계산한다.
GNN의 층 별 연산은 다음 두 단계로 이루어진다:

 



1. 이웃 집계(neighborhood aggregation) 이웃 집계 단계에서는 각 노드 v의 k번째 층 집계 벡터 a_v^(k)를 계산한다. 이 벡터는 노드 v의 이웃 노드들의 이전 층(k-1) 임베딩 벡터 h_u^(k-1)를 모아서 집계 함수 AGGREGATE^(k)를 적용한 결과이다.


- 집계 함수는 각 이웃 노드의 임베딩 벡터를 입력으로 받아, 그 정보를 요약하는 역할을 한다. 대표적인 집계 함수로는 평균(mean), 합(sum), 최대 풀링(max-pooling) 등이 있다.

이웃 집계 단계를 통해 노드 v는 자신의 직접 연결된 이웃들의 정보를 종합할 수 있게 된다.

2. 집계 정보 변환(feature transformation) 집계 정보 변환 단계에서는 이웃 집계 벡터 a_v^(k)와 노드 자신의 이전 층 임베딩 벡터 h_v^(k-1)를 연결(concatenation)한 후, 학습 가능한 가중치 행렬 W^(k)를 곱하고 비선형 활성화 함수 σ를 적용하여 노드 v의 k번째 층 임베딩 벡터 h_v^(k)를 계산한다.

 

3. GNN 모델 파라미터 학습

 

GNN의 학습 과정은 기존 딥러닝 모델과 유사하다. 주어진 손실 함수를 최소화하는 방향으로 층별 가중치 행렬(W,B)를 경사 하강법으로 업데이트한다.

 

Wk : 이웃 집계를 위한 가중치 행렬

Bk : 자기 벡터를 변환하기 위한 가중치 행렬

 

이 가중치 행렬은 여러 노드에서 공유된다! 즉, 네트워크의 모든 노드가 동일한 변환 행렬을 사용한다.

  • 파라미터 효율성: 모든 노드가 같은 가중치 행렬을 사용하므로 학습해야 할 파라미터의 수가 줄어다. 이는 모델의 공간 복잡도를 낮추고 학습을 더 효율적으로 만든다.
  • 귀납적 학습: 가중치 행렬이 노드 independent하므로, 한번 학습된 모델은 이전에 보지 못한 새로운 노드나 심지어 새로운 그래프에도 적용될 수 있다. 예를 들어, 우리 반에 전학생 민준이가 왔고 민준이의 친구 관계가 주어지면, 이미 학습된 가중치 행렬 W를 사용하여 민준이의 특성을 바로 계산할 수 있다.
  • 일반화 능력: 노드 independent한 가중치 행렬은 그래프 구조에 내재된 일반적인 패턴을 학습할 수 있다. 이는 모델의 일반화 능력을 향상시키고, 과적합의 위험을 줄여준다.

👉 그래프 합성곱 신경망 (Graph Convolutional Networks)

GCN은 GNN의 대표적인 구현체 중 하나로, 그래프 라플라시안 행렬을 변형한 필터로 근방 노드를 집계하는 방식이다. GCN의 층 별 연산은 다음과 같이 행렬 형태로 간단히 표현 된다: 

GCN은 그래프 구조를 인접 행렬로 명시적으로 나타내고 효율적인 행렬 연산을 통해 빠른 학습이 가능하다는 장점이 있다.

 

EX)
우리 학교에는 민수, 영희, 철수, 지민이라는 4명의 학생이 있고. 이들의 친구 관계는 다음과 같다.

- 민수의 친구: 영희, 철수
- 영희의 친구: 민수, 지민 
- 철수의 친구: 민수
- 지민이의 친구: 영희


1단계: 인접 행렬 A 만들기

  민수 영희 철수 지민
민수 0 1 1 0
영희 1 0 0 1
철수 1 0 0 0
지민 0 1 0 0



2단계: 차수 행렬 D와 그 역행렬 D^-1 만들기

- 차수행렬 D(대각 원소에 자신과 연결된 친구 수)

  민수 영희 철수 지민
민수 2 0 0 0
영희 0 2 0 0
철수 0 0 1 0
지민 0 0 0 1



- 역행렬 (D의 대각선 요소를 역수로 바꾼 행렬)

  민수 영희 철수 지민
민수 1/2 0 0 0
영희 0 1/2 0 0
철수 0 0 1 0
지민 0 0 0 1



3단계: H^(k) 행렬 
H^(k)는 k번째 층에서 각 학생의 임베딩 벡터를 나타낸다. 예를 들어 H^(k)가 다음과 같다고 해 보자(임베딩 차원 : 2)

민수 영희 철수 지민
1.0 0.5 1.2 0.3
0.7 1.1 0.9 1.4



4단계: AH^(k) 계산하기
(AH^(k))[i]는 i번 학생의 친구들의 임베딩 벡터를 모두 더한 것

민수 영희 철수  지민
1.7 1.3 1.0 0.5
2.0 2.3 0.7 1.1

- 민수와 연결된 친구 : 영희, 철수

- 영희의 임베딩 벡터 [0.5, 1.1]

- 철수의 임베딩 벡터 [1.2, 0.9]

업데이트 된 민수의 임베딩 벡터  =  [1.7, 2.0] ( 친구들)

 

예를 들어, 민수의 경우 (AH^(k))[민수] = H^(k)[영희] + H^(k)[철수] = [0.5, 1.1] + [1.2, 0.9] = [1.7, 2.0] 이 된다.

(영희를 갱신할 때 방금 갱신된 민수의 임베딩 벡터를 사용하는것이 아닌 이전 레이어에서의 임베딩을 사용해야 한다!)

    
5단계: D^-1AH^(k) 계산하기
마지막으로 D^-1과 AH^(k)를 곱하면 D^-1AH^(k)를 얻을 수 있. 이는 각 학생의 임베딩 벡터를 그 학생의 친구 수로 나눈 것과 같다.

 

위에서 구한 민수의 새로운 특성을 민수의 친구 수(2)로 나눈다. 민수의 최종 특성 = [2.7, 2.7] / 2 = [1.35, 1.35]

민수 영희 철수 지민
0.86 0.65 1.0 0.5
1 1.15 0.7 1.1


이 새로운 임베딩 벡터는 그 학생의 원래 임베딩 벡터와 친구들의 임베딩 벡터의 평균을 합친 것과 같다. 이를 통해 친구 관계라는 그래프 구조의 정보가 학생들의 임베딩 벡터에 반영된다.

 

👉 GNN과 기존 신경망 아키텍처와의 비교

GNN은 기존의 CNN, RNN, Transformer 등 주요 신경망 아키텍처를 그래프 도메인으로 일반화한 것으로 볼 수 있다.

  • CNN: 고정된 크기의 격자(grid) 위에서 sliding window로 지역 정보를 집계 vs. GNN: 불규칙한 그래프 위에서 가변 크기의 이웃을 adaptive하게 집계

  • Transformer: self-attention을 통해 모든 토큰이 서로 attend vs. GNN: 노드의 계산 그래프가 Transformer의 fully-connected attention과 유사

요컨대, GNN은 기존 신경망의 핵심 아이디어를 차용하면서도 그래프 고유의 특성을 반영할 수 있도록 일반화된 강력한 프레임워크라 할 수 있다.

📝 정리

  • 그래프 데이터는 노드의 순서가 없고 연결 구조가 불규칙하여 기존 딥러닝 모델을 직접 적용하기 어려움
  • GNN은 노드의 이웃 정보를 반복적으로 집계하고 변환하여 노드 임베딩을 학습
  • GNN의 계산은 이웃 집계와 집계 정보 변환의 두 단계로 이루어짐
  • GNN의 학습은 손실 함수를 최소화하는 방향으로 층별 가중치 행렬을 업데이트
  • GNN의 가중치 행렬은 그래프의 모든 노드에서 공유되어 파라미터 효율성, 귀납적 학습, 일반화 능력 등의 장점이 있음
  • GCN은 GNN의 대표적 구현체로, 그래프 라플라시안 행렬을 변형한 필터로 근방 노드를 효율적으로 집계하며 행렬 연산으로 계산이 간단히 표현됨
  • GNN은 CNN, Transformer 등 기존 신경망 아키텍처를 그래프 도메인으로 일반화한 유연하고 강력한 프레임워크