Selective Federated Continual Learning in Multi-Modality Edge AI
본 포스트는 이화여자대학교 2025 캡스톤과디자인창업프로젝트 스타트 프로젝트를 위해 작성되었습니다.
과제 목표
본 연구는 'Pick-a-back: Selective Device-to-Device Knowledge Transfer in Federated Continual Learning' 를 Muti-modal, 다양한 Size의 데이터로 확장하는 후속연구입니다.
Continual Learning은 새로운 데이터를 학습하면서 기존에 학습한 지식을 잊지 않고 계속해서 학습을 이어가는 방식입니다.
이때 새로운 데이터를 지속적으로 학습하는 과정에서, 기존에 학습한 지식을 잃어버리는 Catastrophic forgetting 문제가 발생합니다
Federated Learning은 분산된 데이터를 각 장치가 보유한 상태에서 협력적으로 모델을 학습하는 방법입니다.
데이터가 중앙 서버에 모이지 않고 각 장치에 남아있어 데이터가 장치 간에 공유되지 않고도 모델을 학습할 수 있습니다.
두 개념을 결합한 federated continual learning은 각 장치가 학습한 모델을 인접 장치로부터 선택적으로 활용하는 방식입니다.
장치 간에 직접 데이터를 공유하지 않고도 연속학습을 수행하면서 새로운 지식을 축적할 수 있습니다.
Pick-a-back은 각 장치에서 모델을 학습시킨 후 그 모델 간의 유사도를 ModelDiff로 비교하였을 때, 유사도가 가장 높은 모델을 가져와 지식전이 하였을 때, 기존 모델보다 높은 성능을 보임과 동시에 Catastrophic Forgetting 문제를 해결할 수 있음을 보이는 연구입니다.
하지만 이러한 기존 연구에는 한계가 있었습니다:
- 실제 작업을 다룰 때는 다양한 모달리티와 multiple한 task 세팅이 존재합니다. 그러나 이전 연구에서는 이런 설정이 고려되지 않았기 때문에 이질적인 data와, multimodality, multiple task 환경에서 edge간 학습할 수 있는 효율적인 continual learning 알고리즘이 부재합니다.
- Decentralized 환경에서 raw data를 공유함으로써 높은 communication cost와 computation cost가 발생할 수 있습니다.
이에 저희는 더 범용적인 Cross-modal 상황에서 사용 가능한 Pick-a-back 알고리즘, MM-Pick-a-back을 만들었습니다.
Pick-a-back (선행연구) | MM-Pick-a-back (우리 연구) | |
모달리티 | Image | Multi-modal (Image, Text) |
Task | Image Classification | Image Classification, Text Classification |
연구 목적 | 학습된 데이터 패턴을 기준으로 선택적 지식 전이를 통해 개별 학습자의 성능을 개선 | 데이터의 모달리티 전이 방향에 따른 학습 성능 향상 패턴을 분석 후, 선택적 지식 전이를 통해 개별 학습자의 성능을 개선 |
소프트웨어 구조
MM-Pick-a-back의 구조는 다음과 같습니다.
자세한 기능별 용도와 사용법은 아래에서 설명하겠습니다.
구조별 기능 및 구현
0-1. Data Config
Cross-modal 실험을 위해, 동일한 분류 라벨을 공유하는 (Image, Text, Label)로 이루어진 데이터를 선정하였습니다:
- CUB-200 2011
- MSCOCO 2017
- Oxford 102 Flowers
해당 데이터들을 모달리티 별로 (Image / Text) 나눈 후,
각 모달리티의 데이터를 각각 5개의 subclass를 갖는 N개의 superclass로 나누었습니다.
이렇게 완성한 superclass-모달리티 의 분류 데이터셋에 각각 Task 번호를 부여하고, 각각 모델을 할당하여 Edge Model 상황을 상정하였습니다.
0-2. Model Config
본 연구에서는 다양한 모달리티의 데이터를 효과적으로 다룰 수 있는 하나의 모델로 모든 Edge Device를 학습시키는 것이 핵심이었습니다.
이에 저희는 이미지, 텍스트, 비디오, 오디오 등 다양한 I/O를 범용적으로 다룰 수 있는 모델 - PerceiverIO를 선택하였습니다.
1. Baseline
실험의 비교군입니다.
지속학습의 일종인 Packnet을 실행하여 Edge 모델을 학습시킵니다.
bash run_pickaback/baseline.sh
2. Without-Backbone
여기서부터 실험군입니다.
각 디바이스의 로컬 데이터로 모델을 학습시켜, Backbone을 구성합니다.
이 때, 지속학습의 일종인 CPG 를 사용합니다.
bash run_pickaback/wo_backbone.sh
3. Select Pruning Ratio
CPG로 완성한 각 디바이스의 backbone에서 점진적으로 Sparsity를 늘려가며 최적의 pruning rate을 골라,
지속학습을 위한 최적의 가지치기 ratio를 선택합니다.
이는 모델 지속학습 과정에서 Catastrophic Forgetting을 해결할 수 있는 방안입니다.
bash run_pickaback/select_pruning_ratio_of_backbone.sh
4. Find Backbone
이제, ModelDiff를 사용하여, 각 모델의 DDV 기반 유사도를 분석합니다.
이 때, 선행연구와 다른 문제가 발생합니다.
선행연구에서는 동일한 크기, 모달리티의 input을 사용하기 때문에, 서로 다른 디바이스의 데이터라도 어렵지 않게 각 모델에 입력할 수 있었습니다.
하지만, Image-Text는 서로 size, 입력형식이 모두 다른 데이터이므로 서로 다른 디바이스의 데이터를 모델에 입력해줄 수 없습니다.
이에 저희는 PerceiveIO의 Encoder layer에 존재하는 Key-value projection을 한 차례 거친 데이터를 서로 교환하여 모델에 입력하는 방식을 선택하였습니다.
이 방법은 서로 다른 크기나 모달리티의 데이터를 모델에 호환시켜줄 뿐 아니라, Raw Data가 아닌 한차례 인코딩 된 projection을 보내기 때문에 보안에 훨씬 강한 모습을 보일 것입니다.
bash run_pickaback/find_backbone_MM.sh
5. Transfer Key-Value Projection
Backbone을 고른 후, 그 Backbone모델을 가져온 후, 기존 Edge 모델의 key-value projection layer와 CPG 학습을 위한 piggymask를 backbone 모델에 이식 및 활성화를 해주어 최종적으로 로컬 데이터에 호환되는 backbone을 재구성합니다.
bash run_pickaback/transfer_kv.sh
6. With-Backbone
재구성한 Backbone을 로컬 데이터로 fine-tuning합니다.
마찬가지로 CPG를 사용하여 지속학습을 수행합니다.
bash run_pickaback/w_backbone_MM.sh
결과
Key-value projeciton을 교환했을 때에도, Pick-a-back이 유효함을 알 수 있었습니다.
또한, 단일 모달리티에서 backbone을 골랐을 때 뿐만 아니라, Cross-modal에서 backbone을 골랐을 때 역시 성능이 향상함을 확인하였습니다.
연구 레포지토리에서 자세한 코드를 확인해보실 수 있습니다.
https://github.com/EWHA-Tespa/MM-Pick-a-back
GitHub - EWHA-Tespa/MM-Pick-a-back: Multi-Modal Pick-a-back
Multi-Modal Pick-a-back. Contribute to EWHA-Tespa/MM-Pick-a-back development by creating an account on GitHub.
github.com