@@ -138,7 +138,6 @@ class InitializeKvVariableOp : public OpKernel {
138
138
&false_positive_probability_));
139
139
OP_REQUIRES_OK (c, c->GetAttr (" l2_weight_threshold" ,
140
140
&l2_weight_threshold_));
141
- OP_REQUIRES_OK (c, c->GetAttr (" layout" , &layout_));
142
141
OP_REQUIRES_OK (c, c->GetAttr (" default_value_dim" , &default_value_dim_));
143
142
OP_REQUIRES_OK (c, c->GetAttr (" slot_num" , &slot_num_));
144
143
OP_REQUIRES_OK (c, c->GetAttr (" record_freq" , &record_freq_));
@@ -156,6 +155,20 @@ class InitializeKvVariableOp : public OpKernel {
156
155
filter_freq_ = 0 ;
157
156
}
158
157
158
+ if ((filter_freq_ != 0 && max_element_size_ == 0 )
159
+ || steps_to_live_ != 0 || record_freq_
160
+ || record_version_ || storage_type > 5 ) {
161
+ if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5 )) {
162
+ layout_ = " normal" ;
163
+ } else {
164
+ layout_ = " normal_contiguous" ;
165
+ }
166
+ } else {
167
+ layout_ = " light" ;
168
+ }
169
+
170
+ CHECK (block_num_ == 1 || layout_ != " normal_contiguous" );
171
+
159
172
if (steps_to_live_ == kEmbeddingVarUseDB ||
160
173
steps_to_live_ == kInitializableEmbeddingVarUseDB ) {
161
174
LOG (INFO) << " hashmap use db" ;
@@ -199,7 +212,6 @@ class InitializeKvVariableOp : public OpKernel {
199
212
std::string opname = handle_self.name ();
200
213
201
214
EmbeddingVar<TKey, TValue>* ev = nullptr ;
202
- CHECK (block_num_ == 1 || layout_ != " normal_contiguous" );
203
215
204
216
if (handle_self.name () == handle_primary.name () &&
205
217
handle_self.container () == handle_primary.container ()) {
@@ -642,7 +654,6 @@ class KvResourceImportV2Op: public AsyncOpKernel {
642
654
&false_positive_probability_));
643
655
OP_REQUIRES_OK (c, c->GetAttr (" l2_weight_threshold" ,
644
656
&l2_weight_threshold_));
645
- OP_REQUIRES_OK (c, c->GetAttr (" layout" , &layout_));
646
657
OP_REQUIRES_OK (c, c->GetAttr (" max_freq" , &max_freq_));
647
658
OP_REQUIRES_OK (c, c->GetAttr (" default_value_dim" ,
648
659
&default_value_dim_));
@@ -656,6 +667,20 @@ class KvResourceImportV2Op: public AsyncOpKernel {
656
667
OP_REQUIRES_OK (c, c->GetAttr (" record_freq" , &record_freq_));
657
668
OP_REQUIRES_OK (c, c->GetAttr (" record_version" , &record_version_));
658
669
670
+ if ((filter_freq_ != 0 && max_element_size_ == 0 )
671
+ || steps_to_live_ != -1 || record_freq_
672
+ || record_version_ || storage_type > 5 ) {
673
+ if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5 )) {
674
+ layout_ = " normal" ;
675
+ } else {
676
+ layout_ = " normal_contiguous" ;
677
+ }
678
+ } else {
679
+ layout_ = " light" ;
680
+ }
681
+
682
+ CHECK (block_num_ == 1 || layout_ != " normal_contiguous" );
683
+
659
684
TF_CHECK_OK (ReadBoolFromEnvVar (" TF_ENABLE_EV_ASYNC_RESTORE" , true ,
660
685
&ev_async_restore_));
661
686
}
0 commit comments