Skip to content

Commit 5d7a113

Browse files
TensorFlow Lattice Authorsmmilanifard
authored andcommitted
Internal change
PiperOrigin-RevId: 559208084 Change-Id: Ic9ee1d2f8243012daf2cf4d862037026d3c35d2b
1 parent a864edb commit 5d7a113

26 files changed

+2213
-964
lines changed

examples/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
load("//third_party/bazel_rules/rules_python/python:py_binary.bzl", "py_binary")
17+
1618
licenses(["notice"])
1719

1820
package(

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# This version number should always be that of the *next* (unreleased) version.
2828
# Immediately after uploading a package to PyPI, you should increment the
2929
# version number and push to gitHub.
30-
__version__ = "2.0.11"
30+
__version__ = "2.0.12"
3131

3232
if "--release" in sys.argv:
3333
sys.argv.remove("--release")
@@ -45,7 +45,6 @@
4545
"scikit-learn",
4646
"matplotlib",
4747
"graphviz",
48-
"dm-sonnet",
4948
]
5049

5150
# Part of the visualization code uses colabtools and IPython libraries. These

tensorflow_lattice/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ py_library(
2828
srcs = [
2929
"__init__.py",
3030
"layers/__init__.py",
31-
"sonnet_modules/__init__.py",
3231
],
3332
srcs_version = "PY2AND3",
3433
deps = [
3534
"//tensorflow_lattice/python:aggregation_layer",
3635
"//tensorflow_lattice/python:categorical_calibration_layer",
3736
"//tensorflow_lattice/python:categorical_calibration_lib",
3837
"//tensorflow_lattice/python:cdf_layer",
38+
"//tensorflow_lattice/python:conditional_cdf",
39+
"//tensorflow_lattice/python:conditional_pwl_calibration",
3940
"//tensorflow_lattice/python:configs",
4041
"//tensorflow_lattice/python:estimators",
4142
"//tensorflow_lattice/python:kronecker_factored_lattice_layer",
@@ -50,7 +51,6 @@ py_library(
5051
"//tensorflow_lattice/python:premade_lib",
5152
"//tensorflow_lattice/python:pwl_calibration_layer",
5253
"//tensorflow_lattice/python:pwl_calibration_lib",
53-
"//tensorflow_lattice/python:pwl_calibration_sonnet_module",
5454
"//tensorflow_lattice/python:rtl_layer",
5555
"//tensorflow_lattice/python:test_utils",
5656
"//tensorflow_lattice/python:utils",

tensorflow_lattice/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from __future__ import absolute_import
2121

2222
import tensorflow_lattice.layers
23-
2423
from tensorflow_lattice.python import aggregation_layer
2524
from tensorflow_lattice.python import categorical_calibration_layer
2625
from tensorflow_lattice.python import categorical_calibration_lib
2726
from tensorflow_lattice.python import cdf_layer
27+
from tensorflow_lattice.python import conditional_cdf
28+
from tensorflow_lattice.python import conditional_pwl_calibration
2829
from tensorflow_lattice.python import configs
2930
from tensorflow_lattice.python import estimators
3031
from tensorflow_lattice.python import kronecker_factored_lattice_layer
@@ -39,9 +40,6 @@
3940
from tensorflow_lattice.python import premade_lib
4041
from tensorflow_lattice.python import pwl_calibration_layer
4142
from tensorflow_lattice.python import pwl_calibration_lib
42-
from tensorflow_lattice.python import pwl_calibration_sonnet_module
4343
from tensorflow_lattice.python import test_utils
4444
from tensorflow_lattice.python import utils
4545
from tensorflow_lattice.python import visualization
46-
47-
import tensorflow_lattice.sonnet_modules

tensorflow_lattice/python/BUILD

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
load("//third_party/bazel_rules/rules_python/python:py_library.bzl", "py_library")
17+
load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test")
18+
1619
package(
1720
default_visibility = [
1821
"//tensorflow_lattice:__subpackages__",
@@ -434,19 +437,6 @@ py_library(
434437
],
435438
)
436439

437-
py_library(
438-
name = "pwl_calibration_sonnet_module",
439-
srcs = ["pwl_calibration_sonnet_module.py"],
440-
srcs_version = "PY2AND3",
441-
deps = [
442-
":pwl_calibration_lib",
443-
":utils",
444-
# absl/logging dep,
445-
# sonnet dep,
446-
# tensorflow:tensorflow_no_contrib dep,
447-
],
448-
)
449-
450440
py_test(
451441
name = "pwl_calibration_test",
452442
size = "large",
@@ -457,14 +447,13 @@ py_test(
457447
deps = [
458448
":parallel_combination_layer",
459449
":pwl_calibration_layer",
460-
":pwl_calibration_sonnet_module",
461450
":test_utils",
462451
":utils",
463452
# absl/logging dep,
464453
# absl/testing:parameterized dep,
465454
# numpy dep,
466-
# sonnet dep,
467455
# tensorflow dep,
456+
# tensorflow:tensorflow_no_contrib dep,
468457
],
469458
)
470459

@@ -514,8 +503,6 @@ py_library(
514503
":visualization",
515504
# absl/logging dep,
516505
# numpy dep,
517-
# sonnet dep,
518-
# tensorflow dep,
519506
],
520507
)
521508

@@ -528,6 +515,42 @@ py_library(
528515
],
529516
)
530517

518+
py_library(
519+
name = "conditional_pwl_calibration",
520+
srcs = ["conditional_pwl_calibration.py"],
521+
deps = [
522+
# numpy dep,
523+
# tensorflow:tensorflow_no_contrib dep,
524+
],
525+
)
526+
527+
py_library(
528+
name = "conditional_cdf",
529+
srcs = ["conditional_cdf.py"],
530+
deps = [
531+
# tensorflow:tensorflow_no_contrib dep,
532+
],
533+
)
534+
535+
py_test(
536+
name = "conditional_cdf_test",
537+
srcs = ["conditional_cdf_test.py"],
538+
deps = [
539+
":conditional_cdf",
540+
# absl/testing:parameterized dep,
541+
# tensorflow:tensorflow_no_contrib dep,
542+
],
543+
)
544+
545+
py_test(
546+
name = "conditional_pwl_calibration_test",
547+
srcs = ["conditional_pwl_calibration_test.py"],
548+
deps = [
549+
":conditional_pwl_calibration",
550+
# tensorflow:tensorflow_no_contrib dep,
551+
],
552+
)
553+
531554
py_test(
532555
name = "utils_test",
533556
srcs = ["utils_test.py"],

tensorflow_lattice/python/aggregation_layer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,14 @@ def call(self, x):
7676
def get_config(self):
7777
"""Standard Keras get_config() method."""
7878
config = super(Aggregation, self).get_config().copy()
79-
config.update({'model': tf.keras.utils.serialize_keras_object(self.model)})
79+
config.update(
80+
{'model': tf.keras.utils.legacy.serialize_keras_object(self.model)}
81+
)
8082
return config
8183

8284
@classmethod
8385
def from_config(cls, config, custom_objects=None):
84-
model = tf.keras.utils.deserialize_keras_object(
85-
config.pop('model'), custom_objects=custom_objects)
86+
model = tf.keras.utils.legacy.deserialize_keras_object(
87+
config.pop('model'), custom_objects=custom_objects
88+
)
8689
return cls(model, **config)

tensorflow_lattice/python/categorical_calibration_layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,11 @@ def get_config(self):
249249
"output_max": self.output_max,
250250
"monotonicities": self.monotonicities,
251251
"kernel_initializer":
252-
keras.initializers.serialize(self.kernel_initializer),
252+
keras.initializers.serialize(
253+
self.kernel_initializer, use_legacy_format=True),
253254
"kernel_regularizer":
254-
[keras.regularizers.serialize(r) for r in self.kernel_regularizer],
255+
[keras.regularizers.serialize(r, use_legacy_format=True)
256+
for r in self.kernel_regularizer],
255257
"default_input_value": self.default_input_value,
256258
"split_outputs": self.split_outputs,
257259
} # pyformat: disable

tensorflow_lattice/python/cdf_layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def get_config(self):
255255
"sparsity_factor":
256256
self.sparsity_factor,
257257
"kernel_initializer":
258-
tf.keras.initializers.serialize(self.kernel_initializer),
258+
tf.keras.initializers.serialize(
259+
self.kernel_initializer, use_legacy_format=True),
259260
}
260261
config.update(super(CDF, self).get_config())
261262
return config

0 commit comments

Comments
 (0)