diff --git a/SDMeta/Cache/SqliteDataSource.cs b/SDMeta/Cache/SqliteDataSource.cs index cee6beb..0c52edf 100644 --- a/SDMeta/Cache/SqliteDataSource.cs +++ b/SDMeta/Cache/SqliteDataSource.cs @@ -1,9 +1,10 @@ -using Dapper; +using Dapper; using Microsoft.Data.Sqlite; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; using System.Linq; +using System.Threading; namespace SDMeta.Cache { @@ -12,6 +13,7 @@ public partial class SqliteDataSource : IImageFileDataSource const string TableName = "PngFilesv2"; private string FTSTableName = $"FTS5{TableName}"; private SqliteTransaction? transaction; + private readonly Lock transactionLock = new(); private readonly string[] columns = [ @@ -71,17 +73,19 @@ private SqliteConnection GetConnection() return connection; } - private T ExecuteOnConnection(Func func) + private T ExecuteOnConnection(Func func) { - if (this.transaction?.Connection != null) + lock (transactionLock) { - return func(transaction.Connection); - } - else - { - using var connection = GetConnection(); - return func(connection); + var currentTransaction = this.transaction; + if (currentTransaction?.Connection != null) + { + return func(currentTransaction.Connection, currentTransaction); + } } + + using var connection = GetConnection(); + return func(connection, null); } private string GetInsertSql() @@ -102,11 +106,11 @@ public void Initialize() var tabledef = GetTableDefinition(); // Setup table if absent https://learn.microsoft.com/en-us/dotnet/standard/data/sqlite/types - ExecuteOnConnection(connection => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} ( + ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} ( {tabledef.Select(p => $"{p.Column} {p.DataType}{(p.IsPrimaryKey ? " PRIMARY KEY" : "")}").ToCommaSeparated()} );")); - ExecuteOnConnection(connection => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});")); + ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});")); logger.LogInformation("Initalization completed"); } @@ -135,7 +139,7 @@ public IEnumerable Query(QueryParams queryParams) modelHash = queryParams.ModelFilter?.ModelHash, }; - var reader = ExecuteOnConnection(connection => + var reader = ExecuteOnConnection((connection, _) => connection.Query(sql, param) ); return reader; @@ -212,7 +216,7 @@ private static string BuildOrderByClause(QuerySortBy querySort) public ImageFile? ReadImageFile(string realFileName) { - var reader = ExecuteOnConnection(connection => connection.QueryFirstOrDefault( + var reader = ExecuteOnConnection((connection, _) => connection.QueryFirstOrDefault( $@"SELECT * FROM {TableName} WHERE FileName = @FileName @@ -230,10 +234,10 @@ private static string BuildOrderByClause(QuerySortBy querySort) public void WriteImageFile(ImageFile info) { - ExecuteOnConnection(connection => connection.Execute( + ExecuteOnConnection((connection, tx) => connection.Execute( insertSql.Value, FromModel(info), - this.transaction + tx )); } @@ -257,25 +261,42 @@ private DataRow FromModel(ImageFile info) public void BeginTransaction() { - this.transaction ??= GetConnection().BeginTransaction(); + lock (transactionLock) + { + this.transaction ??= GetConnection().BeginTransaction(); + } } public void CommitTransaction() { - if (this.transaction != null) + SqliteTransaction? transactionToCommit; + + lock (transactionLock) { - var connection = this.transaction.Connection; - this.transaction.Commit(); - this.transaction.Dispose(); + transactionToCommit = this.transaction; this.transaction = null; - connection?.Close(); + } + + if (transactionToCommit == null) + { + return; + } + + try + { + transactionToCommit.Commit(); + } + finally + { + var connection = transactionToCommit.Connection; + transactionToCommit.Dispose(); connection?.Dispose(); } } public IEnumerable GetModelSummaryList() { - var reader = ExecuteOnConnection(connection => connection.Query( + var reader = ExecuteOnConnection((connection, _) => connection.Query( $@"SELECT Model, ModelHash, Count(*) as Count FROM {TableName} GROUP BY Model, ModelHash @@ -287,7 +308,7 @@ ORDER BY 3 DESC" public IEnumerable GetAllFilenames() { - var reader = ExecuteOnConnection(connection => connection.Query( + var reader = ExecuteOnConnection((connection, _) => connection.Query( $@"SELECT Filename FROM {TableName} WHERE [Exists] = 1" @@ -298,20 +319,20 @@ public IEnumerable GetAllFilenames() public void Truncate() { - ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {TableName}")); - ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {FTSTableName}")); + ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {TableName}")); + ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {FTSTableName}")); } public void PostUpdateProcessing() { - ExecuteOnConnection(connection => + ExecuteOnConnection((connection, tx) => connection.Execute( $@"INSERT INTO {FTSTableName} (FileName, Prompt, PromptFormat, Version) SELECT FileName, Prompt, PromptFormat, Version FROM {TableName} WHERE FileName NOT IN (SELECT FileName from {FTSTableName})", - this.transaction)); + tx)); - ExecuteOnConnection(connection => + ExecuteOnConnection((connection, tx) => connection.Execute( $@"UPDATE {FTSTableName} SET Prompt = p.Prompt, @@ -319,7 +340,7 @@ WHERE FileName NOT IN (SELECT FileName from {FTSTableName})", Version = p.Version FROM {TableName} p WHERE {FTSTableName}.FileName = p.FileName and {FTSTableName}.Version != p.Version", - this.transaction)); + tx)); } } @@ -333,3 +354,4 @@ public static string ToCommaSeparated(this IEnumerable list) internal record struct ColumnDefinition(string Column, string Parameter, string DataType, bool IsPrimaryKey); } +