티스토리 뷰

RecSys

[RecSys] SASRec 코드 리뷰

오탱 2023. 7. 12. 16:55

https://github.com/pmixer/SASRec.pytorch

 

GitHub - pmixer/SASRec.pytorch: PyTorch(1.6+) implementation of https://github.com/kang205/SASRec

PyTorch(1.6+) implementation of https://github.com/kang205/SASRec - GitHub - pmixer/SASRec.pytorch: PyTorch(1.6+) implementation of https://github.com/kang205/SASRec

github.com

SASRec모델의 Pytorch 버전 코드를 리뷰한 글입니다!

 

1. Data_partition : utils.py의 data_partition 함수

 

cf) 

→ 우리가 사용하는 Steam.txt 데이터 형태

→ 1열 사용자 id, 2열 사용자가 구매한 게임 id / 시간순서별로 기록되어 있는 것!

 

 

 

 

 

 

각 유저를 key로, 아이템 리스트를 value로 하는 딕셔너리 만들고, train/valid/ test 데이터셋으로 나눠주기

2. Input Data 만들기 : utils.py의 random_neq, sample_function 함수

random_neq를 이용해 각 batch 사이즈별로 들어갈 데이터 랜덤하게 골라줌!

과정1 에서 각 유저당 아이템 sequence 데이터를 만들었음. 각 유저별로 구매한 아이템 개수가 다를 수 있기 때문에 각 유저별 시퀀스를 고정된 크기 max_len으로 통일하고 하나의 텐서로 구성해주기!

seq( embedding layer에 들어갈 input), neg, pos(Prediction Layer에서 loss 계산할때 사용) 3개의 데이터 리스트 생성

 

3. Embedding : model.py의 log2feats 함수 일부

3-1. Item Embedding

seq 텐서(batch_size x max_len) 가 input으로 들어가면, Item_embedding을 거쳐 (batch_size * max_len *d) 사이즈로 변환됨

3-2. Position Embedding

(batch_size * max_len *d) 사이즈 크기의 텐서

→ 그 후 item_embedding이랑 position embedding 더하고, dropout, zero-padding 다시 적용~

→ 여기서 완성된 텐서는 Self-attention의 Key,Value로 활용! Query는 여기다가 Layer normalization 거쳐서 활용됨 → 어텐션 함수에 넣을 준비 완.

 

4. SASRec Block 구조: model.py의 SASRec 클래스 초반 부분

 

input 데이터가 임베딩 레이어 거친 후, (LayerNorm - Attention, Layernorm - PointwiseFeedForward) 로 구성된 블록을 반복적으로 지나감!

 

5. Attention Layer : model.py의 log2feats 뒷부분

Query, Key, Value가 Attention 함수에 들어감. Attention mask 씌워져서 미래 시점 값을 참조하는 것을 막음!

→ 여기서 최종 출력값이 이제 Point-wise-FeedForward로 들어간다

 

6. Point-wise FeedForward : model.py의 PointwiseFeedForward 클래스

- 선형결합 연산이 2번 반복 -> 1d Convolution layer 이용해 구현, 이때 input으로 들어갈 tensor는 전치시켜서 들어가야함! 

- 1d Conv 연산은 시퀀스의 time step 별로 kernel size 만큼 옆으로 이동하면서 가중치를 곱해줌 전치해서 time step을 맞춰줘야함!

 

7. Prediction Layer : model.py의 forward 함수 / main.py , model.py 의 predict 함수

model.py forward() 함수

main.py

→ 앞선 과정을 통해 만들어진 최종적인 tensor가 pos 시퀀스와 neg 시퀀스와 각각 elementary wise 곱이 이루어짐

→ 다음 영화의 관련도 점수가 만들어지고, pos_logit은 최대화, neg_logit은 최소화되는 방향으로 loss 함수 학습이 이루어짐

model.py 의 predict 함수

→ 다음 시점 상품 예측하는 과정,,, log feats 텐서에서 가장 마지막 행만을 가지고 오고,, 유저가 구매하지 않은 상품 m개로 이루어진 벡터와 내적 연산

→ 1* m 차원의 output 반환

→ 여기서 가장 높은 값 가진 인덱스에 해당하는 상품 추천!

 

8. Evaluate : utils.py의 evaluate 함수

→ 저 predictions 값이 최종 관련도 점수 , 여기에 argsort() 이용해서 rank를 매겨줌

NDCG 이용해서 성능 평가! 

 

참고)

https://velog.io/@tobigs-recsys/Code-Review-2018-IEEE-Self-Attentive-Sequential-Recommendation-SASRechttps://velog.io/@rlawhddn1010/SASRec-code

 

[Code Review] (2018, IEEE) Self-Attentive Sequential Recommendation (SASRec)

작성자: 고유경앞선 논문 리뷰 게시물에 이어서 이번에는 SASRec을 Pytorch로 구현한 코드를 리뷰하겠습니다. 구현된 코드는 이 깃헙에서 만나보실 수 있습니다. 참고로 Tensorflow로 구현된 코드는 논

velog.io