feat(data): rewrite all repositories to use EF Core ClaudeDoDbContext
Replace raw ADO.NET implementations with EF Core LINQ queries and ExecuteUpdate/ExecuteDelete for bulk operations. TaskRepository preserves FlipAllRunningToFailedAsync(reason) signature and keeps raw SQL for the atomic queue claim (UPDATE...RETURNING). GetByListAsync alias kept for backwards compat. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,171 +1,146 @@
|
||||
using ClaudeDo.Data.Models;
|
||||
using Microsoft.Data.Sqlite;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using TaskStatus = ClaudeDo.Data.Models.TaskStatus;
|
||||
|
||||
namespace ClaudeDo.Data.Repositories;
|
||||
|
||||
public sealed class TaskRepository
|
||||
{
|
||||
private readonly SqliteConnectionFactory _factory;
|
||||
private readonly ClaudeDoDbContext _context;
|
||||
|
||||
public TaskRepository(SqliteConnectionFactory factory) => _factory = factory;
|
||||
|
||||
#region Status mapping
|
||||
|
||||
private static string ToDb(TaskStatus s) => s switch
|
||||
{
|
||||
TaskStatus.Manual => "manual",
|
||||
TaskStatus.Queued => "queued",
|
||||
TaskStatus.Running => "running",
|
||||
TaskStatus.Done => "done",
|
||||
TaskStatus.Failed => "failed",
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(s)),
|
||||
};
|
||||
|
||||
private static TaskStatus FromDb(string s) => s switch
|
||||
{
|
||||
"manual" => TaskStatus.Manual,
|
||||
"queued" => TaskStatus.Queued,
|
||||
"running" => TaskStatus.Running,
|
||||
"done" => TaskStatus.Done,
|
||||
"failed" => TaskStatus.Failed,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(s)),
|
||||
};
|
||||
|
||||
#endregion
|
||||
public TaskRepository(ClaudeDoDbContext context) => _context = context;
|
||||
|
||||
#region CRUD
|
||||
|
||||
public async Task AddAsync(TaskEntity entity, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
INSERT INTO tasks (id, list_id, title, description, status, scheduled_for,
|
||||
result, log_path, created_at, started_at, finished_at, commit_type,
|
||||
model, system_prompt, agent_path)
|
||||
VALUES (@id, @list_id, @title, @description, @status, @scheduled_for,
|
||||
@result, @log_path, @created_at, @started_at, @finished_at, @commit_type,
|
||||
@model, @system_prompt, @agent_path)
|
||||
""";
|
||||
BindTask(cmd, entity);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
_context.Tasks.Add(entity);
|
||||
await _context.SaveChangesAsync(ct);
|
||||
}
|
||||
|
||||
public async Task UpdateAsync(TaskEntity entity, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
UPDATE tasks SET list_id = @list_id, title = @title, description = @description,
|
||||
status = @status, scheduled_for = @scheduled_for, result = @result,
|
||||
log_path = @log_path, started_at = @started_at,
|
||||
finished_at = @finished_at, commit_type = @commit_type,
|
||||
model = @model, system_prompt = @system_prompt, agent_path = @agent_path
|
||||
WHERE id = @id
|
||||
""";
|
||||
BindTask(cmd, entity);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
_context.Tasks.Update(entity);
|
||||
await _context.SaveChangesAsync(ct);
|
||||
}
|
||||
|
||||
public async Task DeleteAsync(string taskId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "DELETE FROM tasks WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
await _context.Tasks.Where(t => t.Id == taskId).ExecuteDeleteAsync(ct);
|
||||
}
|
||||
|
||||
public async Task<TaskEntity?> GetByIdAsync(string taskId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "SELECT id, list_id, title, description, status, scheduled_for, result, log_path, created_at, started_at, finished_at, commit_type, model, system_prompt, agent_path FROM tasks WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
|
||||
await using var reader = await cmd.ExecuteReaderAsync(ct);
|
||||
if (!await reader.ReadAsync(ct)) return null;
|
||||
return ReadTask(reader);
|
||||
return await _context.Tasks.AsNoTracking().FirstOrDefaultAsync(t => t.Id == taskId, ct);
|
||||
}
|
||||
|
||||
public async Task<List<TaskEntity>> GetByListAsync(string listId, CancellationToken ct = default)
|
||||
public async Task<List<TaskEntity>> GetByListIdAsync(string listId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "SELECT id, list_id, title, description, status, scheduled_for, result, log_path, created_at, started_at, finished_at, commit_type, model, system_prompt, agent_path FROM tasks WHERE list_id = @list_id ORDER BY created_at";
|
||||
cmd.Parameters.AddWithValue("@list_id", listId);
|
||||
return await _context.Tasks
|
||||
.Where(t => t.ListId == listId)
|
||||
.OrderByDescending(t => t.CreatedAt)
|
||||
.ToListAsync(ct);
|
||||
}
|
||||
|
||||
await using var reader = await cmd.ExecuteReaderAsync(ct);
|
||||
var result = new List<TaskEntity>();
|
||||
while (await reader.ReadAsync(ct))
|
||||
result.Add(ReadTask(reader));
|
||||
return result;
|
||||
// Kept for backwards-compatibility with callers using the old name.
|
||||
public Task<List<TaskEntity>> GetByListAsync(string listId, CancellationToken ct = default)
|
||||
=> GetByListIdAsync(listId, ct);
|
||||
|
||||
#endregion
|
||||
|
||||
#region Status transitions
|
||||
|
||||
public async Task MarkRunningAsync(string taskId, DateTime startedAt, CancellationToken ct = default)
|
||||
{
|
||||
await _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.ExecuteUpdateAsync(s => s
|
||||
.SetProperty(t => t.Status, TaskStatus.Running)
|
||||
.SetProperty(t => t.StartedAt, startedAt), ct);
|
||||
}
|
||||
|
||||
public async Task MarkDoneAsync(string taskId, DateTime finishedAt, string? result, CancellationToken ct = default)
|
||||
{
|
||||
await _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.ExecuteUpdateAsync(s => s
|
||||
.SetProperty(t => t.Status, TaskStatus.Done)
|
||||
.SetProperty(t => t.FinishedAt, finishedAt)
|
||||
.SetProperty(t => t.Result, result), ct);
|
||||
}
|
||||
|
||||
public async Task MarkFailedAsync(string taskId, DateTime finishedAt, string? result, CancellationToken ct = default)
|
||||
{
|
||||
await _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.ExecuteUpdateAsync(s => s
|
||||
.SetProperty(t => t.Status, TaskStatus.Failed)
|
||||
.SetProperty(t => t.FinishedAt, finishedAt)
|
||||
.SetProperty(t => t.Result, result), ct);
|
||||
}
|
||||
|
||||
public async Task SetLogPathAsync(string taskId, string logPath, CancellationToken ct = default)
|
||||
{
|
||||
await _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.ExecuteUpdateAsync(s => s.SetProperty(t => t.LogPath, logPath), ct);
|
||||
}
|
||||
|
||||
public async Task<int> FlipAllRunningToFailedAsync(string reason, CancellationToken ct = default)
|
||||
{
|
||||
var resultText = "[stale] " + reason;
|
||||
var now = DateTime.UtcNow;
|
||||
return await _context.Tasks
|
||||
.Where(t => t.Status == TaskStatus.Running)
|
||||
.ExecuteUpdateAsync(s => s
|
||||
.SetProperty(t => t.Status, TaskStatus.Failed)
|
||||
.SetProperty(t => t.FinishedAt, now)
|
||||
.SetProperty(t => t.Result, resultText), ct);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Tag junction
|
||||
|
||||
public async Task<List<TagEntity>> GetTagsAsync(string taskId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
SELECT t.id, t.name FROM tags t
|
||||
JOIN task_tags tt ON tt.tag_id = t.id
|
||||
WHERE tt.task_id = @task_id
|
||||
""";
|
||||
cmd.Parameters.AddWithValue("@task_id", taskId);
|
||||
|
||||
await using var reader = await cmd.ExecuteReaderAsync(ct);
|
||||
var result = new List<TagEntity>();
|
||||
while (await reader.ReadAsync(ct))
|
||||
result.Add(new TagEntity { Id = reader.GetInt64(0), Name = reader.GetString(1) });
|
||||
return result;
|
||||
}
|
||||
#region Tags
|
||||
|
||||
public async Task AddTagAsync(string taskId, long tagId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "INSERT OR IGNORE INTO task_tags (task_id, tag_id) VALUES (@task_id, @tag_id)";
|
||||
cmd.Parameters.AddWithValue("@task_id", taskId);
|
||||
cmd.Parameters.AddWithValue("@tag_id", tagId);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
var task = await _context.Tasks.Include(t => t.Tags).FirstAsync(t => t.Id == taskId, ct);
|
||||
var tag = await _context.Tags.FindAsync([tagId], ct);
|
||||
if (tag is not null && !task.Tags.Any(t => t.Id == tagId))
|
||||
{
|
||||
task.Tags.Add(tag);
|
||||
await _context.SaveChangesAsync(ct);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task RemoveTagAsync(string taskId, long tagId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "DELETE FROM task_tags WHERE task_id = @task_id AND tag_id = @tag_id";
|
||||
cmd.Parameters.AddWithValue("@task_id", taskId);
|
||||
cmd.Parameters.AddWithValue("@tag_id", tagId);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
var task = await _context.Tasks.Include(t => t.Tags).FirstAsync(t => t.Id == taskId, ct);
|
||||
var tag = task.Tags.FirstOrDefault(t => t.Id == tagId);
|
||||
if (tag is not null)
|
||||
{
|
||||
task.Tags.Remove(tag);
|
||||
await _context.SaveChangesAsync(ct);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<List<TagEntity>> GetTagsAsync(string taskId, CancellationToken ct = default)
|
||||
{
|
||||
return await _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.SelectMany(t => t.Tags)
|
||||
.ToListAsync(ct);
|
||||
}
|
||||
|
||||
public async Task<List<TagEntity>> GetEffectiveTagsAsync(string taskId, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
SELECT DISTINCT t.id, t.name FROM tags t
|
||||
WHERE t.id IN (
|
||||
SELECT tag_id FROM task_tags WHERE task_id = @task_id
|
||||
UNION
|
||||
SELECT lt.tag_id FROM list_tags lt
|
||||
JOIN tasks tk ON tk.list_id = lt.list_id
|
||||
WHERE tk.id = @task_id
|
||||
)
|
||||
""";
|
||||
cmd.Parameters.AddWithValue("@task_id", taskId);
|
||||
|
||||
await using var reader = await cmd.ExecuteReaderAsync(ct);
|
||||
var result = new List<TagEntity>();
|
||||
while (await reader.ReadAsync(ct))
|
||||
result.Add(new TagEntity { Id = reader.GetInt64(0), Name = reader.GetString(1) });
|
||||
return result;
|
||||
var taskTags = _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.SelectMany(t => t.Tags);
|
||||
var listTags = _context.Tasks
|
||||
.Where(t => t.Id == taskId)
|
||||
.SelectMany(t => t.List.Tags);
|
||||
return await taskTags.Union(listTags).Distinct().ToListAsync(ct);
|
||||
}
|
||||
|
||||
#endregion
|
||||
@@ -174,146 +149,38 @@ public sealed class TaskRepository
|
||||
|
||||
public async Task<TaskEntity?> GetNextQueuedAgentTaskAsync(DateTime now, CancellationToken ct = default)
|
||||
{
|
||||
// Atomically claim the next queued agent task: the UPDATE flips its
|
||||
// status to 'running' in the same statement that returns its row,
|
||||
// eliminating the TOCTOU gap where two queue-loop iterations could
|
||||
// both select the same queued task before either marked it running.
|
||||
// The caller is responsible for populating started_at shortly after.
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
UPDATE tasks
|
||||
SET status = 'running'
|
||||
// Atomic queue claim: UPDATE + RETURNING in one statement prevents TOCTOU races.
|
||||
// Uses raw SQL because EF cannot express UPDATE...RETURNING.
|
||||
// Includes both task-level and list-level "agent" tag so lists tagged "agent"
|
||||
// automatically enqueue all their tasks without per-task tagging.
|
||||
// EF SQLite stores DateTime as "yyyy-MM-dd HH:mm:ss.fffffff" — use the same format for comparison.
|
||||
var nowStr = now.ToUniversalTime().ToString("yyyy-MM-dd HH:mm:ss.fffffff");
|
||||
var result = await _context.Tasks.FromSqlRaw("""
|
||||
UPDATE tasks SET status = 'running'
|
||||
WHERE id = (
|
||||
SELECT t.id
|
||||
FROM tasks t
|
||||
SELECT t.id FROM tasks t
|
||||
WHERE t.status = 'queued'
|
||||
AND (t.scheduled_for IS NULL OR t.scheduled_for <= @now)
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM task_tags tt
|
||||
JOIN tags tg ON tg.id = tt.tag_id
|
||||
WHERE tt.task_id = t.id AND tg.name = 'agent'
|
||||
UNION
|
||||
SELECT 1 FROM list_tags lt
|
||||
JOIN tags tg ON tg.id = lt.tag_id
|
||||
WHERE lt.list_id = t.list_id AND tg.name = 'agent'
|
||||
AND (t.scheduled_for IS NULL OR t.scheduled_for <= {0})
|
||||
AND (
|
||||
EXISTS (
|
||||
SELECT 1 FROM task_tags tt
|
||||
JOIN tags tg ON tg.id = tt.tag_id
|
||||
WHERE tt.task_id = t.id AND tg.name = 'agent'
|
||||
)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM list_tags lt
|
||||
JOIN tags tg ON tg.id = lt.tag_id
|
||||
WHERE lt.list_id = t.list_id AND tg.name = 'agent'
|
||||
)
|
||||
)
|
||||
ORDER BY t.created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
RETURNING id, list_id, title, description, status, scheduled_for,
|
||||
result, log_path, created_at, started_at, finished_at, commit_type,
|
||||
model, system_prompt, agent_path
|
||||
""";
|
||||
cmd.Parameters.AddWithValue("@now", now.ToString("o"));
|
||||
RETURNING *
|
||||
""", nowStr).ToListAsync(ct);
|
||||
|
||||
await using var reader = await cmd.ExecuteReaderAsync(ct);
|
||||
if (!await reader.ReadAsync(ct)) return null;
|
||||
return ReadTask(reader);
|
||||
return result.FirstOrDefault();
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Transitions
|
||||
|
||||
public async Task SetLogPathAsync(string taskId, string logPath, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "UPDATE tasks SET log_path = @log_path WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
cmd.Parameters.AddWithValue("@log_path", logPath);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
}
|
||||
|
||||
public async Task MarkRunningAsync(string taskId, DateTime startedAt, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "UPDATE tasks SET status = 'running', started_at = @started_at WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
cmd.Parameters.AddWithValue("@started_at", startedAt.ToString("o"));
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
}
|
||||
|
||||
public async Task MarkDoneAsync(string taskId, DateTime finishedAt, string? result, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "UPDATE tasks SET status = 'done', finished_at = @finished_at, result = @result WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
cmd.Parameters.AddWithValue("@finished_at", finishedAt.ToString("o"));
|
||||
cmd.Parameters.AddWithValue("@result", (object?)result ?? DBNull.Value);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
}
|
||||
|
||||
public async Task MarkFailedAsync(string taskId, DateTime finishedAt, string? errorMarkdown, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = "UPDATE tasks SET status = 'failed', finished_at = @finished_at, result = @result WHERE id = @id";
|
||||
cmd.Parameters.AddWithValue("@id", taskId);
|
||||
cmd.Parameters.AddWithValue("@finished_at", finishedAt.ToString("o"));
|
||||
cmd.Parameters.AddWithValue("@result", (object?)errorMarkdown ?? DBNull.Value);
|
||||
await cmd.ExecuteNonQueryAsync(ct);
|
||||
}
|
||||
|
||||
public async Task<int> FlipAllRunningToFailedAsync(string reason, CancellationToken ct = default)
|
||||
{
|
||||
await using var conn = _factory.Open();
|
||||
await using var cmd = conn.CreateCommand();
|
||||
cmd.CommandText = """
|
||||
UPDATE tasks SET status = 'failed',
|
||||
finished_at = @now,
|
||||
result = '[stale] ' || @reason
|
||||
WHERE status = 'running'
|
||||
""";
|
||||
cmd.Parameters.AddWithValue("@now", DateTime.UtcNow.ToString("o"));
|
||||
cmd.Parameters.AddWithValue("@reason", reason);
|
||||
return await cmd.ExecuteNonQueryAsync(ct);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helpers
|
||||
|
||||
private static void BindTask(SqliteCommand cmd, TaskEntity e)
|
||||
{
|
||||
cmd.Parameters.AddWithValue("@id", e.Id);
|
||||
cmd.Parameters.AddWithValue("@list_id", e.ListId);
|
||||
cmd.Parameters.AddWithValue("@title", e.Title);
|
||||
cmd.Parameters.AddWithValue("@description", (object?)e.Description ?? DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@status", ToDb(e.Status));
|
||||
cmd.Parameters.AddWithValue("@scheduled_for", e.ScheduledFor.HasValue ? e.ScheduledFor.Value.ToString("o") : DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@result", (object?)e.Result ?? DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@log_path", (object?)e.LogPath ?? DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@created_at", e.CreatedAt.ToString("o"));
|
||||
cmd.Parameters.AddWithValue("@started_at", e.StartedAt.HasValue ? e.StartedAt.Value.ToString("o") : DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@finished_at", e.FinishedAt.HasValue ? e.FinishedAt.Value.ToString("o") : DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@commit_type", e.CommitType);
|
||||
cmd.Parameters.AddWithValue("@model", (object?)e.Model ?? DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@system_prompt", (object?)e.SystemPrompt ?? DBNull.Value);
|
||||
cmd.Parameters.AddWithValue("@agent_path", (object?)e.AgentPath ?? DBNull.Value);
|
||||
}
|
||||
|
||||
private static TaskEntity ReadTask(SqliteDataReader r) => new()
|
||||
{
|
||||
Id = r.GetString(0),
|
||||
ListId = r.GetString(1),
|
||||
Title = r.GetString(2),
|
||||
Description = r.IsDBNull(3) ? null : r.GetString(3),
|
||||
Status = FromDb(r.GetString(4)),
|
||||
ScheduledFor = r.IsDBNull(5) ? null : DateTime.Parse(r.GetString(5)),
|
||||
Result = r.IsDBNull(6) ? null : r.GetString(6),
|
||||
LogPath = r.IsDBNull(7) ? null : r.GetString(7),
|
||||
CreatedAt = DateTime.Parse(r.GetString(8)),
|
||||
StartedAt = r.IsDBNull(9) ? null : DateTime.Parse(r.GetString(9)),
|
||||
FinishedAt = r.IsDBNull(10) ? null : DateTime.Parse(r.GetString(10)),
|
||||
CommitType = r.GetString(11),
|
||||
Model = r.IsDBNull(12) ? null : r.GetString(12),
|
||||
SystemPrompt = r.IsDBNull(13) ? null : r.GetString(13),
|
||||
AgentPath = r.IsDBNull(14) ? null : r.GetString(14),
|
||||
};
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user