Skip to content

Commit

Permalink
Use * instead of primary alias in count queries with CTE.
Browse files Browse the repository at this point in the history
Closes #3726
Original pull request: #3730
  • Loading branch information
christophstrobl authored and mp911de committed Jan 13, 2025
1 parent 3e99eee commit 643a3a9
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class HqlCountQueryTransformer extends HqlQueryRenderer {

private final @Nullable String countProjection;
private final @Nullable String primaryFromAlias;
private boolean containsCTE = false;

HqlCountQueryTransformer(@Nullable String countProjection, @Nullable String primaryFromAlias) {
this.countProjection = countProjection;
Expand Down Expand Up @@ -66,6 +67,12 @@ public QueryRendererBuilder visitOrderedQuery(HqlParser.OrderedQueryContext ctx)
return builder;
}

@Override
public QueryTokenStream visitCte(HqlParser.CteContext ctx) {
this.containsCTE = true;
return super.visitCte(ctx);
}

@Override
public QueryRendererBuilder visitFromQuery(HqlParser.FromQueryContext ctx) {

Expand Down Expand Up @@ -189,7 +196,9 @@ public QueryTokenStream visitSelectClause(HqlParser.SelectClauseContext ctx) {
nested.append(QueryTokens.expression(ctx.DISTINCT()));
nested.append(getDistinctCountSelection(visit(ctx.selectionList())));
} else {
nested.append(QueryTokens.token(primaryFromAlias));

// with CTE primary alias fails with hibernate (WITH entities AS (…) SELECT count(c) FROM entities c)
nested.append(containsCTE ? QueryTokens.token("*") : QueryTokens.token(primaryFromAlias));
}
} else {
builder.append(QueryTokens.token(countProjection));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.jpa.repository;

import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assumptions.*;

import jakarta.persistence.EntityManager;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.FilterType;
import org.springframework.context.annotation.ImportResource;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.domain.sample.Role;
import org.springframework.data.jpa.domain.sample.User;
import org.springframework.data.jpa.provider.PersistenceProvider;
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
import org.springframework.data.jpa.repository.sample.RoleRepository;
import org.springframework.data.jpa.repository.sample.UserRepository;
import org.springframework.data.repository.CrudRepository;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.transaction.annotation.Transactional;

/**
* Hibernate-specific repository tests.
*
* @author Mark Paluch
*/
@ExtendWith(SpringExtension.class)
@ContextConfiguration()
@Transactional
class HibernateRepositoryTests {

@Autowired UserRepository userRepository;
@Autowired RoleRepository roleRepository;
@Autowired CteUserRepository cteUserRepository;
@Autowired EntityManager em;

PersistenceProvider provider;
User dave;
User carter;
User oliver;
Role drummer;
Role guitarist;
Role singer;

@BeforeEach
void setUp() {
provider = PersistenceProvider.fromEntityManager(em);

assumeThat(provider).isEqualTo(PersistenceProvider.HIBERNATE);
roleRepository.deleteAll();
userRepository.deleteAll();

drummer = roleRepository.save(new Role("DRUMMER"));
guitarist = roleRepository.save(new Role("GUITARIST"));
singer = roleRepository.save(new Role("SINGER"));

dave = userRepository.save(new User("Dave", "Matthews", "[email protected]", singer));
carter = userRepository.save(new User("Carter", "Beauford", "[email protected]", singer, drummer));
oliver = userRepository.save(new User("Oliver August", "Matthews", "[email protected]"));
}

@Test // GH-3726
void testQueryWithCTE() {

Page<UserExcerptDto> result = cteUserRepository.findWithCTE(PageRequest.of(0, 1));
assertThat(result.getTotalElements()).isEqualTo(3);
}

@ImportResource({ "classpath:infrastructure.xml" })
@Configuration
@EnableJpaRepositories(basePackageClasses = HibernateRepositoryTests.class, considerNestedRepositories = true,
includeFilters = @ComponentScan.Filter(
classes = { CteUserRepository.class, UserRepository.class, RoleRepository.class },
type = FilterType.ASSIGNABLE_TYPE))
static class TestConfig {}

interface CteUserRepository extends CrudRepository<User, Integer> {

/*
WITH entities AS (
SELECT
e.id as id,
e.number as number
FROM TestEntity e
)
SELECT new com.example.demo.Result('X', c.id, c.number)
FROM entities c
*/

@Query("""
WITH cte_select AS (select u.firstname as firstname, u.lastname as lastname from User u)
SELECT new org.springframework.data.jpa.repository.UserExcerptDto(c.firstname, c.lastname)
FROM cte_select c
""")
Page<UserExcerptDto> findWithCTE(Pageable page);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.jpa.repository;

/**
* Hibernate is still a bit picky on records so let's use a class, just in case.
*
* @author Christoph Strobl
*/
public class UserExcerptDto {

private String firstname;
private String lastname;

public UserExcerptDto(String firstname, String lastname) {
this.firstname = firstname;
this.lastname = lastname;
}

public String getFirstname() {
return firstname;
}

public void setFirstname(String firstname) {
this.firstname = firstname;
}

public String getLastname() {
return lastname;
}

public void setLastname(String lastname) {
this.lastname = lastname;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,11 @@ void nullFirstLastSorting() {

assertThat(createQueryFor(original, Sort.unsorted())).isEqualTo(original);

assertThat(createQueryFor(original, Sort.by(Order.desc("lastName").nullsLast())))
.startsWith(original)
.endsWithIgnoringCase("e.lastName DESC NULLS LAST");
assertThat(createQueryFor(original, Sort.by(Order.desc("lastName").nullsLast()))).startsWith(original)
.endsWithIgnoringCase("e.lastName DESC NULLS LAST");

assertThat(createQueryFor(original, Sort.by(Order.desc("lastName").nullsFirst())))
.startsWith(original)
.endsWithIgnoringCase("e.lastName DESC NULLS FIRST");
assertThat(createQueryFor(original, Sort.by(Order.desc("lastName").nullsFirst()))).startsWith(original)
.endsWithIgnoringCase("e.lastName DESC NULLS FIRST");
}

@Test
Expand Down Expand Up @@ -151,6 +149,24 @@ void applyCountToAlreadySortedQuery() {
assertThat(results).isEqualTo("SELECT count(e) FROM Employee e where e.name = :name");
}

@Test // GH-3726
void shouldCreateCountQueryForCTE() {

// given
var original = """
WITH cte_select AS (select u.firstname as firstname, u.lastname as lastname from User u)
SELECT new org.springframework.data.jpa.repository.sample.UserExcerptDto(c.firstname, c.lastname)
FROM cte_select c
""";

// when
var results = createCountQueryFor(original);

// then
assertThat(results).isEqualToIgnoringWhitespace(
"WITH cte_select AS (select u.firstname as firstname, u.lastname as lastname from User u) SELECT count(*) FROM cte_select c");
}

@Test
void multipleAliasesShouldBeGathered() {

Expand Down Expand Up @@ -539,7 +555,7 @@ WITH maxId AS(select max(sr.snapshot.id) snapshotId from SnapshotReference sr
""");

assertThat(countQuery).startsWith("WITH maxId AS (select max(sr.snapshot.id) snapshotId from SnapshotReference sr")
.endsWith("select count(m) from maxId m join SnapshotReference sr on sr.snapshot.id = m.snapshotId");
.endsWith("select count(*) from maxId m join SnapshotReference sr on sr.snapshot.id = m.snapshotId");
}

@Test // GH-3504
Expand Down Expand Up @@ -1039,8 +1055,7 @@ select max(id), col
""", """
delete MyEntity AS mes
where mes.col = 'test'
"""
}) // GH-2977, GH-3649
""" }) // GH-2977, GH-3649
void isSubqueryThrowsException(String query) {
assertThat(createQueryFor(query, Sort.unsorted())).isEqualToIgnoringWhitespace(query);
}
Expand Down Expand Up @@ -1101,7 +1116,8 @@ void createsCountQueryUsingAliasCorrectly() {
"select count(distinct a, b, sum(amount), d) from Employee AS __ GROUP BY n");
assertCountQuery("select distinct a, count(b) as c from Employee GROUP BY n",
"select count(distinct a, count(b)) from Employee AS __ GROUP BY n");
assertCountQuery("select distinct substring(e.firstname, 1, position('a' in e.lastname)) as x from from Employee", "select count(distinct substring(e.firstname, 1, position('a' in e.lastname))) from from Employee");
assertCountQuery("select distinct substring(e.firstname, 1, position('a' in e.lastname)) as x from from Employee",
"select count(distinct substring(e.firstname, 1, position('a' in e.lastname))) from from Employee");
}

@Test // GH-3427
Expand Down

0 comments on commit 643a3a9

Please sign in to comment.