From 2a1e6d5f8f678950ec6a746e52a43c74a58a40bc Mon Sep 17 00:00:00 2001 From: Andrea Boriero Date: Wed, 4 Sep 2024 20:37:42 +0200 Subject: [PATCH 1/2] HHH-18563 Test auto flush foreign key target tables for set clause paths --- .../flush/AutoFlushOnUpdateQueryTest.java | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java new file mode 100644 index 000000000000..ca76ca5e9960 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java @@ -0,0 +1,174 @@ +package org.hibernate.orm.test.flush; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.Id; +import jakarta.persistence.OneToOne; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + + +@DomainModel( + annotatedClasses = { + AutoFlushOnUpdateQueryTest.FruitLogEntry.class, + AutoFlushOnUpdateQueryTest.Fruit.class, + } +) +@SessionFactory +public class AutoFlushOnUpdateQueryTest { + + public static final String FRUIT_NAME = "Apple"; + + @BeforeEach + public void setUp(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + session.persist( new Fruit( FRUIT_NAME ) ); + } + ); + } + + @AfterEach + public void tearDown(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + session.createMutationQuery( "delete from Fruit" ).executeUpdate(); + session.createMutationQuery( "delete from FruitLogEntry" ).executeUpdate(); + } + ); + } + + @Test + public void testFlushIsExecuted(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Fruit fruit = session + .createQuery( + "select f from Fruit f where f.name = :name", + Fruit.class + ).setParameter( "name", FRUIT_NAME ).getSingleResult(); + + FruitLogEntry logEntry = new FruitLogEntry( fruit, "foo" ); + session.persist( logEntry ); + + session.createMutationQuery( "update Fruit f set f.logEntry = :logEntry where f.id = :fruitId" ) + .setParameter( "logEntry", logEntry ) + .setParameter( "fruitId", fruit.getId() ).executeUpdate(); + } + ); + + scope.inTransaction( + session -> { + Fruit fruit = session + .createQuery( + "select f from Fruit f where f.name = :name", + Fruit.class + ).setParameter( "name", FRUIT_NAME ).getSingleResult(); + assertThat( fruit.getLogEntry() ).isNotNull(); + } + ); + } + + @Test + public void testFlushIsExecuted2(SessionFactoryScope scope) { + scope.inTransaction( + session -> { + Fruit fruit = session + .createQuery( + "select f from Fruit f where f.name = :name", + Fruit.class + ).setParameter( "name", FRUIT_NAME ).getSingleResult(); + + FruitLogEntry logEntry = new FruitLogEntry( fruit, "foo" ); + session.persist( logEntry ); + + session.createMutationQuery( "update Fruit f set f.logEntry.id = :logEntryId where f.id = :fruitId" ) + .setParameter( "logEntryId", logEntry.getId() ) + .setParameter( "fruitId", fruit.getId() ).executeUpdate(); + } + ); + + scope.inTransaction( + session -> { + Fruit fruit = session + .createQuery( + "select f from Fruit f where f.name = :name", + Fruit.class + ).setParameter( "name", FRUIT_NAME ).getSingleResult(); + assertThat( fruit.getLogEntry() ).isNotNull(); + } + ); + } + + @Entity(name = "Fruit") + public static class Fruit { + + @Id + @GeneratedValue + private Long id; + + private String name; + + @OneToOne + private FruitLogEntry logEntry; + + public Fruit() { + } + + public Fruit(String name) { + this.name = name; + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } + + public FruitLogEntry getLogEntry() { + return logEntry; + } + } + + @Entity(name = "FruitLogEntry") + public static class FruitLogEntry { + + @Id + @GeneratedValue + private Long id; + + @OneToOne(mappedBy = "logEntry") + private Fruit fruit; + + private String logComments; + + public FruitLogEntry(Fruit fruit, String comment) { + this.fruit = fruit; + this.logComments = comment; + } + + FruitLogEntry() { + } + + public Long getId() { + return id; + } + + public Fruit getFruit() { + return fruit; + } + + public String getLogComments() { + return logComments; + } + } +} From aef004d558a43855b63cb5d752582aa716ca076f Mon Sep 17 00:00:00 2001 From: Christian Beikov Date: Mon, 9 Sep 2024 13:01:00 +0200 Subject: [PATCH 2/2] HHH-18563 Add set clause foreign key target tables to affected tables --- .../dialect/OracleLegacySqlAstTranslator.java | 50 +++++++++- .../dialect/OracleSqlAstTranslator.java | 52 +++++++++- .../internal/cte/CteInsertHandler.java | 2 +- .../sqm/sql/BaseSqmToSqlAstConverter.java | 99 ++++++++++--------- .../BasicValuedPathInterpretation.java | 30 ++++++ .../EmbeddableValuedPathInterpretation.java | 30 ++++++ .../EntityValuedPathInterpretation.java | 29 ++++++ ...atedCompositeValuedPathInterpretation.java | 31 +++++- .../sql/internal/SqmPathInterpretation.java | 6 ++ .../hibernate/sql/ast/SqlAstTranslator.java | 2 + .../sql/ast/spi/AbstractSqlAstTranslator.java | 83 +++++++++++++--- .../flush/AutoFlushOnUpdateQueryTest.java | 4 + 12 files changed, 348 insertions(+), 70 deletions(-) diff --git a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/OracleLegacySqlAstTranslator.java b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/OracleLegacySqlAstTranslator.java index 0727be243100..2cc8ebaf0844 100644 --- a/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/OracleLegacySqlAstTranslator.java +++ b/hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/OracleLegacySqlAstTranslator.java @@ -22,6 +22,7 @@ import org.hibernate.query.common.FetchClauseType; import org.hibernate.query.common.FrameExclusion; import org.hibernate.query.common.FrameKind; +import org.hibernate.query.sqm.sql.internal.SqmPathInterpretation; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator; @@ -55,7 +56,9 @@ import org.hibernate.sql.ast.tree.select.QueryPart; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectClause; +import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.ast.tree.select.SortSpecification; +import org.hibernate.sql.ast.tree.update.Assignable; import org.hibernate.sql.ast.tree.update.Assignment; import org.hibernate.sql.ast.tree.update.UpdateStatement; import org.hibernate.sql.exec.spi.JdbcOperation; @@ -668,13 +671,43 @@ private boolean supportsOffsetFetchClause() { return getDialect().supportsFetchClause( FetchClauseType.ROWS_ONLY ); } + @Override + protected void renderNull(Literal literal) { + if ( getParameterRenderingMode() == SqlAstNodeRenderingMode.NO_UNTYPED ) { + switch ( literal.getJdbcMapping().getJdbcType().getDdlTypeCode() ) { + case SqlTypes.BLOB: + appendSql( "to_blob(null)" ); + break; + case SqlTypes.CLOB: + appendSql( "to_clob(null)" ); + break; + case SqlTypes.NCLOB: + appendSql( "to_nclob(null)" ); + break; + default: + super.renderNull( literal ); + break; + } + } + else { + super.renderNull( literal ); + } + } + @Override protected void visitSetAssignment(Assignment assignment) { + final Assignable assignable = assignment.getAssignable(); + if ( assignable instanceof SqmPathInterpretation ) { + final String affectedTableName = ( (SqmPathInterpretation) assignable ).getAffectedTableName(); + if ( affectedTableName != null ) { + addAffectedTableName( affectedTableName ); + } + } final List columnReferences = assignment.getAssignable().getColumnReferences(); + final Expression assignedValue = assignment.getAssignedValue(); if ( columnReferences.size() == 1 ) { columnReferences.get( 0 ).appendColumnForWrite( this ); appendSql( '=' ); - final Expression assignedValue = assignment.getAssignedValue(); final SqlTuple sqlTuple = SqlTupleContainer.getSqlTuple( assignedValue ); if ( sqlTuple != null ) { assert sqlTuple.getExpressions().size() == 1; @@ -684,7 +717,7 @@ protected void visitSetAssignment(Assignment assignment) { assignedValue.accept( this ); } } - else { + else if ( assignedValue instanceof SelectStatement ) { char separator = OPEN_PARENTHESIS; for ( ColumnReference columnReference : columnReferences ) { appendSql( separator ); @@ -694,5 +727,18 @@ protected void visitSetAssignment(Assignment assignment) { appendSql( ")=" ); assignment.getAssignedValue().accept( this ); } + else { + assert assignedValue instanceof SqlTupleContainer; + final List expressions = ( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions(); + columnReferences.get( 0 ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( 0 ).accept( this ); + for ( int i = 1; i < columnReferences.size(); i++ ) { + appendSql( ',' ); + columnReferences.get( i ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( i ).accept( this ); + } + } } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleSqlAstTranslator.java index 0ce276707e0b..012894497f18 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleSqlAstTranslator.java @@ -20,6 +20,7 @@ import org.hibernate.query.common.FetchClauseType; import org.hibernate.query.common.FrameExclusion; import org.hibernate.query.common.FrameKind; +import org.hibernate.query.sqm.sql.internal.SqmPathInterpretation; import org.hibernate.sql.ast.Clause; import org.hibernate.sql.ast.SqlAstNodeRenderingMode; import org.hibernate.sql.ast.spi.SqlSelection; @@ -53,7 +54,9 @@ import org.hibernate.sql.ast.tree.select.QueryPart; import org.hibernate.sql.ast.tree.select.QuerySpec; import org.hibernate.sql.ast.tree.select.SelectClause; +import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.ast.tree.select.SortSpecification; +import org.hibernate.sql.ast.tree.update.Assignable; import org.hibernate.sql.ast.tree.update.Assignment; import org.hibernate.sql.ast.tree.update.UpdateStatement; import org.hibernate.sql.exec.spi.JdbcOperation; @@ -625,13 +628,43 @@ private boolean supportsOffsetFetchClause() { return getDialect().supportsFetchClause( FetchClauseType.ROWS_ONLY ); } + @Override + protected void renderNull(Literal literal) { + if ( getParameterRenderingMode() == SqlAstNodeRenderingMode.NO_UNTYPED ) { + switch ( literal.getJdbcMapping().getJdbcType().getDdlTypeCode() ) { + case SqlTypes.BLOB: + appendSql( "to_blob(null)" ); + break; + case SqlTypes.CLOB: + appendSql( "to_clob(null)" ); + break; + case SqlTypes.NCLOB: + appendSql( "to_nclob(null)" ); + break; + default: + super.renderNull( literal ); + break; + } + } + else { + super.renderNull( literal ); + } + } + @Override protected void visitSetAssignment(Assignment assignment) { + final Assignable assignable = assignment.getAssignable(); + if ( assignable instanceof SqmPathInterpretation ) { + final String affectedTableName = ( (SqmPathInterpretation) assignable ).getAffectedTableName(); + if ( affectedTableName != null ) { + addAffectedTableName( affectedTableName ); + } + } final List columnReferences = assignment.getAssignable().getColumnReferences(); + final Expression assignedValue = assignment.getAssignedValue(); if ( columnReferences.size() == 1 ) { columnReferences.get( 0 ).appendColumnForWrite( this ); appendSql( '=' ); - final Expression assignedValue = assignment.getAssignedValue(); final SqlTuple sqlTuple = SqlTupleContainer.getSqlTuple( assignedValue ); if ( sqlTuple != null ) { assert sqlTuple.getExpressions().size() == 1; @@ -641,7 +674,7 @@ protected void visitSetAssignment(Assignment assignment) { assignedValue.accept( this ); } } - else { + else if ( assignedValue instanceof SelectStatement ) { char separator = OPEN_PARENTHESIS; for ( ColumnReference columnReference : columnReferences ) { appendSql( separator ); @@ -649,7 +682,20 @@ protected void visitSetAssignment(Assignment assignment) { separator = COMMA_SEPARATOR_CHAR; } appendSql( ")=" ); - assignment.getAssignedValue().accept( this ); + assignedValue.accept( this ); + } + else { + assert assignedValue instanceof SqlTupleContainer; + final List expressions = ( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions(); + columnReferences.get( 0 ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( 0 ).accept( this ); + for ( int i = 1; i < columnReferences.size(); i++ ) { + appendSql( ',' ); + columnReferences.get( i ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( i ).accept( this ); + } } } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/mutation/internal/cte/CteInsertHandler.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/mutation/internal/cte/CteInsertHandler.java index cd9f67ab17ba..8bc9d0694f1c 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/mutation/internal/cte/CteInsertHandler.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/mutation/internal/cte/CteInsertHandler.java @@ -1299,7 +1299,7 @@ private List getCompatibleAssignments(InsertSelectStatement dmlState final List assignments = conflictClause.getAssignments(); for ( Assignment assignment : assignments ) { for ( ColumnReference targetColumn : dmlStatement.getTargetColumns() ) { - if ( targetColumn.equals( assignment.getAssignable() ) ) { + if ( assignment.getAssignable().getColumnReferences().contains( targetColumn ) ) { if ( compatibleAssignments == null ) { compatibleAssignments = new ArrayList<>( assignments.size() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index bfe6f7a5dad3..cfa00d627b85 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -858,36 +858,57 @@ public List visitSetClause(SqmSetClause setClause) { final SqmExpression assignmentValue = sqmAssignment.getValue(); final SqmParameter assignmentValueParameter = getSqmParameter( assignmentValue ); final Expression pathSqlExpression = assignedPathInterpretation.getSqlExpression(); - final List targetColumnReferences = - pathSqlExpression instanceof SqlTuple sqlTuple - ? sqlTuple.getExpressions() + //noinspection unchecked + final List targetColumnReferences = + pathSqlExpression instanceof SqlTupleContainer sqlTuple + ? (List) sqlTuple.getSqlTuple().getExpressions() : pathSqlExpression.getColumnReference().getColumnReferences(); if ( assignmentValueParameter != null ) { + final ArrayList expressions = new ArrayList<>( targetColumnReferences.size() ); consumeSqmParameter( assignmentValueParameter, assignedPathInterpretation.getExpressionType(), - (index, jdbcParameter) -> addAssignment( - assignments, - aggregateColumnAssignmentHandler, - targetColumnReferences.get( index ), - jdbcParameter - ) + (index, jdbcParameter) -> expressions.add( jdbcParameter ) ); - } - else if ( assignmentValue instanceof SqmLiteralNull ) { - for ( Expression columnReference : targetColumnReferences ) { + if ( pathSqlExpression instanceof SqlTupleContainer ) { addAssignment( assignments, aggregateColumnAssignmentHandler, - columnReference, - new QueryLiteral<>( null, (BasicValuedMapping) columnReference.getExpressionType() ) + (Assignable) assignedPathInterpretation, + targetColumnReferences, + new SqlTuple( expressions, assignedPathInterpretation.getExpressionType() ) ); } + else { + assert expressions.size() == 1; + addAssignment( + assignments, + aggregateColumnAssignmentHandler, + (Assignable) assignedPathInterpretation, + targetColumnReferences, + expressions.get( 0 ) + ); + } + } + else if ( pathSqlExpression instanceof SqlTupleContainer + && assignmentValue instanceof SqmLiteralNull ) { + final ArrayList expressions = new ArrayList<>( targetColumnReferences.size() ); + for ( ColumnReference targetColumnReference : targetColumnReferences ) { + expressions.add( new QueryLiteral<>( null, + (SqlExpressible) targetColumnReference.getExpressionType() ) ); + } + addAssignment( + assignments, + aggregateColumnAssignmentHandler, + (Assignable) assignedPathInterpretation, + targetColumnReferences, + new SqlTuple( expressions, assignedPathInterpretation.getExpressionType() ) + ); } else { addAssignments( (Expression) assignmentValue.accept( this ), - assignedPathInterpretation.getExpressionType(), + assignedPathInterpretation, targetColumnReferences, assignments, aggregateColumnAssignmentHandler @@ -909,35 +930,18 @@ else if ( assignmentValue instanceof SqmLiteralNull ) { private void addAssignments( Expression valueExpression, - ModelPart assignedPathType, - List targetColumnReferences, + SqmPathInterpretation assignedPathInterpretation, + List targetColumnReferences, ArrayList assignments, - AggregateColumnAssignmentHandler assignmentHandler) { - checkAssignment( valueExpression, assignedPathType ); - if ( valueExpression instanceof SqlTuple sqlTuple ) { - addTupleAssignments( targetColumnReferences, assignments, assignmentHandler, sqlTuple ); - } - else if ( valueExpression instanceof EmbeddableValuedPathInterpretation embeddable ) { - addTupleAssignments( targetColumnReferences, assignments, assignmentHandler, embeddable.getSqlTuple() ); - } - else { - for ( Expression columnReference : targetColumnReferences ) { - addAssignment( assignments, assignmentHandler, columnReference, valueExpression ); - } - } - } - - private void addTupleAssignments( - List targetColumnReferences, - ArrayList assignments, - AggregateColumnAssignmentHandler aggregateColumnAssignmentHandler, - SqlTuple sqlTuple) { - final List expressions = sqlTuple.getExpressions(); - assert targetColumnReferences.size() == expressions.size(); - for ( int i = 0; i < targetColumnReferences.size(); i++ ) { - final ColumnReference columnReference = (ColumnReference) targetColumnReferences.get( i ); - addAssignment( assignments, aggregateColumnAssignmentHandler, columnReference, expressions.get( i ) ); - } + AggregateColumnAssignmentHandler aggregateColumnAssignmentHandler) { + checkAssignment( valueExpression, assignedPathInterpretation.getExpressionType() ); + addAssignment( + assignments, + aggregateColumnAssignmentHandler, + (Assignable) assignedPathInterpretation, + targetColumnReferences, + valueExpression + ); } private void checkAssignment(Expression valueExpression, ModelPart assignedPathType) { @@ -958,12 +962,15 @@ private void checkAssignment(Expression valueExpression, ModelPart assignedPathT private void addAssignment( List assignments, AggregateColumnAssignmentHandler aggregateColumnAssignmentHandler, - Expression columnReference, + Assignable assignable, + List targetColumnReferences, Expression valueExpression) { if ( aggregateColumnAssignmentHandler != null ) { - aggregateColumnAssignmentHandler.addAssignment( assignments.size(), (ColumnReference) columnReference ); + for ( ColumnReference targetColumnReference : targetColumnReferences ) { + aggregateColumnAssignmentHandler.addAssignment( assignments.size(), targetColumnReference ); + } } - assignments.add( new Assignment( (ColumnReference) columnReference, valueExpression ) ); + assignments.add( new Assignment( assignable, valueExpression ) ); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/BasicValuedPathInterpretation.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/BasicValuedPathInterpretation.java index 064931de0809..3ae24e4564cb 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/BasicValuedPathInterpretation.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/BasicValuedPathInterpretation.java @@ -10,6 +10,7 @@ import org.hibernate.metamodel.MappingMetamodel; import org.hibernate.metamodel.mapping.BasicValuedModelPart; +import org.hibernate.metamodel.mapping.EntityAssociationMapping; import org.hibernate.metamodel.mapping.EntityMappingType; import org.hibernate.metamodel.mapping.MappingType; import org.hibernate.metamodel.mapping.ModelPart; @@ -31,6 +32,8 @@ import org.hibernate.sql.ast.tree.from.TableReference; import org.hibernate.sql.ast.tree.update.Assignable; +import org.checkerframework.checker.nullness.qual.Nullable; + import static jakarta.persistence.metamodel.Type.PersistenceType.ENTITY; import static org.hibernate.internal.util.NullnessUtil.castNonNull; import static org.hibernate.query.sqm.internal.SqmUtil.getTargetMappingIfNeeded; @@ -124,15 +127,37 @@ else if ( expression instanceof SqlSelectionExpression ) { } private final ColumnReference columnReference; + private final @Nullable String affectedTableName; public BasicValuedPathInterpretation( ColumnReference columnReference, NavigablePath navigablePath, BasicValuedModelPart mapping, TableGroup tableGroup) { + this( columnReference, navigablePath, mapping, tableGroup, determineAffectedTableName( tableGroup, mapping ) ); + } + + private static @Nullable String determineAffectedTableName(TableGroup tableGroup, BasicValuedModelPart mapping) { + final ModelPartContainer modelPart = tableGroup.getModelPart(); + if ( modelPart instanceof EntityAssociationMapping ) { + final EntityAssociationMapping associationMapping = (EntityAssociationMapping) modelPart; + if ( !associationMapping.containsTableReference( mapping.getContainingTableExpression() ) ) { + return associationMapping.getAssociatedEntityMappingType().getMappedTableDetails().getTableName(); + } + } + return null; + } + + public BasicValuedPathInterpretation( + ColumnReference columnReference, + NavigablePath navigablePath, + BasicValuedModelPart mapping, + TableGroup tableGroup, + @Nullable String affectedTableName) { super( navigablePath, mapping, tableGroup ); assert columnReference != null; this.columnReference = columnReference; + this.affectedTableName = affectedTableName; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -143,6 +168,11 @@ public Expression getSqlExpression() { return columnReference; } + @Override + public @Nullable String getAffectedTableName() { + return affectedTableName; + } + @Override public ColumnReference getColumnReference() { return columnReference; diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EmbeddableValuedPathInterpretation.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EmbeddableValuedPathInterpretation.java index 93d5f7076ff5..e311a1de75f2 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EmbeddableValuedPathInterpretation.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EmbeddableValuedPathInterpretation.java @@ -10,6 +10,7 @@ import org.hibernate.metamodel.MappingMetamodel; import org.hibernate.metamodel.mapping.EmbeddableValuedModelPart; +import org.hibernate.metamodel.mapping.EntityAssociationMapping; import org.hibernate.metamodel.mapping.EntityMappingType; import org.hibernate.metamodel.mapping.ModelPartContainer; import org.hibernate.metamodel.model.domain.EntityDomainType; @@ -26,6 +27,8 @@ import org.hibernate.sql.ast.tree.from.TableGroup; import org.hibernate.sql.ast.tree.update.Assignable; +import org.checkerframework.checker.nullness.qual.Nullable; + import static jakarta.persistence.metamodel.Type.PersistenceType.ENTITY; import static org.hibernate.query.sqm.internal.SqmUtil.getTargetMappingIfNeeded; @@ -79,14 +82,36 @@ else if ( lhs.getNodeType() instanceof EntityDomainType ) { } private final SqlTuple sqlExpression; + private final @Nullable String affectedTableName; public EmbeddableValuedPathInterpretation( SqlTuple sqlExpression, NavigablePath navigablePath, EmbeddableValuedModelPart mapping, TableGroup tableGroup) { + this( sqlExpression, navigablePath, mapping, tableGroup, determineAffectedTableName( tableGroup, mapping ) ); + } + + public EmbeddableValuedPathInterpretation( + SqlTuple sqlExpression, + NavigablePath navigablePath, + EmbeddableValuedModelPart mapping, + TableGroup tableGroup, + @Nullable String affectedTableName) { super( navigablePath, mapping, tableGroup ); this.sqlExpression = sqlExpression; + this.affectedTableName = affectedTableName; + } + + private static @Nullable String determineAffectedTableName(TableGroup tableGroup, EmbeddableValuedModelPart mapping) { + final ModelPartContainer modelPart = tableGroup.getModelPart(); + if ( modelPart instanceof EntityAssociationMapping ) { + final EntityAssociationMapping associationMapping = (EntityAssociationMapping) modelPart; + if ( !associationMapping.containsTableReference( mapping.getContainingTableExpression() ) ) { + return associationMapping.getAssociatedEntityMappingType().getMappedTableDetails().getTableName(); + } + } + return null; } @Override @@ -94,6 +119,11 @@ public SqlTuple getSqlExpression() { return sqlExpression; } + @Override + public @Nullable String getAffectedTableName() { + return affectedTableName; + } + @Override public void accept(SqlAstWalker sqlTreeWalker) { sqlExpression.accept( sqlTreeWalker ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EntityValuedPathInterpretation.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EntityValuedPathInterpretation.java index 31753fba8a1e..f54ae000cfdc 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EntityValuedPathInterpretation.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/EntityValuedPathInterpretation.java @@ -19,6 +19,7 @@ import org.hibernate.metamodel.mapping.EntityValuedModelPart; import org.hibernate.metamodel.mapping.MappingModelExpressible; import org.hibernate.metamodel.mapping.ModelPart; +import org.hibernate.metamodel.mapping.ModelPartContainer; import org.hibernate.metamodel.mapping.SelectableConsumer; import org.hibernate.metamodel.mapping.ValuedModelPart; import org.hibernate.metamodel.mapping.internal.EntityCollectionPart; @@ -48,10 +49,12 @@ import org.hibernate.sql.results.graph.Fetchable; import jakarta.persistence.criteria.Selection; +import org.checkerframework.checker.nullness.qual.Nullable; public class EntityValuedPathInterpretation extends AbstractSqmPathInterpretation implements SqlTupleContainer, Assignable { private final Expression sqlExpression; + private final @Nullable String affectedTableName; public static EntityValuedPathInterpretation from( SqmEntityValuedSimplePath sqmPath, @@ -415,8 +418,29 @@ public EntityValuedPathInterpretation( NavigablePath navigablePath, TableGroup tableGroup, EntityValuedModelPart mapping) { + this( sqlExpression, navigablePath, tableGroup, mapping, determineAffectedTableName( tableGroup, mapping ) ); + } + + public EntityValuedPathInterpretation( + Expression sqlExpression, + NavigablePath navigablePath, + TableGroup tableGroup, + EntityValuedModelPart mapping, + @Nullable String affectedTableName) { super( navigablePath, mapping, tableGroup ); this.sqlExpression = sqlExpression; + this.affectedTableName = affectedTableName; + } + + private static @Nullable String determineAffectedTableName(TableGroup tableGroup, EntityValuedModelPart mapping) { + final ModelPartContainer modelPart = tableGroup.getModelPart(); + if ( modelPart instanceof EntityAssociationMapping && mapping instanceof ValuedModelPart ) { + final EntityAssociationMapping associationMapping = (EntityAssociationMapping) modelPart; + if ( !associationMapping.containsTableReference( ( (ValuedModelPart) mapping ).getContainingTableExpression() ) ) { + return associationMapping.getAssociatedEntityMappingType().getMappedTableDetails().getTableName(); + } + } + return null; } @Override @@ -424,6 +448,11 @@ public Expression getSqlExpression() { return sqlExpression; } + @Override + public @Nullable String getAffectedTableName() { + return affectedTableName; + } + @Override public void accept(SqlAstWalker sqlTreeWalker) { sqlExpression.accept( sqlTreeWalker ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/NonAggregatedCompositeValuedPathInterpretation.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/NonAggregatedCompositeValuedPathInterpretation.java index 09507ab2d7ee..626a070a5670 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/NonAggregatedCompositeValuedPathInterpretation.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/NonAggregatedCompositeValuedPathInterpretation.java @@ -4,7 +4,8 @@ */ package org.hibernate.query.sqm.sql.internal; -import org.hibernate.metamodel.mapping.ModelPart; +import org.hibernate.metamodel.mapping.EntityAssociationMapping; +import org.hibernate.metamodel.mapping.ModelPartContainer; import org.hibernate.metamodel.mapping.NonAggregatedIdentifierMapping; import org.hibernate.spi.NavigablePath; import org.hibernate.query.sqm.sql.SqmToSqlAstConverter; @@ -14,6 +15,8 @@ import org.hibernate.sql.ast.tree.expression.SqlTupleContainer; import org.hibernate.sql.ast.tree.from.TableGroup; +import org.checkerframework.checker.nullness.qual.Nullable; + /** * @author Andrea Boriero */ @@ -40,19 +43,34 @@ public static NonAggregatedCompositeValuedPathInterpretation from( ), sqmPath.getNavigablePath(), mapping, - tableGroup + tableGroup, + determineAffectedTableName( tableGroup, mapping ) ); } private final SqlTuple sqlExpression; + private final @Nullable String affectedTableName; private NonAggregatedCompositeValuedPathInterpretation( SqlTuple sqlExpression, NavigablePath navigablePath, - ModelPart mapping, - TableGroup tableGroup) { + NonAggregatedIdentifierMapping mapping, + TableGroup tableGroup, + @Nullable String affectedTableName) { super( navigablePath, mapping, tableGroup ); this.sqlExpression = sqlExpression; + this.affectedTableName = affectedTableName; + } + + private static @Nullable String determineAffectedTableName(TableGroup tableGroup, NonAggregatedIdentifierMapping mapping) { + final ModelPartContainer modelPart = tableGroup.getModelPart(); + if ( modelPart instanceof EntityAssociationMapping ) { + final EntityAssociationMapping associationMapping = (EntityAssociationMapping) modelPart; + if ( !associationMapping.containsTableReference( mapping.getContainingTableExpression() ) ) { + return associationMapping.getAssociatedEntityMappingType().getMappedTableDetails().getTableName(); + } + } + return null; } @Override @@ -60,6 +78,11 @@ public SqlTuple getSqlExpression() { return sqlExpression; } + @Override + public @Nullable String getAffectedTableName() { + return affectedTableName; + } + @Override public void accept(SqlAstWalker sqlTreeWalker) { sqlExpression.accept( sqlTreeWalker ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/SqmPathInterpretation.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/SqmPathInterpretation.java index 02defdcebb3b..5d4b09f63787 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/SqmPathInterpretation.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/internal/SqmPathInterpretation.java @@ -9,6 +9,8 @@ import org.hibernate.query.sqm.tree.domain.SqmPath; import org.hibernate.sql.ast.tree.expression.Expression; +import org.checkerframework.checker.nullness.qual.Nullable; + /** * Interpretation of a {@link SqmPath} as part of the translation to SQL AST. We need specialized handling * for path interpretations because it can (and likely) contains multiple SqlExpressions (entity to its columns, e.g.) @@ -26,4 +28,8 @@ public interface SqmPathInterpretation extends Expression, DomainResultProduc default Expression getSqlExpression() { return this; } + + default @Nullable String getAffectedTableName() { + return null; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlAstTranslator.java index e279b82de452..1439763f9eeb 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/SqlAstTranslator.java @@ -60,5 +60,7 @@ public interface SqlAstTranslator extends SqlAstWalker */ Set getAffectedTableNames(); + void addAffectedTableName(String tableName); + T translate(JdbcParameterBindings jdbcParameterBindings, QueryOptions queryOptions); } diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java index a9fafe1f1edb..1d8ee6988cfb 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java @@ -184,6 +184,7 @@ import org.hibernate.sql.ast.tree.select.SelectClause; import org.hibernate.sql.ast.tree.select.SelectStatement; import org.hibernate.sql.ast.tree.select.SortSpecification; +import org.hibernate.sql.ast.tree.update.Assignable; import org.hibernate.sql.ast.tree.update.Assignment; import org.hibernate.sql.ast.tree.update.UpdateStatement; import org.hibernate.sql.exec.ExecutionException; @@ -439,6 +440,11 @@ public Set getAffectedTableNames() { return affectedTableNames; } + @Override + public void addAffectedTableName(String tableName) { + affectedTableNames.add( tableName ); + } + protected Statement getStatement() { return statementStack.getRoot(); } @@ -1144,11 +1150,18 @@ protected void renderSetClause(List assignments) { } protected void visitSetAssignment(Assignment assignment) { - final List columnReferences = assignment.getAssignable().getColumnReferences(); + final Assignable assignable = assignment.getAssignable(); + if ( assignable instanceof SqmPathInterpretation ) { + final String affectedTableName = ( (SqmPathInterpretation) assignable ).getAffectedTableName(); + if ( affectedTableName != null ) { + addAffectedTableName( affectedTableName ); + } + } + final List columnReferences = assignable.getColumnReferences(); + final Expression assignedValue = assignment.getAssignedValue(); if ( columnReferences.size() == 1 ) { columnReferences.get( 0 ).appendColumnForWrite( this, null ); appendSql( '=' ); - final Expression assignedValue = assignment.getAssignedValue(); final SqlTuple sqlTuple = getSqlTuple( assignedValue ); if ( sqlTuple != null ) { assert sqlTuple.getExpressions().size() == 1; @@ -1158,7 +1171,7 @@ protected void visitSetAssignment(Assignment assignment) { assignedValue.accept( this ); } } - else { + else if ( assignedValue instanceof SelectStatement ) { char separator = OPEN_PARENTHESIS; for ( ColumnReference columnReference : columnReferences ) { appendSql( separator ); @@ -1166,11 +1179,31 @@ protected void visitSetAssignment(Assignment assignment) { separator = COMMA_SEPARATOR_CHAR; } appendSql( ")=" ); - assignment.getAssignedValue().accept( this ); + assignedValue.accept( this ); + } + else { + assert assignedValue instanceof SqlTupleContainer; + final List expressions = ( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions(); + columnReferences.get( 0 ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( 0 ).accept( this ); + for ( int i = 1; i < columnReferences.size(); i++ ) { + appendSql( ',' ); + columnReferences.get( i ).appendColumnForWrite( this, null ); + appendSql( '=' ); + expressions.get( i ).accept( this ); + } } } protected void visitSetAssignmentEmulateJoin(Assignment assignment, UpdateStatement statement) { + final Assignable assignable = assignment.getAssignable(); + if ( assignable instanceof SqmPathInterpretation ) { + final String affectedTableName = ( (SqmPathInterpretation) assignable ).getAffectedTableName(); + if ( affectedTableName != null ) { + addAffectedTableName( affectedTableName ); + } + } final List columnReferences = assignment.getAssignable().getColumnReferences(); final Expression valueExpression; if ( columnReferences.size() == 1 ) { @@ -2028,17 +2061,39 @@ private void renderPredicatedSetAssignments(List assignments, Predic visitSetAssignment( assignment ); } else { - assert assignment.getAssignable().getColumnReferences().size() == 1; - final Expression expression = new CaseSearchedExpression( - (MappingModelExpressible) assignment.getAssignedValue().getExpressionType(), - List.of( - new CaseSearchedExpression.WhenFragment( - predicate, assignment.getAssignedValue() + final Assignable assignable = assignment.getAssignable(); + final Expression assignedValue = assignment.getAssignedValue(); + final Expression expression; + if ( assignable.getColumnReferences().size() == 1 ) { + expression = new CaseSearchedExpression( + (MappingModelExpressible) assignedValue.getExpressionType(), + List.of( new CaseSearchedExpression.WhenFragment( predicate, assignedValue ) ), + assignable.getColumnReferences().get( 0 ) + ); + } + else { + assert assignedValue instanceof SqlTupleContainer; + final List expressions = + ( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions(); + final List tupleExpressions = new ArrayList<>( expressions.size() ); + for ( int i = 0; i < expressions.size(); i++ ) { + tupleExpressions.add( + new CaseSearchedExpression( + (MappingModelExpressible) expressions.get( i ).getExpressionType(), + List.of( new CaseSearchedExpression.WhenFragment( + predicate, + expressions.get( i ) + ) ), + assignable.getColumnReferences().get( i ) ) - ), - assignment.getAssignable().getColumnReferences().get( 0 ) - ); - visitSetAssignment( new Assignment( assignment.getAssignable(), expression ) ); + ); + } + expression = new SqlTuple( + tupleExpressions, + (MappingModelExpressible) assignedValue.getExpressionType() + ); + } + visitSetAssignment( new Assignment( assignable, expression ) ); } } } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java index ca76ca5e9960..6d7bb7a3968f 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/flush/AutoFlushOnUpdateQueryTest.java @@ -1,3 +1,7 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ package org.hibernate.orm.test.flush; import org.hibernate.testing.orm.junit.DomainModel;