Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 51 additions & 29 deletions SDMeta/Cache/SqliteDataSource.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using Dapper;
using Dapper;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;

namespace SDMeta.Cache
{
Expand All @@ -12,6 +13,7 @@ public partial class SqliteDataSource : IImageFileDataSource
const string TableName = "PngFilesv2";
private string FTSTableName = $"FTS5{TableName}";
private SqliteTransaction? transaction;
private readonly Lock transactionLock = new();

private readonly string[] columns =
[
Expand Down Expand Up @@ -71,17 +73,19 @@ private SqliteConnection GetConnection()
return connection;
}

private T ExecuteOnConnection<T>(Func<SqliteConnection, T> func)
private T ExecuteOnConnection<T>(Func<SqliteConnection, SqliteTransaction?, T> func)
{
if (this.transaction?.Connection != null)
lock (transactionLock)
{
return func(transaction.Connection);
}
else
{
using var connection = GetConnection();
return func(connection);
var currentTransaction = this.transaction;
if (currentTransaction?.Connection != null)
{
return func(currentTransaction.Connection, currentTransaction);
}
}

using var connection = GetConnection();
return func(connection, null);
}

private string GetInsertSql()
Expand All @@ -102,11 +106,11 @@ public void Initialize()
var tabledef = GetTableDefinition();

// Setup table if absent https://learn.microsoft.com/en-us/dotnet/standard/data/sqlite/types
ExecuteOnConnection(connection => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} (
ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} (
{tabledef.Select(p => $"{p.Column} {p.DataType}{(p.IsPrimaryKey ? " PRIMARY KEY" : "")}").ToCommaSeparated()}
);"));

ExecuteOnConnection(connection => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});"));
ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});"));
logger.LogInformation("Initalization completed");
}

Expand Down Expand Up @@ -135,7 +139,7 @@ public IEnumerable<ImageFileSummary> Query(QueryParams queryParams)
modelHash = queryParams.ModelFilter?.ModelHash,
};

var reader = ExecuteOnConnection(connection =>
var reader = ExecuteOnConnection((connection, _) =>
connection.Query<ImageFileSummary>(sql, param)
);
return reader;
Expand Down Expand Up @@ -212,7 +216,7 @@ private static string BuildOrderByClause(QuerySortBy querySort)

public ImageFile? ReadImageFile(string realFileName)
{
var reader = ExecuteOnConnection(connection => connection.QueryFirstOrDefault<DataRow>(
var reader = ExecuteOnConnection((connection, _) => connection.QueryFirstOrDefault<DataRow>(
$@"SELECT *
FROM {TableName}
WHERE FileName = @FileName
Expand All @@ -230,10 +234,10 @@ private static string BuildOrderByClause(QuerySortBy querySort)

public void WriteImageFile(ImageFile info)
{
ExecuteOnConnection(connection => connection.Execute(
ExecuteOnConnection((connection, tx) => connection.Execute(
insertSql.Value,
FromModel(info),
this.transaction
tx
));
}

Expand All @@ -257,25 +261,42 @@ private DataRow FromModel(ImageFile info)

public void BeginTransaction()
{
this.transaction ??= GetConnection().BeginTransaction();
lock (transactionLock)
{
this.transaction ??= GetConnection().BeginTransaction();
}
}

public void CommitTransaction()
{
if (this.transaction != null)
SqliteTransaction? transactionToCommit;

lock (transactionLock)
{
var connection = this.transaction.Connection;
this.transaction.Commit();
this.transaction.Dispose();
transactionToCommit = this.transaction;
this.transaction = null;
connection?.Close();
}

if (transactionToCommit == null)
{
return;
}

try
{
transactionToCommit.Commit();
}
finally
{
var connection = transactionToCommit.Connection;
transactionToCommit.Dispose();
connection?.Dispose();
}
}

public IEnumerable<ModelSummary> GetModelSummaryList()
{
var reader = ExecuteOnConnection(connection => connection.Query<ModelSummary>(
var reader = ExecuteOnConnection((connection, _) => connection.Query<ModelSummary>(
$@"SELECT Model, ModelHash, Count(*) as Count
FROM {TableName}
GROUP BY Model, ModelHash
Expand All @@ -287,7 +308,7 @@ ORDER BY 3 DESC"

public IEnumerable<string> GetAllFilenames()
{
var reader = ExecuteOnConnection(connection => connection.Query<string>(
var reader = ExecuteOnConnection((connection, _) => connection.Query<string>(
$@"SELECT Filename
FROM {TableName}
WHERE [Exists] = 1"
Expand All @@ -298,28 +319,28 @@ public IEnumerable<string> GetAllFilenames()

public void Truncate()
{
ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {TableName}"));
ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {FTSTableName}"));
ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {TableName}"));
ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {FTSTableName}"));
}

public void PostUpdateProcessing()
{
ExecuteOnConnection(connection =>
ExecuteOnConnection((connection, tx) =>
connection.Execute(
$@"INSERT INTO {FTSTableName} (FileName, Prompt, PromptFormat, Version)
SELECT FileName, Prompt, PromptFormat, Version FROM {TableName}
WHERE FileName NOT IN (SELECT FileName from {FTSTableName})",
this.transaction));
tx));

ExecuteOnConnection(connection =>
ExecuteOnConnection((connection, tx) =>
connection.Execute(
$@"UPDATE {FTSTableName} SET
Prompt = p.Prompt,
PromptFormat = p.PromptFormat,
Version = p.Version
FROM {TableName} p
WHERE {FTSTableName}.FileName = p.FileName and {FTSTableName}.Version != p.Version",
this.transaction));
tx));
}
}

Expand All @@ -333,3 +354,4 @@ public static string ToCommaSeparated(this IEnumerable<string> list)

internal record struct ColumnDefinition(string Column, string Parameter, string DataType, bool IsPrimaryKey);
}

Loading