diff --git a/SqlMarshal.CompilationTests/DbContextManager.cs b/SqlMarshal.CompilationTests/DbContextManager.cs index da4d3e0..e296fc3 100644 --- a/SqlMarshal.CompilationTests/DbContextManager.cs +++ b/SqlMarshal.CompilationTests/DbContextManager.cs @@ -50,4 +50,7 @@ public DbContextManager(PersonDbContext context) [SqlMarshal("persons_by_id")] public partial PersonDbContext.Person GetPersonById(int personId); + + [SqlMarshal("users_list")] + public partial IList GetUsers(); } diff --git a/SqlMarshal.CompilationTests/PersonDbContext.cs b/SqlMarshal.CompilationTests/PersonDbContext.cs index af841f4..0225780 100644 --- a/SqlMarshal.CompilationTests/PersonDbContext.cs +++ b/SqlMarshal.CompilationTests/PersonDbContext.cs @@ -18,6 +18,8 @@ public PersonDbContext(DbContextOptions options) public DbSet Persons { get; set; } = null!; + public DbSet Users { get; set; } = null!; + internal class Person { [Column("person_id")] @@ -26,5 +28,14 @@ internal class Person [Column("person_name")] public string? PersonName { get; set; } } + + internal class User + { + [Column("user_id")] + public int UserId { get; set; } + + [Column("user_name")] + public string? UserName { get; set; } + } } } diff --git a/SqlMarshal.CompilationTests/sqlmarshal_sample.sql b/SqlMarshal.CompilationTests/sqlmarshal_sample.sql index a0689a4..c02257c 100644 --- a/SqlMarshal.CompilationTests/sqlmarshal_sample.sql +++ b/SqlMarshal.CompilationTests/sqlmarshal_sample.sql @@ -41,3 +41,15 @@ begin end GO +CREATE TABLE [user] ( +user_id int not null identity primary key, +user_name nvarchar(100) null +) +GO + + + +CREATE OR ALTER PROCEDURE users_list +AS +SELECT * from [user] +GO \ No newline at end of file diff --git a/SqlMarshal.Tests/StoredProcedureGenerationTests.cs b/SqlMarshal.Tests/StoredProcedureGenerationTests.cs index c9e3103..9db2857 100644 --- a/SqlMarshal.Tests/StoredProcedureGenerationTests.cs +++ b/SqlMarshal.Tests/StoredProcedureGenerationTests.cs @@ -473,6 +473,76 @@ public partial IList M(string clientId, string personId) Assert.AreEqual(expectedOutput, output); } + [TestMethod] + public void DbSetNameFoundFromClass2_WithNullable() + { + string source = @" +namespace Foo +{ + public partial class CustomDbContext : DbContext + { + public virtual DbSet? Items { get; set; } = null!; + public virtual DbSet? Persons { get; set; } = null!; + } + + class C + { + private readonly CustomDbContext context; + + [SqlMarshal(""sp_TestSP"")] + public partial IList M(string? clientId, string personId) + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Enable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Data.Common; + using System.Linq; + using Microsoft.EntityFrameworkCore; + using Microsoft.EntityFrameworkCore.Storage; + + partial class C + { + public partial IList M(string? clientId, string personId) + { + var connection = this.context.Database.GetDbConnection(); + using var command = connection.CreateCommand(); + + var clientIdParameter = command.CreateParameter(); + clientIdParameter.ParameterName = ""@client_id""; + clientIdParameter.Value = clientId == null ? (object)DBNull.Value : clientId; + + var personIdParameter = command.CreateParameter(); + personIdParameter.ParameterName = ""@person_id""; + personIdParameter.Value = personId; + + var parameters = new DbParameter[] + { + clientIdParameter, + personIdParameter, + }; + + var sqlQuery = @""sp_TestSP @client_id, @person_id""; + var result = this.context.Persons!.FromSqlRaw(sqlQuery, parameters).ToList(); + return result; + } + } +}"; + Assert.AreEqual(expectedOutput, output); + } + [TestMethod] public void NonReferenceParameterPassedDirectlyToStoredProcedure() { diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs index 25fc021..ce0668d 100644 --- a/SqlMarshal/Generator.cs +++ b/SqlMarshal/Generator.cs @@ -112,7 +112,7 @@ private static string GetAccessibility(Accessibility a) }; } - private static ISymbol? GetDbSetField(IFieldSymbol? dbContextSymbol, ITypeSymbol itemTypeSymbol) + private static IPropertySymbol? GetDbSetField(IFieldSymbol? dbContextSymbol, ITypeSymbol itemTypeSymbol) { if (dbContextSymbol == null) { @@ -695,17 +695,19 @@ private void MapResults( { var dbContextSymbol = methodGenerationContext.ClassGenerationContext.DbContextField; var contextName = methodGenerationContext.ClassGenerationContext.DbContextName; - var itemTypeProperty = GetDbSetField(dbContextSymbol, itemType)?.Name ?? itemType.Name + "s"; + var dbsetField = GetDbSetField(dbContextSymbol, itemType); + var itemTypeProperty = dbsetField?.Name ?? itemType.Name + "s"; + var nullableAnnotations = dbsetField?.NullableAnnotation == NullableAnnotation.Annotated && methodGenerationContext.ClassGenerationContext.NullableContextOptions.AnnotationsEnabled() ? "!" : string.Empty; if (isTask) { if (isList) { - source.AppendLine($"var result = await this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).ToListAsync({cancellationToken}).ConfigureAwait(false);"); + source.AppendLine($"var result = await this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).ToListAsync({cancellationToken}).ConfigureAwait(false);"); } else { source.AppendLine($"{itemType} result = null!;"); - source.AppendLine($"var asyncEnumerable = this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).AsAsyncEnumerable();"); + source.AppendLine($"var asyncEnumerable = this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).AsAsyncEnumerable();"); source.AppendLine($"await foreach (var current in asyncEnumerable)"); source.AppendLine("{"); source.PushIndent(); @@ -721,7 +723,7 @@ private void MapResults( string materializeResults = isList ? "ToList" : methodGenerationContext.ClassGenerationContext.NullableContextOptions == NullableContextOptions.Enable ? "AsEnumerable().First" : "AsEnumerable().FirstOrDefault"; - source.AppendLine($"var result = this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{parameterString}).{materializeResults}();"); + source.AppendLine($"var result = this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{parameterString}).{materializeResults}();"); } } }