-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
train_model.py
217 lines (180 loc) · 8.93 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# encoding:utf-8
import os
import time
import json
import random
from PIL import Image
import numpy as np
import tensorflow as tf
from cnn_models.cnn_model import CNN
class train_model(CNN):
def __init__(self, train_img_path, verify_img_path, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save,
train_batch_size, test_batch_size):
self.cycle_stop = cycle_stop # 循环次数
self.acc_stop = acc_stop # 达到准确率停止
self.cycle_save = cycle_save # 当前循环次数
self.train_batch_size = train_batch_size # 训练批大小
self.test_batch_size = test_batch_size # 测试批大小
char_set = [str(i) for i in char_set]
self.train_img_path = train_img_path
self.train_images_list = os.listdir(train_img_path)
# 打乱文件顺序
random.seed(time.time())
random.shuffle(self.train_images_list)
# 验证集文件
self.verify_img_path = verify_img_path
self.verify_images_list = os.listdir(verify_img_path)
# 获得图片宽高和字符长度基本信息
label, captcha_array = self.gen_captcha_text_image(train_img_path, self.train_images_list[0])
captcha_shape = captcha_array.shape
captcha_shape_len = len(captcha_shape)
if captcha_shape_len == 3:
image_height, image_width, channel = captcha_shape
self.channel = channel
elif captcha_shape_len == 2:
image_height, image_width = captcha_shape
else:
raise Exception("图片转换为矩阵时出错,请检查图片格式")
# 初始化变量
super(train_model, self).__init__(image_height, image_width, 4, char_set, model_save_dir)
# 相关信息打印
print("-->图片尺寸: {} X {}".format(image_height, image_width))
print("-->验证码长度: {}".format(self.max_captcha))
print("-->验证码共{}类 {}".format(self.char_set_len, char_set))
print("-->使用测试集为 {}".format(train_img_path))
print("-->使验证集为 {}".format(verify_img_path))
# test model input and output
print(">>> Start model test")
batch_x, batch_y = self.get_train_batch(0, size=100)
print(">>> input batch images shape: {}".format(batch_x.shape))
print(">>> input batch labels shape: {}".format(batch_y.shape))
@staticmethod
def gen_captcha_text_image(img_path, img_name):
"""
返回一个验证码的array形式和对应的字符串标签
:return:tuple (str, numpy.array)
"""
# 文件
label = img_name.split("_")[-1].replace('.jpg','')
# 文件
img_file = os.path.join(img_path, img_name)
captcha_image = Image.open(img_file)
captcha_array = np.array(captcha_image) # 向量化
return label, captcha_array
def get_train_batch(self, n, size=64):
batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化
batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化
max_batch = int(len(self.train_images_list) / size)
# print(max_batch)
if max_batch - 1 < 0:
raise Exception("训练集图片数量需要大于每批次训练的图片数量")
if n > max_batch - 1:
n = n % max_batch
s = n * size
e = (n + 1) * size
this_batch = self.train_images_list[s:e]
# print("{}:{}".format(s, e))
for i, img_name in enumerate(this_batch):
label, image_array = self.gen_captcha_text_image(self.train_img_path, img_name)
image_array = self.convert2gray(image_array) # 灰度化图片
batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维
batch_y[i, :] = self.text2vec(label) # 生成 oneHot
return batch_x, batch_y
def get_verify_batch(self, size=50):
batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化
batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化
verify_images = []
for i in range(size):
# TODO:
verify_images.append(random.choice(self.train_images_list))
for i, img_name in enumerate(verify_images):
# TODO:
label, image_array = self.gen_captcha_text_image(self.train_img_path, img_name)
image_array = self.convert2gray(image_array) # 灰度化图片
batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维
batch_y[i, :] = self.text2vec(label) # 生成 oneHot
return batch_x, batch_y
def train(self):
y_pred = self.alexnet_model()
# y_pred = self.Letnet_model()
print(">>>input batch predict shape:{}".format(y_pred.shape))
print(">>>End model test")
# 计算损失函数
with tf.name_scope('cost'):
cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_pred, labels=self.Y))
# 梯度下降法
with tf.name_scope('train'):
optimizer = tf.train.GradientDescentOptimizer(0.000008).minimize(cost)
# 计算准确率
predict = tf.reshape(y_pred, [-1, self.max_captcha, self.char_set_len])
max_idx_p = tf.argmax(predict, 2) # 返回在张量轴上具有最大值的索引
max_idx_l = tf.argmax(tf.reshape(self.Y, [-1, self.max_captcha, self.char_set_len]), 2) #
# 计算准确率
correct_pred = tf.equal(max_idx_p, max_idx_l)
with tf.name_scope('char_acc'):
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
with tf.name_scope('image_acc'):
accuracy_image_count = tf.reduce_mean(tf.reduce_min(tf.cast(correct_pred, tf.float32), axis=1))
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 设置模型存储路径
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.7
with tf.Session(config=config) as sess:
init = tf.global_variables_initializer()
sess.run(init)
# 导入模型
if os.path.exists(self.model_save_dir):
try:
saver.restore(sess, self.model_save_dir)
except:
print('model文件夹为空')
tf.summary.FileWriter("logs/", sess.graph)
step = 1
for i in range(self.cycle_stop):
batch_x, batch_y = self.get_train_batch(i, self.train_batch_size)
# 梯度下降法训练
_, cost_ = sess.run([optimizer, cost], feed_dict={self.X: batch_x, self.Y: batch_y, self.keep_prob: 0.75})
if step % 10 == 0: # 每10 step计算一次准确率
# 计算精度
batch_x_t, batch_y_t = self.get_verify_batch(64)
acc_char = sess.run(accuracy, feed_dict={self.X: batch_x_t, self.Y: batch_y_t, self.keep_prob: 1.})
acc_img = sess.run(accuracy_image_count, feed_dict={self.X: batch_x_t, self.Y: batch_y_t, self.keep_prob:1.})
print ("Iter:{0}, Minibatch Loss= {1},".format(step, cost_) +
"Accuracy_char= {:.5f},Accuracy_img = {:.5f}".format(acc_char, acc_img))
if acc_img > self.acc_stop:
saver.save(sess, self.model_save_dir)
print("准确率得到标准,保存模型")
break
if step % 500 == 0:
saver.save(sess, self.model_save_dir)
print('保存模型:{}'.format(step))
step += 1
saver.save(sess, self.model_save_dir)
print ("Optimization Finished!")
# 计算测试精度
batch_x_t, batch_y_t = self.get_verify_batch(128)
print( "Testing Accuracy:",step, sess.run(accuracy, feed_dict={self.X: batch_x_t, self.Y: batch_y_t, self.keep_prob: 1.}))
def main():
with open("conf/config.json", 'r') as f:
conf = json.load(f)
train_image_dir = conf["train_image_dir"]
verify_image_dir = conf["test_image_dir"]
model_save_dir = conf["model_save_dir"]
cycle_stop = conf["cycle_stop"]
acc_stop = conf["acc_stop"]
cycle_save = conf["cycle_save"]
enable_gpu = conf["enable_gpu"]
train_batch_size = conf['train_batch_size']
test_batch_size = conf['test_batch_size']
char_set = conf['char_set']
if not enable_gpu:
# 设置以下环境变量可开启CPU识别
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
tm = train_model(train_image_dir, verify_image_dir, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save,
train_batch_size, test_batch_size)
tm.train() # 开始训练模型
# tm.recognize_captcha() # 识别图片示例
if __name__ == '__main__':
main()