增加:数据集读入、CNN网络训练
This commit is contained in:
parent
0b0dcb27f3
commit
ea6ee46971
15
Datasets/README.md
Normal file
15
Datasets/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
## GestureData 手势数据 v1.0
|
||||||
|
# 文件格式:
|
||||||
|
每个数据集(npz文件)包含:
|
||||||
|
1个标签label(手势标签,整个数据集都是这一个标签);
|
||||||
|
500组数据data(每组数据是21*3,即21个点的3维数据,就是demo.py-find_position()中的lm.x, lm.y, lm.z);
|
||||||
|
左右手区分handtype(0为左手,1为右手);
|
||||||
|
画布大小shape(一般都是720*1280,对应demo.py-find_position()中的w, h)。
|
||||||
|
|
||||||
|
# 注意事项:
|
||||||
|
1. 在使用之前建议先熟悉npz文件的读写与使用(很简单的);
|
||||||
|
2. 数据集shape类最后会保存一个[0, 0],其他都是正常的[720, 1280];
|
||||||
|
3. 左右手不建议使用,因为面向屏幕的手心手背就可以导致程序的误判。
|
||||||
|
|
||||||
|
# 更新说明:
|
||||||
|
1. 保存了0~9的手势。
|
BIN
Datasets/eight.npz
Normal file
BIN
Datasets/eight.npz
Normal file
Binary file not shown.
BIN
Datasets/five.npz
Normal file
BIN
Datasets/five.npz
Normal file
Binary file not shown.
BIN
Datasets/four.npz
Normal file
BIN
Datasets/four.npz
Normal file
Binary file not shown.
BIN
Datasets/nine.npz
Normal file
BIN
Datasets/nine.npz
Normal file
Binary file not shown.
BIN
Datasets/one.npz
Normal file
BIN
Datasets/one.npz
Normal file
Binary file not shown.
BIN
Datasets/seven.npz
Normal file
BIN
Datasets/seven.npz
Normal file
Binary file not shown.
BIN
Datasets/six.npz
Normal file
BIN
Datasets/six.npz
Normal file
Binary file not shown.
BIN
Datasets/three.npz
Normal file
BIN
Datasets/three.npz
Normal file
Binary file not shown.
BIN
Datasets/two.npz
Normal file
BIN
Datasets/two.npz
Normal file
Binary file not shown.
BIN
Datasets/zero.npz
Normal file
BIN
Datasets/zero.npz
Normal file
Binary file not shown.
100
demo.py
100
demo.py
@ -9,6 +9,43 @@
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
|
||||||
|
class CNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(CNN, self).__init__()
|
||||||
|
self.out_label = []
|
||||||
|
self.conv1 = nn.Sequential( # input shape (1, 21, 3)
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=1, # input height
|
||||||
|
out_channels=16, # n_filters
|
||||||
|
kernel_size=5, # filter size
|
||||||
|
stride=1, # filter movement/step
|
||||||
|
padding=2, # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
|
||||||
|
), # output shape (16, 28, 28)
|
||||||
|
nn.ReLU(), # activation
|
||||||
|
nn.MaxPool2d(kernel_size=1), # 在 2x2 空间里向下采样, output shape (16, 14, 14)
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
|
||||||
|
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
|
||||||
|
nn.ReLU(), # activation
|
||||||
|
nn.MaxPool2d(3), # output shape (32, 7, 7)
|
||||||
|
)
|
||||||
|
self.med = nn.Linear(32 * 7 * 1, 500)
|
||||||
|
self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
|
||||||
|
x = self.med(x)
|
||||||
|
output = self.out(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class HandDetector:
|
class HandDetector:
|
||||||
@ -137,11 +174,68 @@ class HandDetector:
|
|||||||
|
|
||||||
class Main:
|
class Main:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.detector = None
|
self.EPOCH = 20
|
||||||
|
self.BATCH_SIZE = 10
|
||||||
|
self.LR = 10e-5
|
||||||
|
self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
self.camera = cv2.VideoCapture(0, cv2.CAP_DSHOW)
|
||||||
self.camera.set(3, 1280)
|
self.camera.set(3, 1280)
|
||||||
self.camera.set(4, 720)
|
self.camera.set(4, 720)
|
||||||
|
|
||||||
|
self.datasets_dir = "Datasets"
|
||||||
|
self.train_loader = None
|
||||||
|
self.out_label = [] # CNN网络输出后数字标签转和字符串标签的映射关系
|
||||||
|
|
||||||
|
self.detector = None
|
||||||
|
|
||||||
|
def load_datasets(self):
|
||||||
|
train_data = []
|
||||||
|
train_label = []
|
||||||
|
|
||||||
|
for file in Path(self.datasets_dir).rglob("*.npz"):
|
||||||
|
data = np.load(str(file))
|
||||||
|
train_data.append(data["data"])
|
||||||
|
label_number = np.ones(len(data["data"]))*len(self.out_label)
|
||||||
|
train_label.append(label_number)
|
||||||
|
self.out_label.append(data["label"])
|
||||||
|
train_data = torch.Tensor(np.concatenate(train_data, axis=0))
|
||||||
|
train_data = train_data.unsqueeze(1)
|
||||||
|
train_label = torch.tensor(np.concatenate(train_label, axis=0)).long()
|
||||||
|
|
||||||
|
dataset = TensorDataset(train_data, train_label)
|
||||||
|
self.train_loader = DataLoader(dataset, batch_size=self.BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
|
def train_cnn(self):
|
||||||
|
cnn = CNN().to(self.DEVICE)
|
||||||
|
optimizer = torch.optim.Adam(cnn.parameters(), self.LR) # optimize all cnn parameters
|
||||||
|
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
|
||||||
|
|
||||||
|
for epoch in range(self.EPOCH):
|
||||||
|
for step, (data, target) in enumerate(self.train_loader):
|
||||||
|
# 分配 batch data, normalize x when iterate train_loader
|
||||||
|
data, target = data.to(self.DEVICE), target.to(self.DEVICE)
|
||||||
|
output = cnn(data) # cnn output
|
||||||
|
loss = loss_func(output, target) # cross entropy loss
|
||||||
|
optimizer.zero_grad() # clear gradients for this training step
|
||||||
|
loss.backward() # backpropagation, compute gradients
|
||||||
|
optimizer.step() # apply gradients
|
||||||
|
if (step + 1) % 100 == 0: # 输出结果
|
||||||
|
if (step + 1) % 100 == 0: # 输出结果
|
||||||
|
print(
|
||||||
|
"\r[Epoch: %d] [%d/%d (%0.f %%)][Loss: %f]"
|
||||||
|
% (
|
||||||
|
epoch,
|
||||||
|
step * len(data),
|
||||||
|
len(self.train_loader.dataset),
|
||||||
|
100. * step / len(self.train_loader),
|
||||||
|
loss.item()
|
||||||
|
), end="")
|
||||||
|
|
||||||
|
cnn.out_label = self.out_label
|
||||||
|
torch.save(cnn, 'CNN.pkl')
|
||||||
|
print("训练结束")
|
||||||
|
|
||||||
def gesture_recognition(self):
|
def gesture_recognition(self):
|
||||||
self.detector = HandDetector()
|
self.detector = HandDetector()
|
||||||
while True:
|
while True:
|
||||||
@ -188,4 +282,6 @@ class Main:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
Solution = Main()
|
Solution = Main()
|
||||||
Solution.gesture_recognition()
|
# Solution.gesture_recognition()
|
||||||
|
Solution.load_datasets()
|
||||||
|
Solution.train_cnn()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user