Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load saved SDM weights properly to reproduce embeddings? #29

Open
qiong-sportsbet opened this issue Jun 18, 2020 · 7 comments
Open
Labels
question Further information is requested

Comments

@qiong-sportsbet
Copy link

Describe the question(问题描述)
After I saved SDM weights and loaded it in another process. It produced different user embeddings.

How to save SDM model properly and then load it properly to reproduce embeddings?

Operating environment(运行环境):

  • python version [e.g. 3.7.3]
  • tensorflow version [e.g. 2.2.0,]
  • deepmatch version [GPU e.g. 0.1.3,]
@qiong-sportsbet qiong-sportsbet added the question Further information is requested label Jun 18, 2020
@wangzhegeek
Copy link
Collaborator

refer to examples/run_sdm.py

@qiong-sportsbet
Copy link
Author

qiong-sportsbet commented Jun 18, 2020

@wangzhegeek Thanks for your quick reply. My case is I trained the model in one notebook to get the user embeddings and item embeddings. Then I loaded the model in another notebook and got the same item embeddings but completely different user embeddings. I assume the reason is that some hidden states of RNN layer is lost based on my reading from this thread
keras-team/keras#11335

@qiong-sportsbet
Copy link
Author

我先train好了model,save下来:
K.set_learning_phase(True)

import tensorflow as tf

if tf.version >= '2.0.0':
tf.compat.v1.disable_eager_execution()

model = SDM(user_feature_columns,
item_feature_columns,
history_feature_list=['events', 'event_types', 'event_classes', 'event_categories'],
units=embedding_dim,
num_sampled=100, )

optimizer = optimizers.Adam(lr=0.001, clipnorm=5.0)

model.compile(optimizer=optimizer, loss=sampledsoftmaxloss) # "binary_crossentropy")

history = model.fit(train_model_input,
train_label,
batch_size=512,
epochs=1,
verbose=1,
validation_data=(test_model_input, test_label),
)

K.set_learning_phase(False)

model.save('/tmp/saved_model.h5')

然后再加载模型:
from deepmatch.layers import custom_objects
loaded_model = load_model('/tmp/saved_model.h5',custom_objects)

然后用加载的模型获取embedding:
test_user_model_input = test_model_input
all_item_model_input = {"event_idx": betting_event_mapping['betting_event_idx'].values, }

user_embedding_model = Model(inputs=loaded_model.user_input, outputs=loaded_model.user_embedding)
item_embedding_model = Model(inputs=loaded_model.item_input, outputs=loaded_model.item_embedding)

user_embs = user_embedding_model.predict(test_user_model_input, batch_size=2 ** 12)
item_embs = item_embedding_model.predict(all_item_model_input, batch_size=2 ** 12)

print(user_embs.shape)
print(item_embs.shape)

出现Attribute Error:
AttributeError: 'Model' object has no attribute 'user_input'
AttributeError Traceback (most recent call last)
in
3 all_item_model_input = {"event_idx": betting_event_mapping['betting_event_idx'].values, }
4
----> 5 user_embedding_model = Model(inputs=loaded_model.user_input, outputs=loaded_model.user_embedding)
6 item_embedding_model = Model(inputs=loaded_model.item_input, outputs=loaded_model.item_embedding)
7

AttributeError: 'Model' object has no attribute 'user_input'

@nasWang
Copy link

nasWang commented Jul 17, 2020

同样的问题 不知道有没有解决?

@YuRong-Lin
Copy link

采用save_weights和load_weights方式也出现该问题

@YuRong-Lin
Copy link

Describe the question(问题描述)
After I saved SDM weights and loaded it in another process. It produced different user embeddings.

How to save SDM model properly and then load it properly to reproduce embeddings?

Operating environment(运行环境):

  • python version [e.g. 3.7.3]
  • tensorflow version [e.g. 2.2.0,]
  • deepmatch version [GPU e.g. 0.1.3,]

请问这个问题你自己有什么解决方案吗?

@HighingLIN
Copy link

这个问题现在都有,至今未解决

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants