Skip to content

Commit 0f551e6

Browse files
committed
Introduce insert_nodes_before_value (#62)
Convenience function to insert a set of nodes in value(s).
1 parent e13a398 commit 0f551e6

File tree

3 files changed

+252
-0
lines changed

3 files changed

+252
-0
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"replace_all_uses_with",
1515
"create_value_mapping",
1616
"replace_nodes_and_values",
17+
"insert_nodes_in_value",
1718
]
1819

1920
from collections.abc import Mapping, Sequence
@@ -385,3 +386,108 @@ def replace_nodes_and_values(
385386
# insert new nodes after the index node
386387
graph_or_function.insert_after(insertion_point, new_nodes)
387388
graph_or_function.remove(old_nodes, safe=True)
389+
390+
391+
def _find_inputs_outputs(
392+
nodes: Sequence[_core.Node],
393+
) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]:
394+
"""Find the values that are considered as inputs and outputs in a sequence of nodes."""
395+
# Search the unique inputs/outputs in new_nodes, keeping the order.
396+
all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ()))
397+
all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ()))
398+
# A value is considered as input if it is not any output.
399+
inputs = tuple(val for val in all_inputs if val not in all_outputs)
400+
# A value is considered as output if it is not any input.
401+
outputs = tuple(val for val in all_outputs if val not in all_inputs)
402+
return inputs, outputs
403+
404+
405+
def insert_nodes_in_value(
406+
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node]
407+
) -> None:
408+
"""Inserts a sequence of nodes into the provided value(s).
409+
410+
This allows to insert a list of LINKED nodes (over the same context) at
411+
a specific point in the graph.
412+
413+
For example, suppose we have the following graph::
414+
415+
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output
416+
417+
We want to insert [node_M, node_N] at B value::
418+
419+
>>> import onnx_ir as ir
420+
>>> input = ir.Input("input")
421+
>>> node_A = ir.node("op_A", [input])
422+
>>> B = ir.Value(name="B")
423+
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B])
424+
>>> node_C = ir.node("op_C", node_B.outputs)
425+
>>> # Create a new sequence to insert
426+
>>> input_2 = ir.Input("input_2")
427+
>>> node_M = ir.node("op_M", [input_2])
428+
>>> node_N = ir.node("op_N", node_M.outputs)
429+
>>> # Insert nodes in B
430+
>>> insert_nodes_before_value(node_B.outputs, [node_M, node_N])
431+
>>> len(node_B.outputs)
432+
1
433+
>>> node_B.outputs[0].consumers()[0].op_type
434+
'op_M'
435+
>>> len(node_C.inputs)
436+
1
437+
>>> node_C.inputs[0].producer().op_type
438+
'op_N'
439+
>>> node_C.inputs[0].name
440+
'B'
441+
442+
When values is a sequence, the set of nodes must have the same number
443+
of inputs and outputs, then they are zipped into pairs: first value is
444+
replaced with the first input/output, and so on.
445+
446+
Args:
447+
values: The value(s) where to insert the nodes.
448+
new_nodes: The nodes to insert in the graph.
449+
"""
450+
if not isinstance(values, Sequence):
451+
values = (values,)
452+
453+
# Search the unique inputs/outputs in new_nodes, keeping the order.
454+
inputs, outputs = _find_inputs_outputs(new_nodes)
455+
456+
# Sanity check.
457+
if len(values) != len(inputs):
458+
raise ValueError(
459+
f"The number of values and inputs ({inputs}) in new_nodes must match."
460+
)
461+
if len(values) != len(outputs):
462+
raise ValueError(
463+
f"The number of values and outputs ({outputs}) in new_nodes must match."
464+
)
465+
466+
# Propagate relevant info.
467+
for val, in_val, out_val in zip(values, inputs, outputs):
468+
# Propagate relevant info from value to out_value.
469+
# TODO(Rama): Perhaps this should be a separate utility function.
470+
out_val.type = val.type
471+
out_val.shape = val.shape
472+
out_val.name = val.name
473+
# Propagate relevant info from value to in_value.
474+
# TODO(Rama): Perhaps this should be a separate utility function.
475+
in_val.type = val.type
476+
in_val.shape = val.shape
477+
# Rename each value, following each input.
478+
val.name = in_val.name
479+
480+
# Insert the new nodes in two steps:
481+
# 1. Reconnect the users of values to the outputs
482+
replace_all_uses_with(values, outputs)
483+
# 2. Reconnect the users of inputs to values
484+
replace_all_uses_with(inputs, values)
485+
486+
# Update graph if there is one:
487+
if (graph := values[-1].graph) is not None:
488+
# Update graph/function outputs if the node generates output
489+
_update_graph_or_function_outputs(graph, values, outputs)
490+
491+
# Insert new nodes if there is a graph
492+
graph.extend(new_nodes)
493+
graph.sort()
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Unit tests for the _convenience module."""
2+
3+
import onnx
4+
5+
import unittest
6+
7+
import onnx_ir as ir
8+
from onnx_ir._convenience import insert_nodes_in_value
9+
10+
11+
def _create_model(model_text: str) -> ir.Model:
12+
model = onnx.parser.parse_model(model_text)
13+
return ir.serde.deserialize_model(model)
14+
15+
16+
class ConvenienceTest(unittest.TestCase):
17+
def test_insert_nodes_in_value(self):
18+
# Main graph
19+
input = ir.Input("input")
20+
node_A = ir.node("op_A", [input])
21+
node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")])
22+
node_C = ir.node("op_C", node_B.outputs)
23+
24+
# New sequence to insert
25+
input_2 = ir.Input("input_2")
26+
node_M = ir.node("op_M", [input_2])
27+
node_N = ir.node("op_N", node_M.outputs)
28+
29+
# Insert nodes in B
30+
insert_nodes_in_value(node_B.outputs[0], [node_M, node_N])
31+
self.assertEqual(len(node_B.outputs), 1)
32+
self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M")
33+
self.assertEqual(len(node_C.inputs), 1)
34+
self.assertEqual(node_C.inputs[0].producer().op_type, "op_N")
35+
self.assertEqual(node_C.inputs[0].name, "B")
36+
37+
def test_insert_nodes_in_value_in_graph(self):
38+
ir_model = _create_model(
39+
"""
40+
<ir_version: 10, opset_import: [ "" : 17]>
41+
agraph (float[N] x) => (float[N] z) {
42+
two = Constant<value_float=2.0>()
43+
a, b = SplitNode(x)
44+
z = MergeNode(a, b, two)
45+
}
46+
"""
47+
)
48+
49+
# Sequence to insert.
50+
# Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]].
51+
i1, i2 = ir.Input("i1"), ir.Input("i2")
52+
a = ir.node("op_1", [i1, i2])
53+
b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2)
54+
c = ir.node("op_3", [i2, b.outputs[0]])
55+
56+
# Insert nodes in SplitNode.outputs
57+
target_node = ir_model.graph[1]
58+
insert_nodes_in_value(target_node.outputs, [a, b, c])
59+
60+
# Check target_node outputs have been renamed
61+
new_i1, new_i2 = target_node.outputs
62+
self.assertEqual(new_i1.name, "i1")
63+
self.assertEqual(new_i2.name, "i2")
64+
65+
# Check i1 and i2 have new users
66+
self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2"))
67+
self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3"))
68+
69+
# Check outputs have been correctly renamed as previous values
70+
self.assertEqual(b.outputs[1].name, "a")
71+
self.assertEqual(c.outputs[0].name, "b")
72+
73+
# Check nodes have been inserted in the graph
74+
self.assertEqual(len(ir_model.graph), 6)
75+
76+
def test_insert_nodes_in_input(self):
77+
ir_model = _create_model(
78+
"""
79+
<ir_version: 10, opset_import: [ "" : 17]>
80+
agraph (float[N] x) => (float[N] z) {
81+
two = Constant<value_float=2.0>()
82+
z = Add(x, two)
83+
}
84+
"""
85+
)
86+
87+
# Sequence to insert.
88+
x = ir.Input("new_x")
89+
node = ir.node("Mul", [x, x])
90+
91+
# Insert nodes in graph.inputs
92+
insert_nodes_in_value(ir_model.graph[1].inputs[0], [node])
93+
self.assertEqual(node.outputs[0].name, "x")
94+
95+
# Check input has been renamed
96+
self.assertEqual(ir_model.graph.inputs[0].name, "new_x")
97+
98+
# Finally, check new graph is valid
99+
proto = ir.to_proto(ir_model)
100+
onnx.checker.check_model(proto, full_check=True)
101+
102+
def test_insert_nodes_in_output(self):
103+
ir_model = _create_model(
104+
"""
105+
<ir_version: 10, opset_import: [ "" : 17]>
106+
agraph (float[N] x) => (float[N] z) {
107+
two = Constant<value_float=2.0>()
108+
z = Add(x, two)
109+
}
110+
"""
111+
)
112+
113+
# Sequence to insert.
114+
x = ir.Input("new_z")
115+
node = ir.node("Mul", [x, x])
116+
117+
# Insert nodes in graph.inputs
118+
insert_nodes_in_value(ir_model.graph.outputs[0], [node])
119+
self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z")
120+
121+
# Check output name is preserved
122+
self.assertEqual(ir_model.graph.outputs[0].name, "z")
123+
124+
def test_value_error_for_wrong_number_of_points(self):
125+
ir_model = _create_model(
126+
"""
127+
<ir_version: 10, opset_import: [ "" : 17]>
128+
agraph (float[N] x) => (float[N] z) {
129+
two = Constant<value_float=2.0>()
130+
a, b = SplitNode(x)
131+
z = MergeNode(a, b, two)
132+
}
133+
"""
134+
)
135+
node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")])
136+
with self.assertRaisesRegex(ValueError, "The number of values and inputs"):
137+
insert_nodes_in_value(ir_model.graph[0].outputs, [node])
138+
139+
with self.assertRaisesRegex(ValueError, "The number of values and outputs"):
140+
insert_nodes_in_value(ir_model.graph[1].outputs, [node])
141+
142+
143+
if __name__ == "__main__":
144+
unittest.main()

src/onnx_ir/convenience.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"replace_all_uses_with",
1111
"replace_nodes_and_values",
1212
"create_value_mapping",
13+
"insert_nodes_in_value",
1314
]
1415

1516
from onnx_ir._convenience import (
@@ -18,6 +19,7 @@
1819
create_value_mapping,
1920
replace_all_uses_with,
2021
replace_nodes_and_values,
22+
insert_nodes_in_value,
2123
)
2224

2325
# NOTE: Do not implement any other functions in this module.

0 commit comments

Comments
 (0)