diff --git a/.gitignore b/.gitignore index 555f10f..7f72316 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ pip-log.txt # Mac crap .DS_Store packages +.vs/ diff --git a/EntityFramework.Utilities/EntityFramework.Utilities/EFBatchOperation.cs b/EntityFramework.Utilities/EntityFramework.Utilities/EFBatchOperation.cs index 72fa98c..89961ec 100644 --- a/EntityFramework.Utilities/EntityFramework.Utilities/EFBatchOperation.cs +++ b/EntityFramework.Utilities/EntityFramework.Utilities/EFBatchOperation.cs @@ -54,8 +54,8 @@ public UpdateSpecification ColumnsToUpdate(params Expression> public interface IEFBatchOperationFiltered { - int Delete(); - int Update(Expression> prop, Expression> modifier); + int Delete(DbConnection connection = null); + int Update(Expression> prop, Expression> modifier, DbConnection connection = null); } public static class EFBatchOperation { @@ -181,16 +181,17 @@ public IEFBatchOperationFiltered Where(Expression> pr return this; } - public int Delete() + public int Delete(DbConnection connection = null) { var con = context.Connection as EntityConnection; - if (con == null) + if (con == null && connection == null) { Configuration.Log("No provider could be found because the Connection didn't implement System.Data.EntityClient.EntityConnection"); return Fallbacks.DefaultDelete(context, this.predicate); } + var connectionToUse = connection ?? con.StoreConnection; - var provider = Configuration.Providers.FirstOrDefault(p => p.CanHandle(con.StoreConnection)); + var provider = Configuration.Providers.FirstOrDefault(p => p.CanHandle(connectionToUse)); if (provider != null && provider.CanDelete) { var set = context.CreateObjectSet(); @@ -203,21 +204,22 @@ public int Delete() } else { - Configuration.Log("Found provider: " + (provider == null ? "[]" : provider.GetType().Name ) + " for " + con.StoreConnection.GetType().Name); + Configuration.Log("Found provider: " + (provider == null ? "[]" : provider.GetType().Name) + " for " + connectionToUse.GetType().Name); return Fallbacks.DefaultDelete(context, this.predicate); } } - public int Update(Expression> prop, Expression> modifier) + public int Update(Expression> prop, Expression> modifier, DbConnection connection = null) { var con = context.Connection as EntityConnection; - if (con == null) + if (con == null && connection == null) { Configuration.Log("No provider could be found because the Connection didn't implement System.Data.EntityClient.EntityConnection"); return Fallbacks.DefaultUpdate(context, this.predicate, prop, modifier); } + var connectionToUse = connection ?? con.StoreConnection; - var provider = Configuration.Providers.FirstOrDefault(p => p.CanHandle(con.StoreConnection)); + var provider = Configuration.Providers.FirstOrDefault(p => p.CanHandle(connectionToUse)); if (provider != null && provider.CanUpdate) { var set = context.CreateObjectSet(); @@ -241,7 +243,7 @@ public int Update(Expression> prop, Expression> modi } else { - Configuration.Log("Found provider: " + (provider == null ? "[]" : provider.GetType().Name) + " for " + con.StoreConnection.GetType().Name); + Configuration.Log("Found provider: " + (provider == null ? "[]" : provider.GetType().Name) + " for " + connectionToUse.GetType().Name); return Fallbacks.DefaultUpdate(context, this.predicate, prop, modifier); } } diff --git a/EntityFramework.Utilities/EntityFramework.Utilities/Properties/AssemblyInfo.cs b/EntityFramework.Utilities/EntityFramework.Utilities/Properties/AssemblyInfo.cs index 08caae2..cb87fb0 100644 --- a/EntityFramework.Utilities/EntityFramework.Utilities/Properties/AssemblyInfo.cs +++ b/EntityFramework.Utilities/EntityFramework.Utilities/Properties/AssemblyInfo.cs @@ -6,7 +6,7 @@ // set of attributes. Change these attribute values to modify the information // associated with an assembly. [assembly: AssemblyTitle("EntityFramework.Utilities")] -[assembly: AssemblyDescription("")] +[assembly: AssemblyDescription("Fork of https://github.com/MikaelEliasson/EntityFramework.Utilities")] [assembly: AssemblyConfiguration("")] [assembly: AssemblyCompany("")] [assembly: AssemblyProduct("EntityFramework.Utilities")] diff --git a/EntityFramework.Utilities/Tests/DeleteByQueryTest.cs b/EntityFramework.Utilities/Tests/DeleteByQueryTest.cs index 46321f0..13911c7 100644 --- a/EntityFramework.Utilities/Tests/DeleteByQueryTest.cs +++ b/EntityFramework.Utilities/Tests/DeleteByQueryTest.cs @@ -192,5 +192,40 @@ public void DeleteAll_NoProvider_UsesDefaultDelete() Assert.IsNotNull(fallbackText); } + + [TestMethod] + public void DeleteAll_PropertyEquals_WithExplicitConnection_DeletesAllMatchesAndNothingElse() + { + using (var db = Context.Sql()) + { + if (db.Database.Exists()) + { + db.Database.Delete(); + } + db.Database.Create(); + + db.BlogPosts.Add(BlogPost.Create("T1")); + db.BlogPosts.Add(BlogPost.Create("T2")); + db.BlogPosts.Add(BlogPost.Create("T2")); + db.BlogPosts.Add(BlogPost.Create("T3")); + + db.SaveChanges(); + } + + int count; + using (var db = Context.Sql()) + { + count = EFBatchOperation.For(db, db.BlogPosts).Where(b => b.Title == "T2").Delete(db.Database.Connection); + Assert.AreEqual(2, count); + } + + using (var db = Context.Sql()) + { + var posts = db.BlogPosts.ToList(); + Assert.AreEqual(2, posts.Count); + Assert.AreEqual(0, posts.Count(p => p.Title == "T2")); + } + } + } } diff --git a/EntityFramework.Utilities/Tests/UpdateByQueryTest.cs b/EntityFramework.Utilities/Tests/UpdateByQueryTest.cs index f968261..f6f354f 100644 --- a/EntityFramework.Utilities/Tests/UpdateByQueryTest.cs +++ b/EntityFramework.Utilities/Tests/UpdateByQueryTest.cs @@ -244,6 +244,25 @@ public void UpdateAll_NoProvider_UsesDefaultDelete() Assert.IsNotNull(fallbackText); } + [TestMethod] + public void UpdateAll_Increment_WithExplicitConnection() + { + SetupBasePosts(); + + int count; + using (var db = Context.Sql()) + { + count = EFBatchOperation.For(db, db.BlogPosts).Where(b => b.Title == "T2").Update(b => b.Reads, b => b.Reads + 5, db.Database.Connection); + Assert.AreEqual(1, count); + } + + using (var db = Context.Sql()) + { + var post = db.BlogPosts.First(p => p.Title == "T2"); + Assert.AreEqual(5, post.Reads); + } + } + private static void SetupBasePosts() { using (var db = Context.Sql())