[MMPose] MMPose 원하는 데이터로 학습하기
저번 글에서는 MMPose를 사용하기 위한 환경 세팅을 하는 방법에 대해 정리해 보았다.
https://rahites.tistory.com/326
간단하게 기존에 학습되어 있는 모델을 사용해서 Inference를 하는데에는 demo 코드를 사용하면 되지만, 대다수의 MMPose를 사용하려는 사람들은 본인이 원하는 데이터를 가지고 Pose 모델을 학습해서 사용하길 원한다고 생각한다(물론 나 포함).
따라서 이번에는 본인이 새롭게 구축한, 원하는 Customize 데이터를 가지고 MMPose에 존재하는 Pose Estimation 모델을 학습하는 방법에 대해 정리해 보도록 하겠다.
- 사용한 MMPose 버전: 1.3.2
- 사용한 Format: COCO Format(17 Keypoint)
0. 데이터 셋 세팅
모델 학습에 가장 중요한 것은 학습되는 데이터이다. 따라서 우리는 학습을 원하는 데이터를 MMPose 라이브러리에서 요구하는 형식으로 맞춰줄 필요가 있다. 관련된 내용은 아래 링크에 정리되어 있으며, 여기서 나는 핵심적인 부분만 몇 줄 정리해보도록 하겠다.
https://mmpose.readthedocs.io/en/latest/advanced_guides/customize_datasets.html
- dataset config 파일 생성: configs/_base_/datasets 위치에 custom.py 생성
- 여기서 나는 COCO Format을 그대로 사용했기 때문에 기존에 존재하던 coco.py를 그대로 참고하였다. 이 때 본인이 Keypoint의 개수와 이름을 다르게 정의하고 싶다면 원하는 대로 custom.py를 작성해주면 된다. - training config 파일 생성: config/ 위치에 my_custom_config.py 생성
- 이는 모델 학습에 사용할 config를 custom하여 저장하는 것으로 나의 경우 학습에 사용할 모델인 configs/body_2d_keypoint/top_down_heatmap/coco 아래에 있는 human pose estimation 모델들의 config를 가져와 사용하였다.
my_custom_config.py를 만들 때 주의할 점(물론 파일명은 달라도 된다)
1. __base__ 변수에 들어가는 default_runtime.py 경로를 확인
2. # base dataset settings에 나타나 있는 dataset_type, data_mode, data_root를 확인
3. dataloader 부분에 있는 ann_file, bbox_file, data_prefix 등의 경로를 확인
4. evaluator 부분에 있는 ann_file 경로를 확인
또한 my_custom_config.py에 아래의 변수를 추가하면 tensorboard를 통해 학습 결과를 확인할 수 있다.
https://mmpose.readthedocs.io/en/latest/user_guides/train_and_test.html#visualize-training-process
visualizer = dict(vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
])
텐서보드 확인
tensorboard --logdir ${WORK_DIR}/${TIMESTAMP}/vis_data
최종 config_custom.py 예시
_base_ = ['./_base_/default_runtime.py']
# runtime
train_cfg = dict(max_epochs=210, val_interval=10)
# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))
# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]
# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)
# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(288, 384), heatmap_size=(72, 96), sigma=3)
# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(48, 96)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(48, 96, 192)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(48, 96, 192, 384))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w48_8xb32-210e_coco-384x288-c161b7de_20220915.pth')
# checkpoint='https://download.openmmlab.com/mmpose/'
# 'pretrain_models/hrnet_w48-8ef0771d.pth'),
), # https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w48_8xb32-210e_coco-384x288-c161b7de_20220915.pth
head=dict(
type='HeatmapHead',
in_channels=48,
out_channels=17,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))
# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/{custom_location}/'
# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]
# data loaders
train_dataloader = dict(
batch_size=8,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
# change dataset information
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/train_coco_labels.json',
data_prefix=dict(img='train'),
pipeline=train_pipeline,
metainfo=dict(from_file='configs/_base_/datasets/custom.py')
))
val_dataloader = dict(
batch_size=8,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
# change dataset information
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/test_coco_labels.json',
bbox_file=None,
# bbox_file='data/coco/person_detection_results/'
# 'COCO_val2017_detections_AP_H_56_person.json',
# https://mmpose.readthedocs.io/en/latest/faq.html#data
# bbox json을 올려두면 detection 결과도 같이 평가할 수 있다.
data_prefix=dict(img='val'),
test_mode=True,
pipeline=val_pipeline,
metainfo=dict(from_file='configs/_base_/datasets/custom.py')
))
test_dataloader = val_dataloader
# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/test_coco_labels.json')
test_evaluator = val_evaluator
※ 만드는 데이터셋이 COCO Format이 아닌 경우에는 아래 링크를 참고하여 데이터 셋 클래스를 만들어줘야 한다.
1. Custom 데이터로 학습하기
학습 내용을 정리해 둔 MMPose Docs는 아래를 참고하면 된다.
https://mmpose.readthedocs.io/en/latest/user_guides/train_and_test.html
학습을 시작하기 전, Pretrain된 pt 파일을 받고 싶다면 아래의 코드를 사용하면 된다.
mim download mmpose --config td-hm_hrnet-w48_8xb32-210e_coco-384x288 --dest ./models/
여기서 --dest는 destination으로 pt 파일이 저장될 경로를 나타내고 다운받아지는 모델의 pt 파일은 config py 파일 속 model 변수 dict['init_cfg']에 담겨있는 checkpoint이다.
Backbone만 ImageNet으로 학습된 모델인지 COCO로 Pretrain된 모델인지 잘 확인하고 사용하는 것이 필요하다.
위의 내용이 마무리 되었다면 tools/train.py 파일을 활용하여 모델 학습을 진행할 수 있다.
python tools/train.py {config_file}
# ex.
python tools/train.py configs/hrnetw48_384x288_custom.py --resume --show-dir results/
- --work-dir : 학습 결과가 저장되는 기본 저장 경로는 work_dirs로 해당 경로를 바꾸고 싶을 때 사용한다.
- --show-dir : 학습 도중 validation 파일에 대한 inference 결과를 저장할지 여부를 선택하고 work_dirs 경로 내 지정 경로에 저장해준다.
- --no-validate : 학습 도중 validation 과정을 진행하지 않는다.
- --resume : 특정 checkpoint 명시하면 해당 checkpoint부터 학습을 이어서 진행(특정 checkpoint를 명시할 경우 기존에 학습된 모델의 epoch 횟수에 누적되어 내가 명시한 epoch 수까지 돌아간다. 0번부터 돌리고 싶다면 특정 checkpoint를 사용하지 않아야 한다).
2. Custom 데이터로 학습된 모델을 사용해서 test 진행하기
https://mmpose.readthedocs.io/en/latest/user_guides/train_and_test.html#test-your-model
학습이 정상적으로 완료되었다면 work_dirs 폴더 내에 Epoch 별로 저장된 pt 파일과 best, last epoch에 대한 pt파일이 저장되었을 것이다.
tools/test.py 파일을 사용하면 학습된 pt파일을 가지고 test를 진행할 수 있다.
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [ARGS]
# ex.
python tools/test.py configs/hrnetw48_384x288_custom.py work_dirs/hrnetw48_384x288_custom/best_coco_AP_*.pth --show-dir results/ --dump inference_results/hrnet_custom.pkl
- best_coco_AP_*.pth : 여러 모델에 대한 test 코드를 한번에 돌릴 때 모델별로 몇 Epoch가 Best 성능을 냈는지 모르기에 Wildcard 문자로 표현
- --dump: test시 inference를 실행한 모든 이미지에 대한 예측 결과를 .pkl 파일 형태로 저장
※ 테스트 결과를 보았을 때 BBox 밖으로 Keypoint가 예측되는 이유
https://mmpose.readthedocs.io/en/latest/faq.html#evaluation
이미지를 자르기 위해 bbox를 직접 사용하지 않기 때문이다. 먼저 bbox를 중앙 및 축척으로 변환하고 축척에 요소(1.25)를 곱하여 일부 컨텍스트를 포함하며, 가로/세로 비율이 모델 입력의 비율과 다른 경우(아마도 192/256) bbox를 조정한다.
3. Custom 데이터로 학습된 모델을 사용해서 Inference 진행하기
학습된 모델을 사용해서 원하는 이미지 1장에 대해 Inference를 진행하고 싶다면 demo/image_demo.py 파일을 사용하면 된다.
# ex.
python demo/image_demo.py \
data/{data_location}/train/{data_name}.png \
configs/hrnetw48_384x288_custom.py \
work_dirs/hrnetw48_384x288_custom/best_coco_AP_*.pth \
--out-file {data_name}.png \
--draw-heatmap
- --out-file : 저장하려는 파일명
- --draw-heatmap : 이미지를 시각화할 때 히트맵을 같이 그릴지 여부
- --thickness : 시각화시 keypoint들을 연결해주는 line 굵기
- --radius : 사긱화시 keypoint가 표시되는 원 크기
위에 언급된 내용 외에 디테일한 수정사항은 코드를 살펴보면서 평가지표 및 모델 구조를 변경하여 사용할 수도 있다. 환경 세팅에서는 패키지 버전을 맞추어주는 것이 가장 중요했다면, 학습이나 테스트 과정에서는 데이터 경로나 py 파일 경로를 정확하게 작성해주는 것이 중요하다.