Skip to content

Commit cec4417

Browse files
authored
[Embedding] Initialize the layout of ValuePtr in InitializeOp and ImportOp. (#405)
1 parent ba3da62 commit cec4417

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class InitializeKvVariableOp : public OpKernel {
138138
&false_positive_probability_));
139139
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
140140
&l2_weight_threshold_));
141-
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
142141
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim", &default_value_dim_));
143142
OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_));
144143
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
@@ -156,6 +155,20 @@ class InitializeKvVariableOp : public OpKernel {
156155
filter_freq_ = 0;
157156
}
158157

158+
if ((filter_freq_ != 0 && max_element_size_ == 0)
159+
|| steps_to_live_ != 0 || record_freq_
160+
|| record_version_ || storage_type > 5) {
161+
if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5)) {
162+
layout_ = "normal";
163+
} else {
164+
layout_ = "normal_contiguous";
165+
}
166+
} else {
167+
layout_ = "light";
168+
}
169+
170+
CHECK(block_num_ == 1 || layout_ != "normal_contiguous");
171+
159172
if (steps_to_live_ == kEmbeddingVarUseDB ||
160173
steps_to_live_ == kInitializableEmbeddingVarUseDB) {
161174
LOG(INFO) << "hashmap use db";
@@ -199,7 +212,6 @@ class InitializeKvVariableOp : public OpKernel {
199212
std::string opname = handle_self.name();
200213

201214
EmbeddingVar<TKey, TValue>* ev = nullptr;
202-
CHECK(block_num_ == 1 || layout_ != "normal_contiguous");
203215

204216
if (handle_self.name() == handle_primary.name() &&
205217
handle_self.container() == handle_primary.container()) {
@@ -642,7 +654,6 @@ class KvResourceImportV2Op: public AsyncOpKernel {
642654
&false_positive_probability_));
643655
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
644656
&l2_weight_threshold_));
645-
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
646657
OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_));
647658
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim",
648659
&default_value_dim_));
@@ -656,6 +667,20 @@ class KvResourceImportV2Op: public AsyncOpKernel {
656667
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
657668
OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_));
658669

670+
if ((filter_freq_ != 0 && max_element_size_ == 0)
671+
|| steps_to_live_ != -1 || record_freq_
672+
|| record_version_ || storage_type > 5) {
673+
if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5)) {
674+
layout_ = "normal";
675+
} else {
676+
layout_ = "normal_contiguous";
677+
}
678+
} else {
679+
layout_ = "light";
680+
}
681+
682+
CHECK(block_num_ == 1 || layout_ != "normal_contiguous");
683+
659684
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EV_ASYNC_RESTORE", true,
660685
&ev_async_restore_));
661686
}

tensorflow/python/ops/kv_variable_ops.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,6 @@ def _init_from_args(self,
293293
self._storage_path = evconfig.storage_path
294294
self._storage_size = evconfig.storage_size
295295
self._default_value_dim = evconfig.default_value_dim
296-
if (isinstance(evconfig.filter_strategy, variables.CounterFilter) and self._filter_freq != 0) or \
297-
self._steps_to_live not in [0, None] or self._record_version or \
298-
self._storage_type in multi_level_list or self._record_freq:
299-
if self._block_num not in [1, None] and self._storage_type in multi_level_list:
300-
raise ValueError("Dynamic-dimension Embedding and Multi-level EV can't be enabled together")
301-
if self._block_num not in [1, None] or \
302-
(self._filter_freq != 0 and self._storage_type not in multi_level_list):
303-
self._layout = "normal"
304-
else:
305-
self._layout = "normal_contiguous"
306-
else:
307-
self._layout = "light"
308296

309297
if self._primary is None:
310298
self._is_primary = True
@@ -409,7 +397,7 @@ def _init_from_args(self,
409397
false_positive_probability = self._false_positive_probability,
410398
counter_type = self._counter_type,
411399
max_freq = 99999,
412-
layout = self._layout,
400+
layout = "",
413401
storage_type = self._storage_type,
414402
storage_path = self._storage_path,
415403
storage_size = self._storage_size,

tensorflow/python/training/saving/saveable_object_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def restore(self, restored_tensors, unused_restored_shapes):
208208
max_element_size = self.var._max_element_size,
209209
false_positive_probability = self.var._false_positive_probability,
210210
counter_type = self.var._counter_type,
211-
layout = self.var._layout,
211+
layout = "",
212212
storage_type=self.var._storage_type,
213213
storage_path=self.var._storage_path,
214214
storage_size=self.var._storage_size,

tensorflow/python/training/slot_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
123123
primary=primary._primary,
124124
slot_num=slot_config.slot_num,
125125
storage_type=primary.storage_type,
126+
storage_path=primary._storage_path,
127+
storage_size=primary._storage_size,
126128
l2_weight_threshold=primary._l2_weight_threshold,
127129
filter_strategy=filter_strategy)
128130
)

0 commit comments

Comments
 (0)