주뇽's 저장소

[Code] NGCF(Neural Graph Collaborative Filtering) 본문

GNN

[Code] NGCF(Neural Graph Collaborative Filtering)

뎁쭌 2023. 8. 19. 13:10
728x90
반응형

CF + GNN 모델 조사

그래프 구성의 차이

일반적인 GNN모델들은 단순하게 사용자-아이템 이분 그래프에 직접 GNN을 적용하면 다음과 같은 문제

  1. 학습하기에 표현이 충분하지 않음
  2. 대규모 그래프의 경유 높은 계산 비용 발생

GCCF(Graph Convolutional Collaborative Filtering)

 기존의 사용자와 아이템 간의 상호작용 행렬을 분해하여 임베딩 하는 방식의 MF 모델을 개선하기 위하여 나온 모델로 그래프 구조를 고려하여 사용자와 아이템 간의 상호작용을 모델링한다. GCCF에서는 그래프를 구성할 때, 사용자와 아이템을 노드로 표현하고, 이들 간의 상호작용을 엣지로 표현합니다. 이 그래프를 바탕으로 GCN을 적용하여 사용자와 아이템의 임베딩을 학습하고, 이를 기반으로 새로운 사용자와 아이템에 대한 예측을 수행한다.

사용자 아이템 상호작용
노드 노드 엣지

 GCCF는 여러 개의 레이어로 이루어진 GCN을 사용하여 그래프 구조를 반영한.다ㄷ이 때, GCN의 각 레이어에서는 아래 레이어에서 정보를 전파하여 현재 레이어의 임베딩을 업데이트한다. 이를 통해 모델은 상호작용이 일어난 사용자와 아이템의 정보뿐만 아니라, 이들 간의 관계 정보도 함께 고려하여 임베딩을 학습한다.

 

Multi-GCCF (복수의 도메인에서의 추천을 위한 모델)

 이전에는 도메인 간의 차이 때문에 각 도메인마다 개별적인 모델을 만들어야 했으나, Multi-GCCF는 여러 도메인 간의 유사성을 고려하여 한번에 학습이 가능하다.

 Multi-GCCF 모델은 GCCF 모델에서 발전된 모델로, 다수의 도메인에서의 유저-아이템 평가행렬을 이용하여 학습한다. 이때 모든 도메인이 같은 잠재공간(latent space)에서의 임베딩을 공유하게 된다.

  • 각 모델에서 학습한 잠재공간을 가중치를 부여하여 더하고, 최종적으로 평균을 내어 공유하는 방법

Multi-GCCF는 아래와 같은 특징을 가집니다.

  1. 다수의 도메인에서의 평가행렬을 고려하여 하나의 모델로 학습할 수 있다.
  2. 공통 잠재공간에서의 임베딩을 공유하게 되어, 서로 다른 도메인에서의 유저-아이템 간의 관계를 고려하여 추천을 수행할 수 있다.
  3. 다양한 종류의 데이터(텍스트, 이미지 등)를 적절하게 통합하여 추천에 반영할 수 있다.

하지만 Multi-GCCF 모델의 학습이 상대적으로 더 어려우며, 도메인의 수가 많을수록 성능 저하가 발생할 가능성이 있다 따라서 모델 구성 및 하이퍼파라미터 설정에 대한 연구가 필요하다.

 

DGCF(Deep Graph Collaborative Filtering)

 사용자와 아이템을 독립적으로 임베딩한 후 이를 조합하여 예측하는 방식과 달리, 사용자와 아이템이 모두 그래프 상의 노드로 표현되며 이를 바탕으로 예측을 수행한다.

 DGCF 모델은 사용자-아이템 그래프와 아이템-아이템 그래프를 모두 활용한다. 사용자-아이템 그래프는 사용자가 아이템을 구매한 기록을 가지고 노드를 형성하며, 아이템-아이템 그래프는 같은 사용자가 여러 아이템을 구매한 경우 해당 아이템들 간의 연결을 형성하여 그래프를 생성합니다.

 DGCF 모델은 두 개의 임베딩 행렬과 연결 예측 함수로 이루어져 있다. 첫 번째 임베딩 행렬은 사용자를, 두 번째 임베딩 행렬은 아이템을 나타냅니다. 임베딩 행렬은 초기에 무작위로 초기화되며, 학습 과정에서 각 노드의 임베딩 벡터가 업데이트된다. 연결 예측 함수는 두 개의 임베딩 벡터를 입력으로 받아 두 벡터의 내적을 통해 예측 점수를 계산한다.

 DGCF 모델에서는 GNN을 사용하여 사용자-아이템 그래프와 아이템-아이템 그래프를 학습한다. GNN은 각 노드의 임베딩 벡터를 입력으로 받아 해당 노드와 이웃한 노드들의 정보를 고려한 새로운 임베딩 벡터를 출력한다. 이 때, DGCF 모델에서는 사용자와 아이템에 대해 서로 다른 GNN을 사용한다.