Skip to content

Commit e57a2d7

Browse files
feat(diff): get_diff function returning human-readable diffs (#137)
1 parent 2b3fd44 commit e57a2d7

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

sql_compare/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ def compare(first_sql: str, second_sql: str) -> bool:
4141
return first_sql_statements == second_sql_statements
4242

4343

44+
def get_diff(
45+
first_sql: str,
46+
second_sql: str,
47+
) -> list[list[list[str]]]:
48+
"""Show the difference between two SQL schemas, ignoring differences due to column order and other non-significant SQL changes."""
49+
first_set = {Statement(t) for t in sqlparse.parse(first_sql)}
50+
second_set = {Statement(t) for t in sqlparse.parse(second_sql)}
51+
first_diffs = sorted([stmt.str_tokens for stmt in first_set - second_set])
52+
second_diffs = sorted([stmt.str_tokens for stmt in second_set - first_set])
53+
54+
return [first_diffs, second_diffs]
55+
56+
4457
@dataclasses.dataclass
4558
class Token:
4659
"""Wrapper around `sqlparse.sql.Token`."""
@@ -74,6 +87,13 @@ def is_separator(self) -> bool:
7487
and self.token.normalized == ",",
7588
)
7689

90+
@property
91+
def str_tokens(self) -> list[str]:
92+
"""Return the token value."""
93+
if self.hash.strip():
94+
return [self.hash]
95+
return []
96+
7797

7898
@dataclasses.dataclass
7999
class TokenList:
@@ -128,6 +148,11 @@ def statement_type(self) -> str:
128148

129149
return Statement.UNKNOWN_TYPE
130150

151+
@property
152+
def str_tokens(self) -> list[str]:
153+
"""Return the reconstructed SQL statement from tokens as a list of strings."""
154+
return [t.hash for t in self.tokens if not t.ignore]
155+
131156

132157
class Statement(TokenList):
133158
"""SQL statement."""
@@ -152,6 +177,11 @@ def statement_type(self) -> str:
152177
# Only one keyword (e.g.: SELECT, INSERT, DELETE, etc.)
153178
return keywords[0]
154179

180+
@property
181+
def str_tokens(self) -> list[str]:
182+
"""Return the reconstructed SQL statement from tokens as a list of strings."""
183+
return [t for token in self.tokens for t in token.str_tokens]
184+
155185

156186
class UnorderedTokenList(TokenList):
157187
"""Unordered token list."""

tests/test_sql_compare.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,133 @@ def test_compare_neq(first_sql: str, second_sql: str) -> None:
184184
def test_statement_type(sql: str, expected_type: str) -> None:
185185
statement = sql_compare.Statement(sqlparse.parse(sql)[0])
186186
assert statement.statement_type == expected_type
187+
188+
189+
@pytest.mark.parametrize(
190+
("first_sql", "second_sql", "expected_diff"),
191+
[
192+
(
193+
"CREATE TABLE foo (id INT PRIMARY KEY)",
194+
"CREATE TABLE foo (id INT UNIQUE)",
195+
[
196+
[["CREATE", "TABLE", "foo", "(", "id", "INT", "PRIMARY KEY"]],
197+
[["CREATE", "TABLE", "foo", "(", "id", "INT", "UNIQUE"]],
198+
],
199+
),
200+
(
201+
"CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')",
202+
"CREATE TYPE public.colors AS ENUM ('BLUE', 'GREEN', 'RED')",
203+
[[], []],
204+
),
205+
(
206+
"CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')",
207+
"CREATE TYPE public.colors AS ENUM ('YELLOW', 'BLUE', 'RED')",
208+
[
209+
[
210+
[
211+
"CREATE",
212+
"TYPE",
213+
"public",
214+
".",
215+
"colors",
216+
"AS",
217+
"ENUM",
218+
"(",
219+
"'BLUE'",
220+
",",
221+
"'GREEN'",
222+
",",
223+
"'RED'",
224+
],
225+
],
226+
[
227+
[
228+
"CREATE",
229+
"TYPE",
230+
"public",
231+
".",
232+
"colors",
233+
"AS",
234+
"ENUM",
235+
"(",
236+
"'BLUE'",
237+
",",
238+
"'RED'",
239+
",",
240+
"'YELLOW'",
241+
],
242+
],
243+
],
244+
),
245+
(
246+
"""
247+
CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'REJECTED');
248+
CREATE TABLE users (id INT, name VARCHAR(100), status public.status);
249+
CREATE INDEX user_status_idx ON users (status);
250+
""",
251+
"""
252+
CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'ARCHIVED');
253+
CREATE TABLE logs (id INT, message TEXT);
254+
CREATE TABLE users (id INT, name VARCHAR(100), status public.status);
255+
CREATE INDEX user_status_idx ON users (status);
256+
""",
257+
[
258+
[
259+
[
260+
"CREATE",
261+
"TYPE",
262+
"public",
263+
".",
264+
"status",
265+
"AS",
266+
"ENUM",
267+
"(",
268+
"'APPROVED'",
269+
",",
270+
"'PENDING'",
271+
",",
272+
"'REJECTED'",
273+
";",
274+
],
275+
],
276+
[
277+
[
278+
"CREATE",
279+
"TABLE",
280+
"logs",
281+
"(",
282+
"id",
283+
"INT",
284+
",",
285+
"message",
286+
"TEXT",
287+
";",
288+
],
289+
[
290+
"CREATE",
291+
"TYPE",
292+
"public",
293+
".",
294+
"status",
295+
"AS",
296+
"ENUM",
297+
"(",
298+
"'APPROVED'",
299+
",",
300+
"'ARCHIVED'",
301+
",",
302+
"'PENDING'",
303+
";",
304+
],
305+
],
306+
],
307+
),
308+
],
309+
)
310+
def test_get_diff(
311+
first_sql: str,
312+
second_sql: str,
313+
expected_diff: list[list[list[str]]],
314+
) -> None:
315+
result = sql_compare.get_diff(first_sql, second_sql)
316+
assert result == expected_diff

0 commit comments

Comments
 (0)