diff --git a/slick_reporting/generator.py b/slick_reporting/generator.py index 1f14027..7cd2dcc 100644 --- a/slick_reporting/generator.py +++ b/slick_reporting/generator.py @@ -244,7 +244,11 @@ def __init__(self, report_model=None, main_queryset=None, start_date=None, end_d self.main_queryset = self._apply_queryset_options(main_queryset) if type(self.group_by_field) is ForeignKey and '__' not in self.group_by: ids = self.main_queryset.values_list(self.group_by_field_attname).distinct() - self.main_queryset = self.group_by_field.related_model.objects.filter(pk__in=ids).values() + # uses the same logic that is in Django's query.py when fields is empty in values() call + concrete_fields = [f.name for f in self.group_by_field.related_model._meta.concrete_fields] + # add database columns that are not already in concrete_fields + final_fields = concrete_fields + list(set(self.get_database_columns()) - set(concrete_fields)) + self.main_queryset = self.group_by_field.related_model.objects.filter(pk__in=ids).values(*final_fields) else: self.main_queryset = self.main_queryset.distinct().values(self.group_by_field_attname) else: @@ -478,7 +482,7 @@ def _parse(self): self._crosstab_parsed_columns = self.get_crosstab_parsed_columns() def get_database_columns(self): - return [col['name'] for col in self.parsed_columns if col['source'] == 'database'] + return [col['name'] for col in self.parsed_columns if 'source' in col and col['source'] == 'database'] # def get_method_columns(self): # return [col['name'] for col in self.parsed_columns if col['type'] == 'method'] diff --git a/tests/models.py b/tests/models.py index 218705d..471d0be 100644 --- a/tests/models.py +++ b/tests/models.py @@ -22,6 +22,8 @@ class Meta: verbose_name = _('Product') verbose_name_plural = _('Products') +class Contact(models.Model): + address = models.CharField(max_length=200, verbose_name=_('Name')) class Client(models.Model): slug = models.CharField(max_length=200, verbose_name=_('Client Slug')) @@ -29,6 +31,7 @@ class Client(models.Model): name = models.CharField(max_length=200, verbose_name=_('Name')) email = models.EmailField(blank=True) notes = models.TextField() + contact = models.ForeignKey(Contact, on_delete=models.CASCADE, null=True) class Meta: verbose_name = _('Client') diff --git a/tests/test_generator.py b/tests/test_generator.py index f56a882..2c8333d 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -13,7 +13,7 @@ TimeSeriesCustomDates from .tests import BaseTestData, year -from .models import SimpleSales +from .models import SimpleSales, Client class MatrixTests(BaseTestData, TestCase): @@ -208,6 +208,35 @@ def test_group_by_traverse(self): self.assertEqual(data[0]['product__category'], 'small') self.assertEqual(data[1]['product__category'], 'big') + def test_group_by_and_foreign_key_field(self): + report = ReportGenerator(report_model=SimpleSales, group_by='client', + columns=['name', 'contact_id', 'contact__address', SlickReportField.create(Sum, 'value'), '__total__'], + # time_series_pattern='monthly', + date_field='doc_date', + # time_series_columns=['__debit__', '__credit__', '__balance__', '__total__'] + ) + + self.assertTrue(report._report_fields_dependencies) + data = report.get_report_data() + # import pdb; + # pdb.set_trace() + self.assertNotEqual(data, []) + self.assertEqual(data[0]['name'], 'Client 1') + self.assertEqual(data[1]['name'], 'Client 2') + self.assertEqual(data[2]['name'], 'Client 3') + + self.assertEqual(data[0]['contact_id'], 1) + self.assertEqual(data[1]['contact_id'], 2) + self.assertEqual(data[2]['contact_id'], 3) + + self.assertEqual(data[0]['sum__value'], 300) + + self.assertEqual(Client.objects.get(pk=1).contact.address, 'Street 1') + self.assertEqual(data[0]['contact__address'], 'Street 1') + self.assertEqual(data[1]['contact__address'], 'Street 2') + self.assertEqual(data[2]['contact__address'], 'Street 3') + + def test_db_field_column_verbose_name(self): report = GenericGenerator() field_list = report.get_list_display_columns() diff --git a/tests/tests.py b/tests/tests.py index 6cc633c..5b53f6c 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -10,7 +10,7 @@ from slick_reporting.fields import SlickReportField, BalanceReportField from tests.report_generators import ClientTotalBalance, ProductClientSalesMatrix2, GroupByCharField, \ GroupByCharFieldPlusTimeSeries, TimeSeriesWithOutGroupBy -from .models import Client, Product, SimpleSales, OrderLine, UserJoined, SalesWithFlag, ComplexSales, TaxCode +from .models import Client, Contact, Product, SimpleSales, OrderLine, UserJoined, SalesWithFlag, ComplexSales, TaxCode from . import report_generators from slick_reporting.registry import field_registry @@ -36,8 +36,14 @@ def setUpTestData(cls): cls.user = user cls.limited_user = limited_user cls.client1 = Client.objects.create(name='Client 1') + cls.client1.contact = Contact.objects.create(address='Street 1') + cls.client1.save() cls.client2 = Client.objects.create(name='Client 2') + cls.client2.contact = Contact.objects.create(address='Street 2') + cls.client2.save() cls.client3 = Client.objects.create(name='Client 3') + cls.client3.contact = Contact.objects.create(address='Street 3') + cls.client3.save() cls.clientIdle = Client.objects.create(name='Client Idle') cls.product1 = Product.objects.create(name='Product 1', category='small')