-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Modelzoo] Add serving for DIEN, DeepFM and WDL. #319
base: main
Are you sure you want to change the base?
Conversation
|
for #232 |
modelzoo/BST/pb_to_pbtxt.py
Outdated
@@ -0,0 +1,13 @@ | |||
from tensorflow.python.saved_model import loader_impl | |||
from tensorflow.python.lib.io import file_io | |||
from tensorflow.python.platform import tf_logging as logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless file.
modelzoo/BST/prepare_savedmodel.py
Outdated
@@ -0,0 +1,738 @@ | |||
import time | |||
import argparse | |||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要修改原文件,在原文件上添加代码即可。
modelzoo/BST/start_serving.cc
Outdated
#include "serving/processor/serving/processor.h" | ||
#include "serving/processor/serving/predict.pb.h" | ||
|
||
static const char* model_config = "{ \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除fake数据之外,需要根据实际的train data生成serving data,然后能够根据serving data进行process
CLA required. |
I have finished exporting savedmodel, extracting data from the input file to generate requests and serving for DIEN DeepFM WDL under modelzoo/features/EmbeddingVariable, please check. |
cat_voc = os.path.join(data_location, "cat_voc.pkl") | ||
|
||
def prepare_data(input, target, maxlen=None, return_neg=False): | ||
# x: a list of sentences |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepRec中一般使用两个空格锁进,这些文件都稍微改一下
|
||
|
||
f.close() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多余的空格删除
::tensorflow::eas::ArrayShape array_shape; | ||
::tensorflow::eas::ArrayDataType dtype_f = | ||
::tensorflow::eas::ArrayDataType::DT_FLOAT; | ||
int num_elem = (int)cur_vector.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用(int)强转,会隐式转换的,或者定义为 size_t num_elem也可以
int num_elem = (int)cur_vector.size(); | ||
|
||
array_shape.add_dim(1); | ||
if((int)cur_vector.size() < 0){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if((int)cur_vector.size() < 0) ,cur_vector.size()不会小于0的,类型是size_t本身是正值。
|
||
return input; | ||
} | ||
array_shape.add_dim((int)cur_vector.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_vector.size()前面已经赋值了num_elem
@@ -415,6 +415,9 @@ def main(tf_config=None, server=None): | |||
|
|||
if tf_config: | |||
print('train steps : %d' % train_steps) | |||
print("-----------") | |||
print("-----------") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless code
@@ -0,0 +1,58 @@ | |||
## How to use prepare_savedmodel.py to get savedmodel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名叫README.md就可以
|
||
} | ||
|
||
::tensorflow::eas::ArrayProto get_proto_f(float char_input,int dim,::tensorflow::eas::ArrayDataType type){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数重复
|
||
while (record != NULL) { | ||
// only 1 label and 39 feature | ||
if (j >= 40) break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
struct input_format39 inputs; | ||
inputs.I1 = (float)(atof(all_elems[start_idx])); | ||
inputs.I2 = (float)(atof(all_elems[start_idx+1])); | ||
inputs.I3 = (float)(atof(all_elems[start_idx+2])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个搞一个循环吧
先删除掉不必要的修改。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先把所有的文件整理一下
b52d097
to
576652e
Compare
## How to use prepare_savedmodel.py to get savedmodel | ||
|
||
- Current support model: \ | ||
BST, DBMTL, DeepFM, DIEN, DIN, DLRM, DSSM, ESMM, MMoE, SimpleMultiTask, WDL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now support DIEN, DeepFM and and WDL.
maxlen, | ||
data_location=data_location) | ||
|
||
f = open("./test_data.csv","w") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_data.csv 怎么得到的呢,可以在readme中写清楚
|
||
with tf.Session() as sess1: | ||
|
||
model = Model_DIN_V2_Gru_Vec_attGru_Neg(n_uid, n_mid, n_cat, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Model_DIN_V2_Gru_Vec_attGru_Neg是哪里import的?
input.set_dtype(dtype_f); | ||
|
||
switch(dtype_f){ | ||
case 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个case中大部分代码都在重复的,可以只在
input.add_float_val((float)atof(cur_vector.back()));
input.add_int_val((int)atoi(cur_vector.back()));
代码加上if判断。
temp_ptrs.clear(); | ||
|
||
// traverse current line | ||
record = strtok(line, delim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面这个可以搞成类似 split(...) 函数吗,现在这样写比较hack。
|
||
// input setting | ||
::tensorflow::eas::ArrayProto I1 = get_proto_cc(&inputs.I1_13[0],1,dtype_f); | ||
::tensorflow::eas::ArrayProto I2 = get_proto_cc(&inputs.I1_13[1],1,dtype_f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
搞一个数组存储?
int main(int argc, char** argv) { | ||
|
||
// PLEASE EDIT THIS LINE!!!! | ||
char filepath[] = "/home/deeprec/DeepRec/modelzoo/features/EmbeddingVariable/WDL/test.csv"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件如何生成的呢
For every model listed above, there is a prepare_savedmodel.py. To run this script please firstly ensure you have gotten the checkpoint file from training. To use prepare_savedmodel.py, please use: | ||
|
||
``` | ||
cd [modelfolder] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
training的文件位置需要写清楚,用户可能找不到具体的training的文件所在。
或者你把training文件copy到你的目录下也可以。
int main(int argc, char** argv) { | ||
|
||
// PLEASE EDIT THIS LINE!!!! | ||
char filepath[] = "/home/deeprec/DeepRec/modelzoo/features/EmbeddingVariable/WDL/test.csv"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件如何生成的呢
::tensorflow::eas::ArrayDataType::DT_INT64; | ||
|
||
// input setting | ||
::tensorflow::eas::ArrayProto I1 = get_proto_cc(&inputs.I1_9[0],1,dtype_f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
搞一个数组存储
请问下,这个代码还是没有合并嘛? 现在仍然只有训练,没有推理部署,文档也没说 |
DeepRec 本身具备推理的能力,这个PR只是这三个model 的推理例子,你可以参考 |
I have finished exporting savedmodel for models in modelzoo. Under Deeprec/modelzoo there is a Get_SavedModel.md for new user to learn how to export and where the result model will locate.