1.1 mmdetction 安裝
1.1.1 系統(tǒng)環(huán)境需求
參考 mmdetection 官方文檔 https://mmdetection.readthedocs.io/en/latest/INSTALL.html,系統(tǒng)環(huán)境需求如下:
- Linux (Windows is not officially supported)
- Python 3.5+
- PyTorch 1.1 or higher
- CUDA 9.0 or higher
- NCCL 2
- GCC 4.9 or higher
- mmcv
我的系統(tǒng)環(huán)境:
- CentOS 7.2
- Python 3.7
- PyTorch 1.1.0
- CUDA 10.0
- NCCL 2
- GCC 7.4.0
- mmcv
安裝 mmcv 時(shí)的依賴項(xiàng)如下:
addict
numpy
pyyaml
six
其中,我的環(huán)境中缺少 addict 和 pyyaml,從 https://pypi.org/ 中下載源碼離線安裝。
附:查看深度學(xué)習(xí)軟件/庫/工具的命令速查表:
1.1.2 安裝 mmdetection
官方文檔的安裝說明如下,適合網(wǎng)絡(luò)環(huán)境好的條件下進(jìn)行在線安裝,
# Clone the mmdetection repository
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
# Install build requirements and then install mmdetection.
pip install -r requirements/build.txt
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
pip install -v -e . # or "python setup.py develop"
這里采用離線安裝方式,
(1)檢查 mmdetection 的依賴項(xiàng),滿足要求。
cat requirements/build.txt
# These must be installed before building mmdetection
numpy
torch>=1.1
(2)安裝 cocoapi,
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py install
(3)安裝 mmdetection,依賴項(xiàng)檢查結(jié)果,terminaltables 沒有,采用源碼離線安裝。
安裝完成以后,打開 Python 進(jìn)行驗(yàn)證,
from mmdet.apis import init_detector, inference_detector, show_result
1.2 訓(xùn)練自定義數(shù)據(jù)集 CatDog
1.2.1 準(zhǔn)備數(shù)據(jù)集
創(chuàng)建 data 文件夾,軟鏈接到數(shù)據(jù)集根目錄,其中標(biāo)記好的數(shù)據(jù)集采用 coco 數(shù)據(jù)格式,
cd mmdetection
mkdir data
export COCO_ROOT=/data1/Projects/datasets/coco
ln -s $COCO_ROOT data
在 mmdet/datasets/cat_dog.py
:
from .coco import CocoDataset
from .registry import DATASETS
@DATASETS.register_module
class CatDog(CocoDataset):
CLASSES = ('dog', 'cat')
在 mmdet/datasets/__init__.py
:
from .cat_dog import CatDog
1.2.2 修改 faster_rcnn 模型配置
下載 resnet50 的預(yù)訓(xùn)練模型,放入 $TORCH_HOME
,
export TORCH_HOME=/data1/Projects/pretrained_models
echo $TORCH_HOME
mkdir -p /data1/Projects/pretrained_models/checkpoints/
mv resnet50-19c8e357.pth /data1/Projects/pretrained_models/checkpoints/
用 faster_rcnn 做為模型進(jìn)行目標(biāo)檢測,拷貝一份 configs/faster_rcnn_r50_fpn_1x.py
為 configs/cat_dog_faster_rcnn_r50_fpn_1x.py
,在 cat_dog_faster_rcnn_r50_fpn_1x.py
,加入預(yù)訓(xùn)練的 resnet50 加載路徑,
# model settings
import os
os.environ['TORCH_HOME'] = '/data1/Projects/pretrained_models'
在 config 文件 configs/cat_dog_faster_rcnn_r50_fpn_1x.py
使用 CatDog 數(shù)據(jù)集,分類數(shù) 3 包括狗,貓和背景,
num_classes=3,
CatDog 數(shù)據(jù)集為類 coco 數(shù)據(jù)集,針對 coco 數(shù)據(jù)集修改:
- 定義數(shù)據(jù)種類,修改
mmdetection/mmdet/datasets/coco.py
,把 CLASSES 的 tuple 改為自己數(shù)據(jù)集對應(yīng)的種類。
CLASSES = ('dog', 'cat')
- 在
mmdetection/mmdet/core/evaluation/class_names.py
修改 coco_classes 數(shù)據(jù)集類別,這個(gè)關(guān)系到后面 test 的時(shí)候結(jié)果圖中顯示的類別名稱。
def coco_classes():
return [ 'dog', 'cat']
1.2.3 訓(xùn)練模型
使用單個(gè) GPU 進(jìn)行訓(xùn)練,指定 --work_dir
保存模型結(jié)果,
python tools/train.py configs/cat_dog_faster_rcnn_r50_fpn_1x.py \
--gpus 1 \
--work_dir './work_dirs/cat_dog_faster_rcnn_r50_fpn_1x'
1.2.4 測試圖片
1.2.4.1 測試單張圖片
import numpy as np
from mmdet.apis import init_detector, inference_detector
import mmcv
import cv2
threshold = 0.9 # confidence score
config_file = './configs/cat_dog_faster_rcnn_r50_fpn_1x.py'
checkpoint_file = './work_dirs/cat_dog_faster_rcnn_r50_fpn_1x/latest.pth'
# 通過配置文件(config file)和模型文件(checkpoint file)構(gòu)建檢測模型
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# 測試單張圖片并展示結(jié)果
img_path = '/data1/Projects/datasets/cat_dog_single/cat.12176.jpg'
result = inference_detector(model, img_path)
bboxes = np.vstack(result)
# print(bboxes.shape)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
# print(labels)
img = cv2.imread(img_path)
scores = bboxes[:, -1]
inds = scores > threshold
bboxes = bboxes[inds, :]
# print(bboxes.shape)
labels = labels[inds]
# print(labels)
class_names = model.CLASSES
cat_dog_dict = {}
for label in labels:
cat_dog_dict[class_names[label]] = cat_dog_dict.get(class_names[label], 0) + 1
print(cat_dog_dict)
for k, v in cat_dog_dict.items():
print('{0} 有 {1} 只 {2}'.format(img_path, v, k))
for bbox in bboxes:
left_top = (bbox[0], bbox[1])
right_bottom = (bbox[2], bbox[3])
cv2.rectangle(img, left_top, right_bottom, color=(0, 255, 0))
cv2.imwrite('/data1/Projects/datasets/cat_dog_single/res_cat.12176.jpg', img)
1.2.4.2 測試多張圖片
from mmdet.apis import init_detector, inference_detector, show_result
import mmcv
import numpy as np
import glob
import os
config_file = 'configs/cat_dog_faster_rcnn_r50_fpn_1x.py'
checkpoint_file = 'work_dirs/cat_dog_faster_rcnn_r50_fpn_1x/latest.pth'
score_thr = 0.9
# 通過配置文件(config file)和模型文件(checkpoint file)構(gòu)建檢測模型
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# 測試單張圖片
# img = '/data1/Projects/datasets/cat_dog_single/cat.12176.jpg'
# result = inference_detector(model, img)
# show_result(img, result, model.CLASSES, score_thr=score_thr,
# out_file='/data1/Projects/datasets/cat_dog_single/res_cat.12176.jpg')
# 測試多張圖片
# imgs = glob.glob('/data1/Projects/datasets/coco/val2017/*.jpg')
imgs = glob.glob('/data1/Projects/datasets/test/*.jpg')
for i, img in enumerate(imgs):
# 畫 bounding boxes 到圖片上
# print(i, imgs[i])
result = inference_detector(model, img)
file_name = imgs[i].split('/')[-1]
out_file = os.path.join('/data1/Projects/datasets/test_det', file_name)
show_result(img, result, model.CLASSES, score_thr=score_thr, out_file=out_file)
# 輸出圖片的貓和狗的數(shù)量
img = mmcv.imread(img)
img = img.copy()
bbox_result = result
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# 根據(jù)閾值調(diào)整輸出的 bboxes 和 labels
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
class_names = model.CLASSES
cat_dog_dict = {}
for label in labels:
cat_dog_dict[class_names[label]] = cat_dog_dict.get(class_names[label], 0) + 1
for k, v in cat_dog_dict.items():
print('{0} 有 {1} 只 {2}'.format(imgs[i].split('/')[-1], v, k))
print('--------------------')
測試多張圖片,輸出結(jié)果如下:
微信公眾號「padluo」,分享數(shù)據(jù)科學(xué)家的自我修養(yǎng),既然遇見,不如一起成長。