增加:CNN网络测试(基本流程已完成)

附:训练结果
This commit is contained in:
leaf 2022-06-09 13:03:13 +08:00
parent ea6ee46971
commit 5f9d6f5abc
2 changed files with 26 additions and 42 deletions

BIN
CNN.pkl Normal file

Binary file not shown.

68
demo.py
View File

@ -20,21 +20,21 @@ class CNN(nn.Module):
def __init__(self): def __init__(self):
super(CNN, self).__init__() super(CNN, self).__init__()
self.out_label = [] self.out_label = []
self.conv1 = nn.Sequential( # input shape (1, 21, 3) self.conv1 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=1, # input height in_channels=1,
out_channels=16, # n_filters out_channels=16,
kernel_size=5, # filter size kernel_size=5,
stride=1, # filter movement/step stride=1,
padding=2, # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1 padding=2,
), # output shape (16, 28, 28) ),
nn.ReLU(), # activation nn.ReLU(),
nn.MaxPool2d(kernel_size=1), # 在 2x2 空间里向下采样, output shape (16, 14, 14) nn.MaxPool2d(kernel_size=1),
) )
self.conv2 = nn.Sequential( # input shape (16, 14, 14) self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(), # activation nn.ReLU(),
nn.MaxPool2d(3), # output shape (32, 7, 7) nn.MaxPool2d(3),
) )
self.med = nn.Linear(32 * 7 * 1, 500) self.med = nn.Linear(32 * 7 * 1, 500)
self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes self.out = nn.Linear(500, 10) # fully connected layer, output 10 classes
@ -109,14 +109,14 @@ class HandDetector:
y_list = [] y_list = []
bbox_info = [] bbox_info = []
self.lmList = [] self.lmList = []
h, w, c = img.shape
if self.results.multi_hand_landmarks: if self.results.multi_hand_landmarks:
my_hand = self.results.multi_hand_landmarks[hand_no] my_hand = self.results.multi_hand_landmarks[hand_no]
for _, lm in enumerate(my_hand.landmark): for _, lm in enumerate(my_hand.landmark):
h, w, c = img.shape
px, py = int(lm.x * w), int(lm.y * h) px, py = int(lm.x * w), int(lm.y * h)
x_list.append(px) x_list.append(px)
y_list.append(py) y_list.append(py)
self.lmList.append([px, py]) self.lmList.append([lm.x, lm.y, lm.z])
if draw: if draw:
cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED) cv2.circle(img, (px, py), 5, (255, 0, 255), cv2.FILLED)
x_min, x_max = min(x_list), max(x_list) x_min, x_max = min(x_list), max(x_list)
@ -174,8 +174,8 @@ class HandDetector:
class Main: class Main:
def __init__(self): def __init__(self):
self.EPOCH = 20 self.EPOCH = 50
self.BATCH_SIZE = 10 self.BATCH_SIZE = 5
self.LR = 10e-5 self.LR = 10e-5
self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -238,6 +238,8 @@ class Main:
def gesture_recognition(self): def gesture_recognition(self):
self.detector = HandDetector() self.detector = HandDetector()
cnn = torch.load("CNN.pkl")
out_label = cnn.out_label
while True: while True:
frame, img = self.camera.read() frame, img = self.camera.read()
img = self.detector.find_hands(img) img = self.detector.find_hands(img)
@ -245,31 +247,13 @@ class Main:
if lm_list: if lm_list:
x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1] x_1, y_1 = bbox["bbox"][0], bbox["bbox"][1]
x1, x2, x3, x4, x5 = self.detector.fingers_up() data = torch.Tensor(lm_list)
data = data.unsqueeze(0)
data = data.unsqueeze(0)
if (x2 == 1 and x3 == 1) and (x4 == 0 and x5 == 0 and x1 == 0): test_output = cnn(data)
cv2.putText(img, "2_TWO", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3, result = torch.max(test_output, 1)[1].data.cpu().numpy()[0]
(0, 0, 255), 3) cv2.putText(img, str(out_label[result]), (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
elif (x2 == 1 and x3 == 1 and x4 == 1) and (x1 == 0 and x5 == 0):
cv2.putText(img, "3_THREE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif (x2 == 1 and x3 == 1 and x4 == 1 and x5 == 1) and (x1 == 0):
cv2.putText(img, "4_FOUR", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif x1 == 1 and x2 == 1 and x3 == 1 and x4 == 1 and x5 == 1:
cv2.putText(img, "5_FIVE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif x2 == 1 and x1 == 0 and (x3 == 0, x4 == 0, x5 == 0):
cv2.putText(img, "1_ONE", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif x1 == 1 and x2 == 1 and (x3 == 0, x4 == 0, x5 == 0):
cv2.putText(img, "8_EIGHT", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif x1 == 1 and x5 == 1 and (x3 == 0, x4 == 0, x5 == 0):
cv2.putText(img, "6_SIX", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3)
elif x1 and (x2 == 0, x3 == 0, x4 == 0, x5 == 0):
cv2.putText(img, "GOOD!", (x_1, y_1), cv2.FONT_HERSHEY_PLAIN, 3,
(0, 0, 255), 3) (0, 0, 255), 3)
cv2.imshow("camera", img) cv2.imshow("camera", img)
@ -282,6 +266,6 @@ class Main:
if __name__ == '__main__': if __name__ == '__main__':
Solution = Main() Solution = Main()
# Solution.gesture_recognition()
Solution.load_datasets() Solution.load_datasets()
Solution.train_cnn() Solution.train_cnn()
Solution.gesture_recognition()