|
| 1 | +"""Unit tests for the _convenience module.""" |
| 2 | + |
| 3 | +import onnx |
| 4 | + |
| 5 | +import unittest |
| 6 | + |
| 7 | +import onnx_ir as ir |
| 8 | +from onnx_ir._convenience import insert_nodes_in_value |
| 9 | + |
| 10 | + |
| 11 | +def _create_model(model_text: str) -> ir.Model: |
| 12 | + model = onnx.parser.parse_model(model_text) |
| 13 | + return ir.serde.deserialize_model(model) |
| 14 | + |
| 15 | + |
| 16 | +class ConvenienceTest(unittest.TestCase): |
| 17 | + def test_insert_nodes_in_value(self): |
| 18 | + # Main graph |
| 19 | + input = ir.Input("input") |
| 20 | + node_A = ir.node("op_A", [input]) |
| 21 | + node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")]) |
| 22 | + node_C = ir.node("op_C", node_B.outputs) |
| 23 | + |
| 24 | + # New sequence to insert |
| 25 | + input_2 = ir.Input("input_2") |
| 26 | + node_M = ir.node("op_M", [input_2]) |
| 27 | + node_N = ir.node("op_N", node_M.outputs) |
| 28 | + |
| 29 | + # Insert nodes in B |
| 30 | + insert_nodes_in_value(node_B.outputs[0], [node_M, node_N]) |
| 31 | + self.assertEqual(len(node_B.outputs), 1) |
| 32 | + self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M") |
| 33 | + self.assertEqual(len(node_C.inputs), 1) |
| 34 | + self.assertEqual(node_C.inputs[0].producer().op_type, "op_N") |
| 35 | + self.assertEqual(node_C.inputs[0].name, "B") |
| 36 | + |
| 37 | + def test_insert_nodes_in_value_in_graph(self): |
| 38 | + ir_model = _create_model( |
| 39 | + """ |
| 40 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 41 | + agraph (float[N] x) => (float[N] z) { |
| 42 | + two = Constant<value_float=2.0>() |
| 43 | + a, b = SplitNode(x) |
| 44 | + z = MergeNode(a, b, two) |
| 45 | + } |
| 46 | + """ |
| 47 | + ) |
| 48 | + |
| 49 | + # Sequence to insert. |
| 50 | + # Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]]. |
| 51 | + i1, i2 = ir.Input("i1"), ir.Input("i2") |
| 52 | + a = ir.node("op_1", [i1, i2]) |
| 53 | + b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2) |
| 54 | + c = ir.node("op_3", [i2, b.outputs[0]]) |
| 55 | + |
| 56 | + # Insert nodes in SplitNode.outputs |
| 57 | + target_node = ir_model.graph[1] |
| 58 | + insert_nodes_in_value(target_node.outputs, [a, b, c]) |
| 59 | + |
| 60 | + # Check target_node outputs have been renamed |
| 61 | + new_i1, new_i2 = target_node.outputs |
| 62 | + self.assertEqual(new_i1.name, "i1") |
| 63 | + self.assertEqual(new_i2.name, "i2") |
| 64 | + |
| 65 | + # Check i1 and i2 have new users |
| 66 | + self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2")) |
| 67 | + self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3")) |
| 68 | + |
| 69 | + # Check outputs have been correctly renamed as previous values |
| 70 | + self.assertEqual(b.outputs[1].name, "a") |
| 71 | + self.assertEqual(c.outputs[0].name, "b") |
| 72 | + |
| 73 | + # Check nodes have been inserted in the graph |
| 74 | + self.assertEqual(len(ir_model.graph), 6) |
| 75 | + |
| 76 | + def test_insert_nodes_in_input(self): |
| 77 | + ir_model = _create_model( |
| 78 | + """ |
| 79 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 80 | + agraph (float[N] x) => (float[N] z) { |
| 81 | + two = Constant<value_float=2.0>() |
| 82 | + z = Add(x, two) |
| 83 | + } |
| 84 | + """ |
| 85 | + ) |
| 86 | + |
| 87 | + # Sequence to insert. |
| 88 | + x = ir.Input("new_x") |
| 89 | + node = ir.node("Mul", [x, x]) |
| 90 | + |
| 91 | + # Insert nodes in graph.inputs |
| 92 | + insert_nodes_in_value(ir_model.graph[1].inputs[0], [node]) |
| 93 | + self.assertEqual(node.outputs[0].name, "x") |
| 94 | + |
| 95 | + # Check input has been renamed |
| 96 | + self.assertEqual(ir_model.graph.inputs[0].name, "new_x") |
| 97 | + |
| 98 | + # Finally, check new graph is valid |
| 99 | + proto = ir.to_proto(ir_model) |
| 100 | + onnx.checker.check_model(proto, full_check=True) |
| 101 | + |
| 102 | + def test_insert_nodes_in_output(self): |
| 103 | + ir_model = _create_model( |
| 104 | + """ |
| 105 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 106 | + agraph (float[N] x) => (float[N] z) { |
| 107 | + two = Constant<value_float=2.0>() |
| 108 | + z = Add(x, two) |
| 109 | + } |
| 110 | + """ |
| 111 | + ) |
| 112 | + |
| 113 | + # Sequence to insert. |
| 114 | + x = ir.Input("new_z") |
| 115 | + node = ir.node("Mul", [x, x]) |
| 116 | + |
| 117 | + # Insert nodes in graph.inputs |
| 118 | + insert_nodes_in_value(ir_model.graph.outputs[0], [node]) |
| 119 | + self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z") |
| 120 | + |
| 121 | + # Check output name is preserved |
| 122 | + self.assertEqual(ir_model.graph.outputs[0].name, "z") |
| 123 | + |
| 124 | + def test_value_error_for_wrong_number_of_points(self): |
| 125 | + ir_model = _create_model( |
| 126 | + """ |
| 127 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 128 | + agraph (float[N] x) => (float[N] z) { |
| 129 | + two = Constant<value_float=2.0>() |
| 130 | + a, b = SplitNode(x) |
| 131 | + z = MergeNode(a, b, two) |
| 132 | + } |
| 133 | + """ |
| 134 | + ) |
| 135 | + node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")]) |
| 136 | + with self.assertRaisesRegex(ValueError, "The number of values and inputs"): |
| 137 | + insert_nodes_in_value(ir_model.graph[0].outputs, [node]) |
| 138 | + |
| 139 | + with self.assertRaisesRegex(ValueError, "The number of values and outputs"): |
| 140 | + insert_nodes_in_value(ir_model.graph[1].outputs, [node]) |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == "__main__": |
| 144 | + unittest.main() |
0 commit comments