Skip to content

Commit e296d82

Browse files
authored
AddData and RemoveData for DbContext (#482)
1 parent 54b35b1 commit e296d82

5 files changed

Lines changed: 77 additions & 48 deletions

File tree

src/Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<Project>
33
<PropertyGroup>
44
<NoWarn>CS1591;CA1416;CS8632</NoWarn>
5-
<Version>15.3.0</Version>
5+
<Version>15.4.0</Version>
66
<AssemblyVersion>1.0.0</AssemblyVersion>
77
<ContinuousIntegrationBuild>false</ContinuousIntegrationBuild>
88
<CheckEolTargetFramework>false</CheckEolTargetFramework>
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
namespace EfLocalDb;
2+
3+
public static partial class DbContextExtensions
4+
{
5+
internal static IEnumerable<object> ExpandEnumerable(IEnumerable<object> entities, IReadOnlyList<IEntityType> entityTypes)
6+
{
7+
foreach (var entity in entities)
8+
{
9+
if (entity is IEnumerable enumerable)
10+
{
11+
var entityType = entity.GetType();
12+
if (entityTypes.Any(_ => _.ClrType != entityType))
13+
{
14+
foreach (var nested in enumerable)
15+
{
16+
yield return nested;
17+
}
18+
19+
continue;
20+
}
21+
}
22+
23+
yield return entity;
24+
}
25+
}
26+
27+
public static Task AddData<TDbContext>(this TDbContext context, IEnumerable<object> entities)
28+
where TDbContext : DbContext =>
29+
context.AddData(entities, context.Model.GetEntityTypes().ToArray());
30+
31+
public static Task AddData<TDbContext>(this TDbContext context, params object[] entities)
32+
where TDbContext : DbContext =>
33+
context.AddData((IEnumerable<object>) entities);
34+
35+
internal static Task AddData<TDbContext>(this TDbContext context, IEnumerable<object> entities, IReadOnlyList<IEntityType> entityTypes)
36+
where TDbContext : DbContext
37+
{
38+
foreach (var entity in ExpandEnumerable(entities, entityTypes))
39+
{
40+
context.Add(entity);
41+
}
42+
43+
return context.SaveChangesAsync();
44+
}
45+
46+
public static Task RemoveData<TDbContext>(this TDbContext context, IEnumerable<object> entities)
47+
where TDbContext : DbContext =>
48+
context.RemoveData(entities, context.Model.GetEntityTypes().ToArray());
49+
50+
public static Task RemoveData<TDbContext>(this TDbContext context, params object[] entities)
51+
where TDbContext : DbContext =>
52+
context.RemoveData((IEnumerable<object>) entities);
53+
54+
internal static Task RemoveData<TDbContext>(this TDbContext context, IEnumerable<object> entities, IReadOnlyList<IEntityType> entityTypes)
55+
where TDbContext : DbContext
56+
{
57+
foreach (var entity in ExpandEnumerable(entities, entityTypes))
58+
{
59+
context.Remove(entity);
60+
}
61+
62+
return context.SaveChangesAsync();
63+
}
64+
}

src/EfLocalDb/SqlDatabase.cs

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -141,25 +141,6 @@ public async Task Delete()
141141
public DbSet<T> Set<T>()
142142
where T : class => NoTrackingContext.Set<T>();
143143

144-
IEnumerable<object> ExpandEnumerable(IEnumerable<object> entities)
145-
{
146-
foreach (var entity in entities)
147-
{
148-
if (entity is IEnumerable enumerable)
149-
{
150-
var entityType = entity.GetType();
151-
if (instance.EntityTypes.Any(_ => _.ClrType != entityType))
152-
{
153-
foreach (var nested in enumerable)
154-
{
155-
yield return nested;
156-
}
157-
158-
continue;
159-
}
160-
}
161-
162-
yield return entity;
163-
}
164-
}
144+
IEnumerable<object> ExpandEnumerable(IEnumerable<object> entities) =>
145+
DbContextExtensions.ExpandEnumerable(entities, instance.EntityTypes);
165146
}

src/EfLocalDb/SqlDatabase_Add.cs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,16 @@
22

33
public partial class SqlDatabase<TDbContext>
44
{
5-
public Task AddData(IEnumerable<object> entities) => Add(entities, Context);
5+
public Task AddData(IEnumerable<object> entities) =>
6+
Context.AddData(entities, instance.EntityTypes);
67

7-
Task Add(IEnumerable<object> entities, TDbContext context)
8-
{
9-
foreach (var entity in ExpandEnumerable(entities))
10-
{
11-
context.Add(entity);
12-
}
13-
14-
return context.SaveChangesAsync();
15-
}
16-
17-
public Task AddData(params object[] entities) => AddData((IEnumerable<object>) entities);
8+
public Task AddData(params object[] entities) =>
9+
AddData((IEnumerable<object>) entities);
1810

1911
public async Task AddDataUntracked(IEnumerable<object> entities)
2012
{
2113
await using var context = NewDbContext();
22-
await Add(entities, context);
14+
await context.AddData(entities, instance.EntityTypes);
2315
}
2416

2517
public Task AddDataUntracked(params object[] entities) => AddDataUntracked((IEnumerable<object>) entities);

src/EfLocalDb/SqlDatabase_Remove.cs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,16 @@
22

33
public partial class SqlDatabase<TDbContext>
44
{
5-
public Task RemoveData(IEnumerable<object> entities) => Remove(entities, Context);
5+
public Task RemoveData(IEnumerable<object> entities) =>
6+
Context.RemoveData(entities, instance.EntityTypes);
67

7-
Task Remove(IEnumerable<object> entities, TDbContext context)
8-
{
9-
foreach (var entity in ExpandEnumerable(entities))
10-
{
11-
context.Remove(entity);
12-
}
13-
14-
return context.SaveChangesAsync();
15-
}
16-
17-
public Task RemoveData(params object[] entities) => RemoveData((IEnumerable<object>) entities);
8+
public Task RemoveData(params object[] entities) =>
9+
RemoveData((IEnumerable<object>) entities);
1810

1911
public async Task RemoveDataUntracked(IEnumerable<object> entities)
2012
{
2113
await using var context = NewDbContext();
22-
await Remove(entities, context);
14+
await context.RemoveData(entities, instance.EntityTypes);
2315
}
2416

2517
public Task RemoveDataUntracked(params object[] entities) => RemoveDataUntracked((IEnumerable<object>) entities);

0 commit comments

Comments
 (0)