Skip to content
1 change: 1 addition & 0 deletions scripts/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_imports():

# Math ops.
_ = tfq.math.inner_product
_ = tfq.math.inner_product_hessian

# Noisy simulation ops.
_ = tfq.noise.expectation
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_quantum/core/ops/math_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_binary(
srcs = [
"tfq_inner_product.cc",
"tfq_inner_product_grad.cc",
"tfq_inner_product_hessian.cc",
],
copts = select({
":windows": [
Expand Down Expand Up @@ -62,6 +63,7 @@ cc_binary(
# cirq cc proto
"//tensorflow_quantum/core/ops:parse_context",
"//tensorflow_quantum/core/ops:tfq_simulate_utils",
"//tensorflow_quantum/core/src:adj_hessian_util",
"//tensorflow_quantum/core/src:adj_util",
"//tensorflow_quantum/core/src:circuit_parser_qsim",
"//tensorflow_quantum/core/src:util_qsim",
Expand Down Expand Up @@ -100,3 +102,13 @@ py_test(
"//tensorflow_quantum/python:util",
],
)

py_test(
name = "inner_product_hessian_test",
srcs = ["inner_product_hessian_test.py"],
python_version = "PY3",
deps = [
":inner_product_op_py",
"//tensorflow_quantum/python:util",
],
)
3 changes: 2 additions & 1 deletion tensorflow_quantum/core/ops/math_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
# ==============================================================================
"""Module for tfq.core.ops.math_ops.*"""

from tensorflow_quantum.core.ops.math_ops.inner_product_op import inner_product
from tensorflow_quantum.core.ops.math_ops.inner_product_op import (
inner_product, inner_product_hessian)
10 changes: 5 additions & 5 deletions tensorflow_quantum/core/ops/math_ops/inner_product_grad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class InnerProductAdjGradTest(tf.test.TestCase, parameterized.TestCase):
"""Tests tfq_inner_product_grad."""

def test_inner_product_grad_inputs(self):
"""Makes sure that inner_product_adj_grad fails on bad inputs."""
"""Makes sure that inner_product_grad fails on bad inputs."""
n_qubits = 5
batch_size = 5
n_other_programs = 3
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_inner_product_grad_inputs(self):
])
def test_correctness_with_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Tests that inner_product works with symbols."""
"""Tests that inner_product_grad works with symbols."""
symbol_names = ['alpha', 'beta', 'gamma']
n_params = len(symbol_names)
qubits = cirq.GridQubit.rect(1, n_qubits)
Expand All @@ -242,7 +242,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,

other_batch = [
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
for i in range(batch_size)
for _ in range(batch_size)
]

symbol_values_array = np.array(
Expand Down Expand Up @@ -312,15 +312,15 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
])
def test_correctness_without_symbols(self, n_qubits, batch_size,
inner_dim_size):
"""Tests that inner_product_adj_grad works without symbols."""
"""Tests that inner_product_grad works without symbols."""
qubits = cirq.GridQubit.rect(1, n_qubits)
circuit_batch, _ = \
util.random_circuit_resolver_batch(
qubits, batch_size)

other_batch = [
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
for i in range(batch_size)
for _ in range(batch_size)
]

programs = util.convert_to_tensor(circuit_batch)
Expand Down
Loading