From 36484ed45a7e32eb1f82f42d70ea42be4af559b7 Mon Sep 17 00:00:00 2001 From: mika kuns Date: Thu, 16 Apr 2026 08:59:24 +0200 Subject: [PATCH] feat(worker,ui): wire EF Core into DI and update all consumers to IDbContextFactory Worker and App Program.cs: replace SqliteConnectionFactory+SchemaInitializer with AddDbContextFactory + Database.Migrate(). Repos changed from AddSingleton to AddScoped. All singleton services (QueueService, StaleTaskRecovery, WorktreeManager, TaskRunner) and singleton ViewModels (MainWindowViewModel, TaskDetailViewModel, TaskListViewModel, TaskEditorViewModel) now take IDbContextFactory and create short-lived contexts per operation. Test infrastructure: DbFixture now uses EF migrations instead of SchemaInitializer; all test classes create contexts via DbFixture.CreateContext(). Co-Authored-By: Claude Sonnet 4.6 --- src/ClaudeDo.App/Program.cs | 37 ++--- .../ViewModels/MainWindowViewModel.cs | 41 ++++-- .../ViewModels/TaskDetailViewModel.cs | 135 +++++++++++++----- .../ViewModels/TaskEditorViewModel.cs | 34 +++-- .../ViewModels/TaskListViewModel.cs | 103 ++++++++----- src/ClaudeDo.Worker/Program.cs | 26 ++-- src/ClaudeDo.Worker/Runner/TaskRunner.cs | 128 +++++++++++------ src/ClaudeDo.Worker/Runner/WorktreeManager.cs | 16 ++- src/ClaudeDo.Worker/Services/QueueService.cs | 23 ++- .../Services/StaleTaskRecovery.cs | 12 +- .../Infrastructure/DbFixture.cs | 24 +++- .../Repositories/ListRepositoryConfigTests.cs | 11 +- .../Repositories/ListRepositoryTests.cs | 13 +- .../Repositories/TaskRepositoryTests.cs | 17 ++- .../Repositories/TaskRunRepositoryTests.cs | 15 +- .../Runner/WorktreeManagerTests.cs | 41 ++++-- .../Services/QueueServiceTests.cs | 20 +-- .../Services/StaleTaskRecoveryTests.cs | 15 +- 18 files changed, 479 insertions(+), 232 deletions(-) diff --git a/src/ClaudeDo.App/Program.cs b/src/ClaudeDo.App/Program.cs index bc67d52..c057463 100644 --- a/src/ClaudeDo.App/Program.cs +++ b/src/ClaudeDo.App/Program.cs @@ -5,6 +5,7 @@ using ClaudeDo.Data.Repositories; using ClaudeDo.Ui; using ClaudeDo.Ui.Services; using ClaudeDo.Ui.ViewModels; +using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using System; @@ -18,9 +19,10 @@ sealed class Program var services = BuildServices(); App.Services = services; - // Ensure DB schema exists - var factory = services.GetRequiredService(); - SchemaInitializer.Apply(factory); + using (var scope = services.CreateScope()) + { + scope.ServiceProvider.GetRequiredService().Database.Migrate(); + } try { @@ -55,14 +57,10 @@ sealed class Program // Infrastructure sc.AddSingleton(settings); - sc.AddSingleton(new SqliteConnectionFactory(dbPath)); - - // Repositories - sc.AddSingleton(); - sc.AddSingleton(); - sc.AddSingleton(); - sc.AddSingleton(); - sc.AddSingleton(); + sc.AddDbContextFactory(opt => + opt.UseSqlite($"Data Source={dbPath}")); + sc.AddScoped(sp => + sp.GetRequiredService>().CreateDbContext()); // Services sc.AddSingleton(); @@ -72,30 +70,21 @@ sealed class Program sc.AddTransient(); sc.AddTransient(); sc.AddSingleton(); - sc.AddSingleton(sp => new TaskDetailViewModel( - sp.GetRequiredService(), - sp.GetRequiredService(), - sp.GetRequiredService(), - sp.GetRequiredService(), - sp.GetRequiredService(), - sp.GetRequiredService(), - sp.GetRequiredService())); + sc.AddSingleton(); sc.AddSingleton(sp => { - var taskRepo = sp.GetRequiredService(); - var tagRepo = sp.GetRequiredService(); - var listRepo = sp.GetRequiredService(); + var dbFactory = sp.GetRequiredService>(); var worker = sp.GetRequiredService(); var statusBar = sp.GetRequiredService(); return new TaskListViewModel( - taskRepo, tagRepo, listRepo, worker, + dbFactory, worker, () => sp.GetRequiredService(), msg => statusBar.ShowMessage(msg)); }); sc.AddSingleton(sp => { return new MainWindowViewModel( - sp.GetRequiredService(), + sp.GetRequiredService>(), sp.GetRequiredService(), sp.GetRequiredService(), sp.GetRequiredService(), diff --git a/src/ClaudeDo.Ui/ViewModels/MainWindowViewModel.cs b/src/ClaudeDo.Ui/ViewModels/MainWindowViewModel.cs index 4309d43..4d2a7e8 100644 --- a/src/ClaudeDo.Ui/ViewModels/MainWindowViewModel.cs +++ b/src/ClaudeDo.Ui/ViewModels/MainWindowViewModel.cs @@ -2,18 +2,20 @@ using System.Collections.ObjectModel; using Avalonia; using Avalonia.Controls; using Avalonia.Controls.ApplicationLifetimes; +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Views; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Ui.ViewModels; public partial class MainWindowViewModel : ViewModelBase { - private readonly ListRepository _listRepo; + private readonly IDbContextFactory _dbFactory; private readonly WorkerClient _worker; private readonly Func _listEditorFactory; @@ -26,14 +28,14 @@ public partial class MainWindowViewModel : ViewModelBase public StatusBarViewModel StatusBar { get; } public MainWindowViewModel( - ListRepository listRepo, + IDbContextFactory dbFactory, WorkerClient worker, TaskListViewModel taskList, TaskDetailViewModel taskDetail, StatusBarViewModel statusBar, Func listEditorFactory) { - _listRepo = listRepo; + _dbFactory = dbFactory; _worker = worker; _listEditorFactory = listEditorFactory; TaskList = taskList; @@ -48,7 +50,9 @@ public partial class MainWindowViewModel : ViewModelBase { try { - var lists = await _listRepo.GetAllAsync(); + using var context = _dbFactory.CreateDbContext(); + var listRepo = new ListRepository(context); + var lists = await listRepo.GetAllAsync(); foreach (var l in lists) Lists.Add(new ListItemViewModel(l)); } @@ -91,10 +95,12 @@ public partial class MainWindowViewModel : ViewModelBase try { - await _listRepo.AddAsync(entity); + using var context = _dbFactory.CreateDbContext(); + var listRepo = new ListRepository(context); + await listRepo.AddAsync(entity); var configEntity = editor.BuildConfig(entity.Id); if (configEntity is not null) - await _listRepo.SetConfigAsync(configEntity); + await listRepo.SetConfigAsync(configEntity); Lists.Add(new ListItemViewModel(entity)); } catch (Exception ex) @@ -107,10 +113,17 @@ public partial class MainWindowViewModel : ViewModelBase private async Task EditList() { if (SelectedList is null) return; - var existing = await _listRepo.GetByIdAsync(SelectedList.Id); - if (existing is null) return; - var config = await _listRepo.GetConfigAsync(existing.Id); + ListEntity? existing; + ListConfigEntity? config; + using (var context = _dbFactory.CreateDbContext()) + { + var listRepo = new ListRepository(context); + existing = await listRepo.GetByIdAsync(SelectedList.Id); + if (existing is null) return; + config = await listRepo.GetConfigAsync(existing.Id); + } + var editor = _listEditorFactory(); await editor.LoadAgentsAsync(_worker); editor.InitForEdit(existing, config); @@ -125,10 +138,12 @@ public partial class MainWindowViewModel : ViewModelBase try { - await _listRepo.UpdateAsync(entity); + using var context = _dbFactory.CreateDbContext(); + var listRepo = new ListRepository(context); + await listRepo.UpdateAsync(entity); var configEntity = editor.BuildConfig(entity.Id); if (configEntity is not null) - await _listRepo.SetConfigAsync(configEntity); + await listRepo.SetConfigAsync(configEntity); SelectedList.Name = entity.Name; SelectedList.WorkingDir = entity.WorkingDir; SelectedList.DefaultCommitType = entity.DefaultCommitType; @@ -146,7 +161,9 @@ public partial class MainWindowViewModel : ViewModelBase // TODO: confirmation dialog try { - await _listRepo.DeleteAsync(SelectedList.Id); + using var context = _dbFactory.CreateDbContext(); + var listRepo = new ListRepository(context); + await listRepo.DeleteAsync(SelectedList.Id); Lists.Remove(SelectedList); SelectedList = null; } diff --git a/src/ClaudeDo.Ui/ViewModels/TaskDetailViewModel.cs b/src/ClaudeDo.Ui/ViewModels/TaskDetailViewModel.cs index 3fbbd50..4b4a144 100644 --- a/src/ClaudeDo.Ui/ViewModels/TaskDetailViewModel.cs +++ b/src/ClaudeDo.Ui/ViewModels/TaskDetailViewModel.cs @@ -2,6 +2,7 @@ using System.Collections.ObjectModel; using System.ComponentModel; using System.Diagnostics; using System.IO; +using ClaudeDo.Data; using ClaudeDo.Data.Git; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; @@ -9,18 +10,15 @@ using ClaudeDo.Ui.Helpers; using ClaudeDo.Ui.Services; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Ui.ViewModels; public partial class TaskDetailViewModel : ViewModelBase { - private readonly TaskRepository _taskRepo; - private readonly WorktreeRepository _worktreeRepo; - private readonly ListRepository _listRepo; + private readonly IDbContextFactory _dbFactory; private readonly GitService _git; private readonly WorkerClient _worker; - private readonly TagRepository _tagRepo; - private readonly SubtaskRepository _subtaskRepo; [ObservableProperty] private string _title = ""; [ObservableProperty] private string? _description; @@ -62,17 +60,11 @@ public partial class TaskDetailViewModel : ViewModelBase public event Action? TaskChanged; - public TaskDetailViewModel(TaskRepository taskRepo, WorktreeRepository worktreeRepo, - ListRepository listRepo, GitService git, WorkerClient worker, TagRepository tagRepo, - SubtaskRepository subtaskRepo) + public TaskDetailViewModel(IDbContextFactory dbFactory, GitService git, WorkerClient worker) { - _taskRepo = taskRepo; - _worktreeRepo = worktreeRepo; - _listRepo = listRepo; + _dbFactory = dbFactory; _git = git; _worker = worker; - _tagRepo = tagRepo; - _subtaskRepo = subtaskRepo; worker.TaskMessageEvent += OnTaskMessage; worker.WorktreeUpdatedEvent += OnWorktreeUpdated; @@ -98,8 +90,24 @@ public partial class TaskDetailViewModel : ViewModelBase try { - var task = await _taskRepo.GetByIdAsync(taskId, ct); - if (task is null) return; + TaskEntity? task; + List tags; + List subtasks; + + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + task = await taskRepo.GetByIdAsync(taskId, ct); + if (task is null) return; + ct.ThrowIfCancellationRequested(); + + tags = await taskRepo.GetTagsAsync(taskId, ct); + ct.ThrowIfCancellationRequested(); + + var subtaskRepo = new SubtaskRepository(context); + subtasks = await subtaskRepo.GetByTaskIdAsync(taskId, ct); + } + ct.ThrowIfCancellationRequested(); if (AvailableAgents.Count == 0) @@ -149,14 +157,12 @@ public partial class TaskDetailViewModel : ViewModelBase } Tags.Clear(); - var tags = await _taskRepo.GetTagsAsync(taskId, ct); foreach (var tag in tags) Tags.Add(tag); // Tear down old subtask subscriptions before replacing them. foreach (var old in Subtasks) old.PropertyChanged -= OnSubtaskPropertyChanged; Subtasks.Clear(); - var subtasks = await _subtaskRepo.GetByTaskIdAsync(taskId, ct); foreach (var s in subtasks) { var vm = SubtaskItemViewModel.From(s); @@ -181,7 +187,9 @@ public partial class TaskDetailViewModel : ViewModelBase { if (_isLoading || _taskId is null) return; - var entity = await _taskRepo.GetByIdAsync(_taskId); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var entity = await taskRepo.GetByIdAsync(_taskId); if (entity is null) return; entity.Title = Title; @@ -196,7 +204,7 @@ public partial class TaskDetailViewModel : ViewModelBase if (Enum.TryParse(StatusChoice, true, out var status)) entity.Status = status; - await _taskRepo.UpdateAsync(entity); + await taskRepo.UpdateAsync(entity); StatusText = entity.Status.ToString().ToLowerInvariant(); TaskChanged?.Invoke(_taskId); } @@ -207,11 +215,15 @@ public partial class TaskDetailViewModel : ViewModelBase var name = NewTagInput.Trim(); if (string.IsNullOrEmpty(name) || _taskId is null) return; - var tagId = await _tagRepo.GetOrCreateAsync(name); - await _taskRepo.AddTagAsync(_taskId, tagId); + using var context = _dbFactory.CreateDbContext(); + var tagRepo = new TagRepository(context); + var taskRepo = new TaskRepository(context); + + var tagId = await tagRepo.GetOrCreateAsync(name); + await taskRepo.AddTagAsync(_taskId, tagId); Tags.Clear(); - var tags = await _taskRepo.GetTagsAsync(_taskId); + var tags = await taskRepo.GetTagsAsync(_taskId); foreach (var tag in tags) Tags.Add(tag); @@ -223,7 +235,9 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task RemoveTag(TagEntity tag) { if (_taskId is null) return; - await _taskRepo.RemoveTagAsync(_taskId, tag.Id); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + await taskRepo.RemoveTagAsync(_taskId, tag.Id); Tags.Remove(tag); TaskChanged?.Invoke(_taskId); } @@ -241,7 +255,9 @@ public partial class TaskDetailViewModel : ViewModelBase OrderNum = Subtasks.Count, CreatedAt = DateTime.UtcNow, }; - await _subtaskRepo.AddAsync(entity); + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + await subtaskRepo.AddAsync(entity); var vm = SubtaskItemViewModel.From(entity); vm.PropertyChanged += OnSubtaskPropertyChanged; Subtasks.Add(vm); @@ -251,7 +267,11 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task RemoveSubtask(SubtaskItemViewModel item) { if (!string.IsNullOrEmpty(item.Id)) - await _subtaskRepo.DeleteAsync(item.Id); + { + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + await subtaskRepo.DeleteAsync(item.Id); + } item.PropertyChanged -= OnSubtaskPropertyChanged; Subtasks.Remove(item); } @@ -262,7 +282,9 @@ public partial class TaskDetailViewModel : ViewModelBase if (e.PropertyName is not (nameof(SubtaskItemViewModel.Title) or nameof(SubtaskItemViewModel.Completed))) return; try { - await _subtaskRepo.UpdateAsync(new SubtaskEntity + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + await subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = _taskId ?? "", @@ -321,7 +343,9 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task LoadWorktreeAsync(string taskId) { - var wt = await _worktreeRepo.GetByTaskIdAsync(taskId); + using var context = _dbFactory.CreateDbContext(); + var wtRepo = new WorktreeRepository(context); + var wt = await wtRepo.GetByTaskIdAsync(taskId); HasWorktree = wt is not null; if (wt is not null) { @@ -378,14 +402,27 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task MergeIntoMainAsync() { if (_taskId is null || _listId is null) return; - var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId); - var list = await _listRepo.GetByIdAsync(_listId); + + WorktreeEntity? wt; + ListEntity? list; + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + wt = await wtRepo.GetByTaskIdAsync(_taskId); + var listRepo = new ListRepository(context); + list = await listRepo.GetByIdAsync(_listId); + } if (wt is null || list?.WorkingDir is null) return; await _git.MergeFfOnlyAsync(list.WorkingDir, wt.BranchName); await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true); - await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Merged); + + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Merged); + } await LoadWorktreeAsync(_taskId); } @@ -393,12 +430,25 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task KeepAsBranchAsync() { if (_taskId is null || _listId is null) return; - var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId); - var list = await _listRepo.GetByIdAsync(_listId); + + WorktreeEntity? wt; + ListEntity? list; + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + wt = await wtRepo.GetByTaskIdAsync(_taskId); + var listRepo = new ListRepository(context); + list = await listRepo.GetByIdAsync(_listId); + } if (wt is null || list?.WorkingDir is null) return; await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); - await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Kept); + + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Kept); + } await LoadWorktreeAsync(_taskId); } @@ -406,13 +456,26 @@ public partial class TaskDetailViewModel : ViewModelBase private async Task DiscardAsync() { if (_taskId is null || _listId is null) return; - var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId); - var list = await _listRepo.GetByIdAsync(_listId); + + WorktreeEntity? wt; + ListEntity? list; + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + wt = await wtRepo.GetByTaskIdAsync(_taskId); + var listRepo = new ListRepository(context); + list = await listRepo.GetByIdAsync(_listId); + } if (wt is null || list?.WorkingDir is null) return; await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true); - await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Discarded); + + using (var context = _dbFactory.CreateDbContext()) + { + var wtRepo = new WorktreeRepository(context); + await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Discarded); + } await LoadWorktreeAsync(_taskId); } diff --git a/src/ClaudeDo.Ui/ViewModels/TaskEditorViewModel.cs b/src/ClaudeDo.Ui/ViewModels/TaskEditorViewModel.cs index 9a172b6..d929ceb 100644 --- a/src/ClaudeDo.Ui/ViewModels/TaskEditorViewModel.cs +++ b/src/ClaudeDo.Ui/ViewModels/TaskEditorViewModel.cs @@ -1,17 +1,19 @@ using System.Collections.ObjectModel; using System.IO; +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Ui.Services; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using Microsoft.EntityFrameworkCore; using TaskStatus = ClaudeDo.Data.Models.TaskStatus; namespace ClaudeDo.Ui.ViewModels; public partial class TaskEditorViewModel : ViewModelBase { - private readonly SubtaskRepository _subtaskRepo; + private readonly IDbContextFactory _dbFactory; [ObservableProperty] private string _title = ""; [ObservableProperty] private string? _description; @@ -40,9 +42,9 @@ public partial class TaskEditorViewModel : ViewModelBase public static string[] StatusChoices { get; } = ["manual", "queued"]; - public TaskEditorViewModel(SubtaskRepository subtaskRepo) + public TaskEditorViewModel(IDbContextFactory dbFactory) { - _subtaskRepo = subtaskRepo; + _dbFactory = dbFactory; } public async Task LoadAgentsAsync(WorkerClient worker) @@ -116,7 +118,9 @@ public partial class TaskEditorViewModel : ViewModelBase WindowTitle = $"Edit Task: {entity.Title}"; Subtasks.Clear(); - var list = await _subtaskRepo.GetByTaskIdAsync(entity.Id, ct); + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + var list = await subtaskRepo.GetByTaskIdAsync(entity.Id, ct); foreach (var s in list) Subtasks.Add(SubtaskItemViewModel.From(s)); } @@ -196,36 +200,42 @@ public partial class TaskEditorViewModel : ViewModelBase // Persist subtask changes if (_editId is not null) { - var existing = await _subtaskRepo.GetByTaskIdAsync(taskId); + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + var existing = await subtaskRepo.GetByTaskIdAsync(taskId); var existingIds = existing.Select(s => s.Id).ToHashSet(); var currentIds = Subtasks.Where(s => s.Id != "").Select(s => s.Id).ToHashSet(); // Deleted foreach (var id in existingIds.Except(currentIds)) - await _subtaskRepo.DeleteAsync(id); + await subtaskRepo.DeleteAsync(id); // Updated foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i))) { if (vm.Id == "") continue; if (vm.Title != vm.OriginalTitle || vm.Completed != vm.OriginalCompleted) - await _subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); + await subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); else { // update order_num if position changed var orig = existing.FirstOrDefault(e => e.Id == vm.Id); if (orig is not null && orig.OrderNum != idx) - await _subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = orig.CreatedAt }); + await subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = orig.CreatedAt }); } } } // Added (id == "" means new) - foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i)).Where(x => x.v.Id == "")) { - if (string.IsNullOrWhiteSpace(vm.Title)) continue; - var newId = Guid.NewGuid().ToString(); - await _subtaskRepo.AddAsync(new SubtaskEntity { Id = newId, TaskId = taskId, Title = vm.Title.Trim(), Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); + using var context = _dbFactory.CreateDbContext(); + var subtaskRepo = new SubtaskRepository(context); + foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i)).Where(x => x.v.Id == "")) + { + if (string.IsNullOrWhiteSpace(vm.Title)) continue; + var newId = Guid.NewGuid().ToString(); + await subtaskRepo.AddAsync(new SubtaskEntity { Id = newId, TaskId = taskId, Title = vm.Title.Trim(), Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); + } } _tcs.TrySetResult(entity); diff --git a/src/ClaudeDo.Ui/ViewModels/TaskListViewModel.cs b/src/ClaudeDo.Ui/ViewModels/TaskListViewModel.cs index 16afd51..1df6795 100644 --- a/src/ClaudeDo.Ui/ViewModels/TaskListViewModel.cs +++ b/src/ClaudeDo.Ui/ViewModels/TaskListViewModel.cs @@ -2,21 +2,21 @@ using System.Collections.ObjectModel; using Avalonia; using Avalonia.Controls; using Avalonia.Controls.ApplicationLifetimes; +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Views; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using Microsoft.EntityFrameworkCore; using TaskStatus = ClaudeDo.Data.Models.TaskStatus; namespace ClaudeDo.Ui.ViewModels; public partial class TaskListViewModel : ViewModelBase { - private readonly TaskRepository _taskRepo; - private readonly TagRepository _tagRepo; - private readonly ListRepository _listRepo; + private readonly IDbContextFactory _dbFactory; private readonly WorkerClient _worker; private readonly Func _editorFactory; private readonly Action _showMessage; @@ -33,13 +33,10 @@ public partial class TaskListViewModel : ViewModelBase partial void OnSelectedTaskChanged(TaskItemViewModel? value) => SelectedTaskChanged?.Invoke(value); - public TaskListViewModel(TaskRepository taskRepo, TagRepository tagRepo, - ListRepository listRepo, WorkerClient worker, + public TaskListViewModel(IDbContextFactory dbFactory, WorkerClient worker, Func editorFactory, Action showMessage) { - _taskRepo = taskRepo; - _tagRepo = tagRepo; - _listRepo = listRepo; + _dbFactory = dbFactory; _worker = worker; _editorFactory = editorFactory; _showMessage = showMessage; @@ -77,7 +74,9 @@ public partial class TaskListViewModel : ViewModelBase if (listId is not null) { - var list = await _listRepo.GetByIdAsync(listId); + using var context = _dbFactory.CreateDbContext(); + var listRepo = new ListRepository(context); + var list = await listRepo.GetByIdAsync(listId); ListName = list?.Name ?? "Tasks"; } else @@ -89,10 +88,12 @@ public partial class TaskListViewModel : ViewModelBase try { - var entities = await _taskRepo.GetByListAsync(listId); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var entities = await taskRepo.GetByListIdAsync(listId); foreach (var e in entities) { - var tags = await _taskRepo.GetEffectiveTagsAsync(e.Id); + var tags = await taskRepo.GetEffectiveTagsAsync(e.Id); Tasks.Add(new TaskItemViewModel(e, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync)); } } @@ -110,8 +111,13 @@ public partial class TaskListViewModel : ViewModelBase var title = InlineAddTitle.Trim(); if (string.IsNullOrEmpty(title) || CurrentListId is null) return; - var list = await _listRepo.GetByIdAsync(CurrentListId); - var defaultCommitType = list?.DefaultCommitType ?? "chore"; + string defaultCommitType; + using (var context = _dbFactory.CreateDbContext()) + { + var listRepo = new ListRepository(context); + var list = await listRepo.GetByIdAsync(CurrentListId); + defaultCommitType = list?.DefaultCommitType ?? "chore"; + } var entity = new TaskEntity { @@ -125,8 +131,10 @@ public partial class TaskListViewModel : ViewModelBase try { - await _taskRepo.AddAsync(entity); - var tags = await _taskRepo.GetEffectiveTagsAsync(entity.Id); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + await taskRepo.AddAsync(entity); + var tags = await taskRepo.GetEffectiveTagsAsync(entity.Id); var vm = new TaskItemViewModel(entity, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync); Tasks.Add(vm); SelectedTask = vm; @@ -141,9 +149,13 @@ public partial class TaskListViewModel : ViewModelBase [RelayCommand(CanExecute = nameof(CanAddTask))] private async Task AddTask() { - // Get list default commit type - var list = await _listRepo.GetByIdAsync(CurrentListId); - var defaultCommitType = list?.DefaultCommitType ?? "chore"; + string defaultCommitType; + using (var context = _dbFactory.CreateDbContext()) + { + var listRepo = new ListRepository(context); + var list = await listRepo.GetByIdAsync(CurrentListId); + defaultCommitType = list?.DefaultCommitType ?? "chore"; + } var editor = _editorFactory(); await editor.LoadAgentsAsync(_worker); @@ -159,15 +171,18 @@ public partial class TaskListViewModel : ViewModelBase try { - await _taskRepo.AddAsync(saved); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var tagRepo = new TagRepository(context); + await taskRepo.AddAsync(saved); foreach (var tagName in editor.SelectedTagNames) { - var tagId = await _tagRepo.GetOrCreateAsync(tagName); - await _taskRepo.AddTagAsync(saved.Id, tagId); + var tagId = await tagRepo.GetOrCreateAsync(tagName); + await taskRepo.AddTagAsync(saved.Id, tagId); } - var tags = await _taskRepo.GetEffectiveTagsAsync(saved.Id); + var tags = await taskRepo.GetEffectiveTagsAsync(saved.Id); Tasks.Add(new TaskItemViewModel(saved, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync)); // Auto wake-queue if agent+queued @@ -188,10 +203,17 @@ public partial class TaskListViewModel : ViewModelBase private async Task EditTask() { if (SelectedTask is null || CurrentListId is null) return; - var entity = await _taskRepo.GetByIdAsync(SelectedTask.Id); - if (entity is null) return; - var taskTags = await _taskRepo.GetTagsAsync(entity.Id); + TaskEntity? entity; + List taskTags; + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + entity = await taskRepo.GetByIdAsync(SelectedTask.Id); + if (entity is null) return; + taskTags = await taskRepo.GetTagsAsync(entity.Id); + } + var editor = _editorFactory(); await editor.LoadAgentsAsync(_worker); await editor.InitForEditAsync(entity, taskTags); @@ -206,18 +228,21 @@ public partial class TaskListViewModel : ViewModelBase try { - await _taskRepo.UpdateAsync(saved); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var tagRepo = new TagRepository(context); + await taskRepo.UpdateAsync(saved); - var existingTags = await _taskRepo.GetTagsAsync(saved.Id); + var existingTags = await taskRepo.GetTagsAsync(saved.Id); foreach (var old in existingTags) - await _taskRepo.RemoveTagAsync(saved.Id, old.Id); + await taskRepo.RemoveTagAsync(saved.Id, old.Id); foreach (var tagName in editor.SelectedTagNames) { - var tagId = await _tagRepo.GetOrCreateAsync(tagName); - await _taskRepo.AddTagAsync(saved.Id, tagId); + var tagId = await tagRepo.GetOrCreateAsync(tagName); + await taskRepo.AddTagAsync(saved.Id, tagId); } - var newTags = await _taskRepo.GetEffectiveTagsAsync(saved.Id); + var newTags = await taskRepo.GetEffectiveTagsAsync(saved.Id); SelectedTask.Refresh(saved, newTags); } catch (Exception ex) @@ -232,7 +257,9 @@ public partial class TaskListViewModel : ViewModelBase if (SelectedTask is null) return; try { - await _taskRepo.DeleteAsync(SelectedTask.Id); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + await taskRepo.DeleteAsync(SelectedTask.Id); Tasks.Remove(SelectedTask); SelectedTask = null; } @@ -244,14 +271,16 @@ public partial class TaskListViewModel : ViewModelBase public async Task RefreshSingleAsync(string taskId) { - var entity = await _taskRepo.GetByIdAsync(taskId); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var entity = await taskRepo.GetByIdAsync(taskId); var existing = Tasks.FirstOrDefault(t => t.Id == taskId); if (entity is null) { if (existing is not null) Tasks.Remove(existing); return; } - var tags = await _taskRepo.GetEffectiveTagsAsync(taskId); + var tags = await taskRepo.GetEffectiveTagsAsync(taskId); if (existing is not null) existing.Refresh(entity, tags); } @@ -270,14 +299,16 @@ public partial class TaskListViewModel : ViewModelBase private async Task ToggleDoneAsync(string taskId) { - var entity = await _taskRepo.GetByIdAsync(taskId); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var entity = await taskRepo.GetByIdAsync(taskId); if (entity is null) return; entity.Status = entity.Status == TaskStatus.Done ? TaskStatus.Manual : TaskStatus.Done; if (entity.Status == TaskStatus.Done) entity.FinishedAt = DateTime.UtcNow; - await _taskRepo.UpdateAsync(entity); + await taskRepo.UpdateAsync(entity); await RefreshSingleAsync(taskId); } diff --git a/src/ClaudeDo.Worker/Program.cs b/src/ClaudeDo.Worker/Program.cs index dbef489..68a80aa 100644 --- a/src/ClaudeDo.Worker/Program.cs +++ b/src/ClaudeDo.Worker/Program.cs @@ -5,6 +5,7 @@ using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Hub; using ClaudeDo.Worker.Runner; using ClaudeDo.Worker.Services; +using Microsoft.EntityFrameworkCore; var cfg = WorkerConfig.Load(); @@ -14,18 +15,18 @@ var builder = WebApplication.CreateBuilder(args); // doesn't think we crashed (~30s timeout). No-op when running interactively. builder.Host.UseWindowsService(o => o.ServiceName = "ClaudeDoWorker"); -// Initialize DB schema before the host starts accepting connections. -var dbFactory = new SqliteConnectionFactory(cfg.DbPath); -SchemaInitializer.Apply(dbFactory); +builder.Services.AddDbContextFactory(opt => + opt.UseSqlite($"Data Source={cfg.DbPath}")); +builder.Services.AddDbContext(opt => + opt.UseSqlite($"Data Source={cfg.DbPath}")); builder.Services.AddSingleton(cfg); -builder.Services.AddSingleton(dbFactory); -builder.Services.AddSingleton(); -builder.Services.AddSingleton(); -builder.Services.AddSingleton(); -builder.Services.AddSingleton(); -builder.Services.AddSingleton(); -builder.Services.AddSingleton(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); +builder.Services.AddScoped(); builder.Services.AddHostedService(); builder.Services.AddSignalR(); @@ -51,6 +52,11 @@ builder.WebHost.UseUrls($"http://127.0.0.1:{cfg.SignalRPort}"); var app = builder.Build(); +using (var scope = app.Services.CreateScope()) +{ + scope.ServiceProvider.GetRequiredService().Database.Migrate(); +} + app.MapHub("/hub"); app.Logger.LogInformation("ClaudeDo.Worker listening on http://127.0.0.1:{Port} (db: {Db})", diff --git a/src/ClaudeDo.Worker/Runner/TaskRunner.cs b/src/ClaudeDo.Worker/Runner/TaskRunner.cs index 2b13c51..8e03b6b 100644 --- a/src/ClaudeDo.Worker/Runner/TaskRunner.cs +++ b/src/ClaudeDo.Worker/Runner/TaskRunner.cs @@ -1,18 +1,16 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Hub; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Worker.Runner; public sealed class TaskRunner { private readonly IClaudeProcess _claude; - private readonly TaskRepository _taskRepo; - private readonly TaskRunRepository _runRepo; - private readonly ListRepository _listRepo; - private readonly WorktreeRepository _wtRepo; - private readonly SubtaskRepository _subtaskRepo; + private readonly IDbContextFactory _dbFactory; private readonly HubBroadcaster _broadcaster; private readonly WorktreeManager _wtManager; private readonly ClaudeArgsBuilder _argsBuilder; @@ -21,11 +19,7 @@ public sealed class TaskRunner public TaskRunner( IClaudeProcess claude, - TaskRepository taskRepo, - TaskRunRepository runRepo, - ListRepository listRepo, - WorktreeRepository wtRepo, - SubtaskRepository subtaskRepo, + IDbContextFactory dbFactory, HubBroadcaster broadcaster, WorktreeManager wtManager, ClaudeArgsBuilder argsBuilder, @@ -33,11 +27,7 @@ public sealed class TaskRunner ILogger logger) { _claude = claude; - _taskRepo = taskRepo; - _runRepo = runRepo; - _listRepo = listRepo; - _wtRepo = wtRepo; - _subtaskRepo = subtaskRepo; + _dbFactory = dbFactory; _broadcaster = broadcaster; _wtManager = wtManager; _argsBuilder = argsBuilder; @@ -49,11 +39,23 @@ public sealed class TaskRunner { try { - var list = await _listRepo.GetByIdAsync(task.ListId, ct); - if (list is null) + ListEntity? list; + ListConfigEntity? listConfig; + List subtasks; + + using (var context = _dbFactory.CreateDbContext()) { - await MarkFailed(task.Id, slot, "List not found."); - return; + var listRepo = new ListRepository(context); + list = await listRepo.GetByIdAsync(task.ListId, ct); + if (list is null) + { + await MarkFailed(task.Id, slot, "List not found."); + return; + } + listConfig = await listRepo.GetConfigAsync(task.ListId, ct); + + var subtaskRepo = new SubtaskRepository(context); + subtasks = await subtaskRepo.GetByTaskIdAsync(task.Id, ct); } // Determine working directory: worktree or sandbox. @@ -81,7 +83,6 @@ public sealed class TaskRunner } // Resolve config: task overrides > list config > null. - var listConfig = await _listRepo.GetConfigAsync(task.ListId, ct); var resolvedConfig = new ClaudeRunConfig( Model: task.Model ?? listConfig?.Model ?? "claude-sonnet-4-6", SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt, @@ -90,11 +91,14 @@ public sealed class TaskRunner ); var now = DateTime.UtcNow; - await _taskRepo.MarkRunningAsync(task.Id, now, ct); + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + await taskRepo.MarkRunningAsync(task.Id, now, ct); + } await _broadcaster.TaskStarted(slot, task.Id, now); // Build prompt. - var subtasks = await _subtaskRepo.GetByTaskIdAsync(task.Id, ct); var sb = new System.Text.StringBuilder(task.Title); if (!string.IsNullOrWhiteSpace(task.Description)) sb.Append("\n\n").Append(task.Description.Trim()); if (subtasks.Count > 0) @@ -155,19 +159,34 @@ public sealed class TaskRunner public async Task ContinueAsync(string taskId, string followUpPrompt, string slot, CancellationToken ct) { - var task = await _taskRepo.GetByIdAsync(taskId, ct) - ?? throw new KeyNotFoundException($"Task '{taskId}' not found."); + TaskEntity task; + TaskRunEntity lastRun; + ListEntity list; + ListConfigEntity? listConfig; + WorktreeEntity? worktree; - var lastRun = await _runRepo.GetLatestByTaskIdAsync(taskId, ct) - ?? throw new InvalidOperationException("No previous run to continue."); + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + task = await taskRepo.GetByIdAsync(taskId, ct) + ?? throw new KeyNotFoundException($"Task '{taskId}' not found."); - if (lastRun.SessionId is null) - throw new InvalidOperationException("Previous run has no session ID — cannot resume."); + var runRepo = new TaskRunRepository(context); + lastRun = await runRepo.GetLatestByTaskIdAsync(taskId, ct) + ?? throw new InvalidOperationException("No previous run to continue."); - var list = await _listRepo.GetByIdAsync(task.ListId, ct) - ?? throw new InvalidOperationException("List not found."); + if (lastRun.SessionId is null) + throw new InvalidOperationException("Previous run has no session ID — cannot resume."); + + var listRepo = new ListRepository(context); + list = await listRepo.GetByIdAsync(task.ListId, ct) + ?? throw new InvalidOperationException("List not found."); + listConfig = await listRepo.GetConfigAsync(task.ListId, ct); + + var wtRepo = new WorktreeRepository(context); + worktree = await wtRepo.GetByTaskIdAsync(taskId, ct); + } - var listConfig = await _listRepo.GetConfigAsync(task.ListId, ct); var resolvedConfig = new ClaudeRunConfig( Model: task.Model ?? listConfig?.Model, SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt, @@ -178,7 +197,6 @@ public sealed class TaskRunner // Determine run directory from existing worktree or sandbox. string runDir; WorktreeContext? wtCtx = null; - var worktree = await _wtRepo.GetByTaskIdAsync(taskId, ct); if (worktree is not null) { runDir = worktree.Path; @@ -190,7 +208,11 @@ public sealed class TaskRunner } var now = DateTime.UtcNow; - await _taskRepo.MarkRunningAsync(taskId, now, ct); + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + await taskRepo.MarkRunningAsync(taskId, now, ct); + } await _broadcaster.TaskStarted(slot, taskId, now); var nextRunNumber = lastRun.RunNumber + 1; @@ -226,7 +248,12 @@ public sealed class TaskRunner LogPath = logPath, StartedAt = DateTime.UtcNow, }; - await _runRepo.AddAsync(run, ct); + + using (var context = _dbFactory.CreateDbContext()) + { + var runRepo = new TaskRunRepository(context); + await runRepo.AddAsync(run, ct); + } var arguments = _argsBuilder.Build(config); @@ -257,10 +284,15 @@ public sealed class TaskRunner run.TokensIn = result.TokensIn; run.TokensOut = result.TokensOut; run.FinishedAt = DateTime.UtcNow; - await _runRepo.UpdateAsync(run, CancellationToken.None); - // Update denormalized fields on the task. - await _taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); + using (var context = _dbFactory.CreateDbContext()) + { + var runRepo = new TaskRunRepository(context); + await runRepo.UpdateAsync(run, CancellationToken.None); + + var taskRepo = new TaskRepository(context); + await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); + } return result; } @@ -273,8 +305,12 @@ public sealed class TaskRunner run.FinishedAt = DateTime.UtcNow; try { - await _runRepo.UpdateAsync(run, CancellationToken.None); - await _taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); + using var context = _dbFactory.CreateDbContext(); + var runRepo = new TaskRunRepository(context); + await runRepo.UpdateAsync(run, CancellationToken.None); + + var taskRepo = new TaskRepository(context); + await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); } catch (Exception updateEx) { @@ -297,7 +333,11 @@ public sealed class TaskRunner // is never left as 'running' because of a cancel that arrived // after the Claude run already succeeded. var finishedAt = DateTime.UtcNow; - await _taskRepo.MarkDoneAsync(task.Id, finishedAt, result.ResultMarkdown, CancellationToken.None); + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + await taskRepo.MarkDoneAsync(task.Id, finishedAt, result.ResultMarkdown, CancellationToken.None); + } await _broadcaster.TaskFinished(slot, task.Id, "done", finishedAt); _logger.LogInformation("Task {TaskId} completed (turns={Turns}, tokens_in={In}, tokens_out={Out})", task.Id, result.TurnCount, result.TokensIn, result.TokensOut); @@ -308,7 +348,9 @@ public sealed class TaskRunner // Intentionally does not accept a CancellationToken: this is the // terminal write for a failed task and must always be persisted. var finishedAt = DateTime.UtcNow; - await _taskRepo.MarkFailedAsync(taskId, finishedAt, result.ErrorMarkdown, CancellationToken.None); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + await taskRepo.MarkFailedAsync(taskId, finishedAt, result.ErrorMarkdown, CancellationToken.None); await _broadcaster.TaskFinished(slot, taskId, "failed", finishedAt); _logger.LogWarning("Task {TaskId} failed (turns={Turns}): {Error}", taskId, result.TurnCount, result.ErrorMarkdown); } @@ -319,7 +361,9 @@ public sealed class TaskRunner { var now = DateTime.UtcNow; // Terminal write — never cancel. - await _taskRepo.MarkFailedAsync(taskId, now, error, CancellationToken.None); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + await taskRepo.MarkFailedAsync(taskId, now, error, CancellationToken.None); await _broadcaster.TaskFinished(slot, taskId, "failed", now); await _broadcaster.TaskUpdated(taskId); } diff --git a/src/ClaudeDo.Worker/Runner/WorktreeManager.cs b/src/ClaudeDo.Worker/Runner/WorktreeManager.cs index 51b76d6..a0869f9 100644 --- a/src/ClaudeDo.Worker/Runner/WorktreeManager.cs +++ b/src/ClaudeDo.Worker/Runner/WorktreeManager.cs @@ -1,7 +1,9 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Git; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Config; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Worker.Runner; @@ -10,14 +12,14 @@ public sealed record WorktreeContext(string WorktreePath, string BranchName, str public sealed class WorktreeManager { private readonly GitService _git; - private readonly WorktreeRepository _wtRepo; + private readonly IDbContextFactory _dbFactory; private readonly WorkerConfig _cfg; private readonly ILogger _logger; - public WorktreeManager(GitService git, WorktreeRepository wtRepo, WorkerConfig cfg, ILogger logger) + public WorktreeManager(GitService git, IDbContextFactory dbFactory, WorkerConfig cfg, ILogger logger) { _git = git; - _wtRepo = wtRepo; + _dbFactory = dbFactory; _cfg = cfg; _logger = logger; } @@ -50,7 +52,9 @@ public sealed class WorktreeManager await _git.WorktreeAddAsync(workingDir, branchName, worktreePath, baseCommit, ct); // Insert worktrees row AFTER git succeeds — if git throws, no row is created. - await _wtRepo.AddAsync(new WorktreeEntity + using var context = _dbFactory.CreateDbContext(); + var wtRepo = new WorktreeRepository(context); + await wtRepo.AddAsync(new WorktreeEntity { TaskId = task.Id, Path = worktreePath, @@ -87,7 +91,9 @@ public sealed class WorktreeManager var head = await _git.RevParseHeadAsync(ctx.WorktreePath, ct); var diffStat = await _git.DiffStatAsync(ctx.WorktreePath, ctx.BaseCommit, head, ct); - await _wtRepo.UpdateHeadAsync(task.Id, head, diffStat, ct); + using var context = _dbFactory.CreateDbContext(); + var wtRepo = new WorktreeRepository(context); + await wtRepo.UpdateHeadAsync(task.Id, head, diffStat, ct); _logger.LogInformation("Committed changes for task {TaskId}: {Head}", task.Id, head); return true; diff --git a/src/ClaudeDo.Worker/Services/QueueService.cs b/src/ClaudeDo.Worker/Services/QueueService.cs index 27b256d..8fbc456 100644 --- a/src/ClaudeDo.Worker/Services/QueueService.cs +++ b/src/ClaudeDo.Worker/Services/QueueService.cs @@ -1,7 +1,9 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Runner; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Worker.Services; @@ -14,7 +16,7 @@ public sealed class QueueSlotState public sealed class QueueService : BackgroundService { - private readonly TaskRepository _taskRepo; + private readonly IDbContextFactory _dbFactory; private readonly TaskRunner _runner; private readonly WorkerConfig _cfg; private readonly ILogger _logger; @@ -26,12 +28,12 @@ public sealed class QueueService : BackgroundService private readonly SemaphoreSlim _wakeSignal = new(0, 1); public QueueService( - TaskRepository taskRepo, + IDbContextFactory dbFactory, TaskRunner runner, WorkerConfig cfg, ILogger logger) { - _taskRepo = taskRepo; + _dbFactory = dbFactory; _runner = runner; _cfg = cfg; _logger = logger; @@ -56,7 +58,9 @@ public sealed class QueueService : BackgroundService public async Task RunNow(string taskId) { - var task = await _taskRepo.GetByIdAsync(taskId); + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var task = await taskRepo.GetByIdAsync(taskId); if (task is null) throw new KeyNotFoundException($"Task '{taskId}' not found."); @@ -78,7 +82,9 @@ public sealed class QueueService : BackgroundService public async Task ContinueTask(string taskId, string followUpPrompt) { - var task = await _taskRepo.GetByIdAsync(taskId) + using var context = _dbFactory.CreateDbContext(); + var taskRepo = new TaskRepository(context); + var task = await taskRepo.GetByIdAsync(taskId) ?? throw new KeyNotFoundException($"Task '{taskId}' not found."); if (task.Status == Data.Models.TaskStatus.Running) @@ -144,7 +150,12 @@ public sealed class QueueService : BackgroundService if (_queueSlot is not null) continue; - var task = await _taskRepo.GetNextQueuedAgentTaskAsync(DateTime.UtcNow, stoppingToken); + TaskEntity? task; + using (var context = _dbFactory.CreateDbContext()) + { + var taskRepo = new TaskRepository(context); + task = await taskRepo.GetNextQueuedAgentTaskAsync(DateTime.UtcNow, stoppingToken); + } if (task is null) continue; lock (_lock) diff --git a/src/ClaudeDo.Worker/Services/StaleTaskRecovery.cs b/src/ClaudeDo.Worker/Services/StaleTaskRecovery.cs index aa43d66..7ad3df7 100644 --- a/src/ClaudeDo.Worker/Services/StaleTaskRecovery.cs +++ b/src/ClaudeDo.Worker/Services/StaleTaskRecovery.cs @@ -1,21 +1,25 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Repositories; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Worker.Services; public sealed class StaleTaskRecovery : IHostedService { - private readonly TaskRepository _tasks; + private readonly IDbContextFactory _dbFactory; private readonly ILogger _logger; - public StaleTaskRecovery(TaskRepository tasks, ILogger logger) + public StaleTaskRecovery(IDbContextFactory dbFactory, ILogger logger) { - _tasks = tasks; + _dbFactory = dbFactory; _logger = logger; } public async Task StartAsync(CancellationToken cancellationToken) { - var flipped = await _tasks.FlipAllRunningToFailedAsync("worker restart", cancellationToken); + using var context = _dbFactory.CreateDbContext(); + var tasks = new TaskRepository(context); + var flipped = await tasks.FlipAllRunningToFailedAsync("worker restart", cancellationToken); if (flipped > 0) _logger.LogWarning("Stale task recovery: flipped {Count} running task(s) to failed", flipped); else diff --git a/tests/ClaudeDo.Worker.Tests/Infrastructure/DbFixture.cs b/tests/ClaudeDo.Worker.Tests/Infrastructure/DbFixture.cs index ebd8ff5..d15c574 100644 --- a/tests/ClaudeDo.Worker.Tests/Infrastructure/DbFixture.cs +++ b/tests/ClaudeDo.Worker.Tests/Infrastructure/DbFixture.cs @@ -1,19 +1,30 @@ using ClaudeDo.Data; +using Microsoft.EntityFrameworkCore; namespace ClaudeDo.Worker.Tests.Infrastructure; public sealed class DbFixture : IDisposable { public string DbPath { get; } - public SqliteConnectionFactory Factory { get; } public DbFixture() { DbPath = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}.db"); - Factory = new SqliteConnectionFactory(DbPath); - SchemaInitializer.Apply(Factory); + // Apply migrations so the schema is created. + using var ctx = CreateContext(); + ctx.Database.Migrate(); } + public ClaudeDoDbContext CreateContext() + { + var options = new DbContextOptionsBuilder() + .UseSqlite($"Data Source={DbPath}") + .Options; + return new ClaudeDoDbContext(options); + } + + public TestDbContextFactory CreateFactory() => new(this); + public void Dispose() { try { File.Delete(DbPath); } catch { /* best effort */ } @@ -21,3 +32,10 @@ public sealed class DbFixture : IDisposable try { File.Delete(DbPath + "-shm"); } catch { } } } + +public sealed class TestDbContextFactory : IDbContextFactory +{ + private readonly DbFixture _fixture; + public TestDbContextFactory(DbFixture fixture) => _fixture = fixture; + public ClaudeDoDbContext CreateDbContext() => _fixture.CreateContext(); +} diff --git a/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryConfigTests.cs b/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryConfigTests.cs index a9068f3..07e6b71 100644 --- a/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryConfigTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryConfigTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Tests.Infrastructure; @@ -7,12 +8,14 @@ namespace ClaudeDo.Worker.Tests.Repositories; public sealed class ListRepositoryConfigTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly ListRepository _repo; private readonly string _listId; public ListRepositoryConfigTests() { - _repo = new ListRepository(_db.Factory); + _ctx = _db.CreateContext(); + _repo = new ListRepository(_ctx); _listId = Guid.NewGuid().ToString(); _repo.AddAsync(new ListEntity { @@ -57,5 +60,9 @@ public sealed class ListRepositoryConfigTests : IDisposable Assert.Equal("haiku-4-5", fetched.Model); } - public void Dispose() => _db.Dispose(); + public void Dispose() + { + _ctx.Dispose(); + _db.Dispose(); + } } diff --git a/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryTests.cs b/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryTests.cs index f1e931d..20463a2 100644 --- a/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Repositories/ListRepositoryTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Tests.Infrastructure; @@ -7,16 +8,22 @@ namespace ClaudeDo.Worker.Tests.Repositories; public sealed class ListRepositoryTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly ListRepository _lists; private readonly TagRepository _tags; public ListRepositoryTests() { - _lists = new ListRepository(_db.Factory); - _tags = new TagRepository(_db.Factory); + _ctx = _db.CreateContext(); + _lists = new ListRepository(_ctx); + _tags = new TagRepository(_ctx); } - public void Dispose() => _db.Dispose(); + public void Dispose() + { + _ctx.Dispose(); + _db.Dispose(); + } [Fact] public async Task AddAsync_And_GetByIdAsync_Roundtrips() diff --git a/tests/ClaudeDo.Worker.Tests/Repositories/TaskRepositoryTests.cs b/tests/ClaudeDo.Worker.Tests/Repositories/TaskRepositoryTests.cs index e99c376..1495213 100644 --- a/tests/ClaudeDo.Worker.Tests/Repositories/TaskRepositoryTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Repositories/TaskRepositoryTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Tests.Infrastructure; @@ -8,18 +9,24 @@ namespace ClaudeDo.Worker.Tests.Repositories; public sealed class TaskRepositoryTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly TaskRepository _tasks; private readonly ListRepository _lists; private readonly TagRepository _tags; public TaskRepositoryTests() { - _tasks = new TaskRepository(_db.Factory); - _lists = new ListRepository(_db.Factory); - _tags = new TagRepository(_db.Factory); + _ctx = _db.CreateContext(); + _tasks = new TaskRepository(_ctx); + _lists = new ListRepository(_ctx); + _tags = new TagRepository(_ctx); } - public void Dispose() => _db.Dispose(); + public void Dispose() + { + _ctx.Dispose(); + _db.Dispose(); + } private async Task CreateListAsync(string? id = null) { @@ -197,7 +204,7 @@ public sealed class TaskRepositoryTests : IDisposable var listId = await CreateListAsync(); var agentTagId = await _tags.GetOrCreateAsync("agent"); var manualTagId = await _tags.GetOrCreateAsync("manual"); - var codeTagId = await TagRepository.GetOrCreateAsync(_db.Factory.Open(), "code"); + var codeTagId = await _tags.GetOrCreateAsync("code"); await _lists.AddTagAsync(listId, agentTagId); diff --git a/tests/ClaudeDo.Worker.Tests/Repositories/TaskRunRepositoryTests.cs b/tests/ClaudeDo.Worker.Tests/Repositories/TaskRunRepositoryTests.cs index 0d06bb1..0205d51 100644 --- a/tests/ClaudeDo.Worker.Tests/Repositories/TaskRunRepositoryTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Repositories/TaskRunRepositoryTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Tests.Infrastructure; @@ -7,16 +8,18 @@ namespace ClaudeDo.Worker.Tests.Repositories; public sealed class TaskRunRepositoryTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly TaskRunRepository _runs; private readonly string _taskId; public TaskRunRepositoryTests() { - _runs = new TaskRunRepository(_db.Factory); + _ctx = _db.CreateContext(); + _runs = new TaskRunRepository(_ctx); // Seed a list and task for all tests - var lists = new ListRepository(_db.Factory); - var tasks = new TaskRepository(_db.Factory); + var lists = new ListRepository(_ctx); + var tasks = new TaskRepository(_ctx); var listId = Guid.NewGuid().ToString(); lists.AddAsync(new ListEntity { @@ -37,7 +40,11 @@ public sealed class TaskRunRepositoryTests : IDisposable }).GetAwaiter().GetResult(); } - public void Dispose() => _db.Dispose(); + public void Dispose() + { + _ctx.Dispose(); + _db.Dispose(); + } private TaskRunEntity MakeRun(int runNumber, bool isRetry = false) => new() { diff --git a/tests/ClaudeDo.Worker.Tests/Runner/WorktreeManagerTests.cs b/tests/ClaudeDo.Worker.Tests/Runner/WorktreeManagerTests.cs index 7e5ed68..4e28ec9 100644 --- a/tests/ClaudeDo.Worker.Tests/Runner/WorktreeManagerTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Runner/WorktreeManagerTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Git; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; @@ -24,19 +25,19 @@ public class WorktreeManagerTests : IDisposable return f; } - private async Task<(WorktreeManager mgr, WorktreeRepository wtRepo)> CreateManagerAsync( + private async Task<(WorktreeManager mgr, DbFixture db)> CreateManagerAsync( TaskEntity task, ListEntity list, string strategy = "sibling", string? centralRoot = null) { var db = new DbFixture(); _dbFixtures.Add(db); // Seed the DB with list and task so FK constraints pass. - var listRepo = new ListRepository(db.Factory); - var taskRepo = new TaskRepository(db.Factory); + using var seedCtx = db.CreateContext(); + var listRepo = new ListRepository(seedCtx); + var taskRepo = new TaskRepository(seedCtx); await listRepo.AddAsync(list); await taskRepo.AddAsync(task); - var wtRepo = new WorktreeRepository(db.Factory); var cfg = new WorkerConfig { WorktreeRootStrategy = strategy, @@ -45,8 +46,8 @@ public class WorktreeManagerTests : IDisposable cfg.CentralWorktreeRoot = centralRoot; var mgr = new WorktreeManager( - new GitService(), wtRepo, cfg, NullLogger.Instance); - return (mgr, wtRepo); + new GitService(), db.CreateFactory(), cfg, NullLogger.Instance); + return (mgr, db); } [Fact] @@ -56,7 +57,7 @@ public class WorktreeManagerTests : IDisposable var repo = CreateRepo(); var (task, list) = MakeEntities(repo.RepoDir); - var (mgr, wtRepo) = await CreateManagerAsync(task, list); + var (mgr, db) = await CreateManagerAsync(task, list); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); @@ -66,6 +67,8 @@ public class WorktreeManagerTests : IDisposable Assert.Equal($"claudedo/{task.Id.Replace("-", "")}", ctx.BranchName); Assert.Equal(repo.BaseCommit, ctx.BaseCommit); + using var readCtx = db.CreateContext(); + var wtRepo = new WorktreeRepository(readCtx); var row = await wtRepo.GetByTaskIdAsync(task.Id); Assert.NotNull(row); Assert.Equal(WorktreeState.Active, row!.State); @@ -80,7 +83,7 @@ public class WorktreeManagerTests : IDisposable var repo = CreateRepo(); var (task, list) = MakeEntities(repo.RepoDir); - var (mgr, wtRepo) = await CreateManagerAsync(task, list); + var (mgr, db) = await CreateManagerAsync(task, list); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); @@ -88,6 +91,8 @@ public class WorktreeManagerTests : IDisposable var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None); Assert.False(committed); + using var readCtx = db.CreateContext(); + var wtRepo = new WorktreeRepository(readCtx); var row = await wtRepo.GetByTaskIdAsync(task.Id); Assert.Null(row!.HeadCommit); } @@ -99,7 +104,7 @@ public class WorktreeManagerTests : IDisposable var repo = CreateRepo(); var (task, list) = MakeEntities(repo.RepoDir); - var (mgr, wtRepo) = await CreateManagerAsync(task, list); + var (mgr, db) = await CreateManagerAsync(task, list); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); @@ -109,6 +114,8 @@ public class WorktreeManagerTests : IDisposable var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None); Assert.True(committed); + using var readCtx = db.CreateContext(); + var wtRepo = new WorktreeRepository(readCtx); var row = await wtRepo.GetByTaskIdAsync(task.Id); Assert.NotNull(row!.HeadCommit); Assert.NotEqual(ctx.BaseCommit, row.HeadCommit); @@ -129,20 +136,24 @@ public class WorktreeManagerTests : IDisposable var db = new DbFixture(); _dbFixtures.Add(db); - var listRepo = new ListRepository(db.Factory); - var taskRepo = new TaskRepository(db.Factory); - await listRepo.AddAsync(list); - await taskRepo.AddAsync(task); + using (var seedCtx = db.CreateContext()) + { + var listRepo = new ListRepository(seedCtx); + var taskRepo = new TaskRepository(seedCtx); + await listRepo.AddAsync(list); + await taskRepo.AddAsync(task); + } - var wtRepo = new WorktreeRepository(db.Factory); var cfg = new WorkerConfig { WorktreeRootStrategy = "sibling" }; var mgr = new WorktreeManager( - new GitService(), wtRepo, cfg, NullLogger.Instance); + new GitService(), db.CreateFactory(), cfg, NullLogger.Instance); var ex = await Assert.ThrowsAsync( () => mgr.CreateAsync(task, list, CancellationToken.None)); Assert.Contains("not a git repository", ex.Message); + using var readCtx = db.CreateContext(); + var wtRepo = new WorktreeRepository(readCtx); var row = await wtRepo.GetByTaskIdAsync(task.Id); Assert.Null(row); } diff --git a/tests/ClaudeDo.Worker.Tests/Services/QueueServiceTests.cs b/tests/ClaudeDo.Worker.Tests/Services/QueueServiceTests.cs index fc2fe1d..3b5a88a 100644 --- a/tests/ClaudeDo.Worker.Tests/Services/QueueServiceTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Services/QueueServiceTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Git; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; @@ -15,6 +16,7 @@ namespace ClaudeDo.Worker.Tests.Services; public sealed class QueueServiceTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly TaskRepository _taskRepo; private readonly ListRepository _listRepo; private readonly TagRepository _tagRepo; @@ -23,9 +25,10 @@ public sealed class QueueServiceTests : IDisposable public QueueServiceTests() { - _taskRepo = new TaskRepository(_db.Factory); - _listRepo = new ListRepository(_db.Factory); - _tagRepo = new TagRepository(_db.Factory); + _ctx = _db.CreateContext(); + _taskRepo = new TaskRepository(_ctx); + _listRepo = new ListRepository(_ctx); + _tagRepo = new TagRepository(_ctx); _tempDir = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}"); Directory.CreateDirectory(_tempDir); _cfg = new WorkerConfig @@ -38,6 +41,7 @@ public sealed class QueueServiceTests : IDisposable public void Dispose() { + _ctx.Dispose(); _db.Dispose(); try { Directory.Delete(_tempDir, true); } catch { } } @@ -47,14 +51,12 @@ public sealed class QueueServiceTests : IDisposable { var fake = new FakeClaudeProcess(handler); var broadcaster = new HubBroadcaster(new FakeHubContext()); - var wtRepo = new WorktreeRepository(_db.Factory); - var runRepo = new TaskRunRepository(_db.Factory); - var wtManager = new WorktreeManager(new GitService(), wtRepo, _cfg, NullLogger.Instance); + var dbFactory = _db.CreateFactory(); + var wtManager = new WorktreeManager(new GitService(), dbFactory, _cfg, NullLogger.Instance); var argsBuilder = new ClaudeArgsBuilder(); - var subtaskRepo = new SubtaskRepository(_db.Factory); - var runner = new TaskRunner(fake, _taskRepo, runRepo, _listRepo, wtRepo, subtaskRepo, broadcaster, wtManager, argsBuilder, _cfg, + var runner = new TaskRunner(fake, dbFactory, broadcaster, wtManager, argsBuilder, _cfg, NullLogger.Instance); - var service = new QueueService(_taskRepo, runner, _cfg, NullLogger.Instance); + var service = new QueueService(dbFactory, runner, _cfg, NullLogger.Instance); return (service, fake); } diff --git a/tests/ClaudeDo.Worker.Tests/Services/StaleTaskRecoveryTests.cs b/tests/ClaudeDo.Worker.Tests/Services/StaleTaskRecoveryTests.cs index 20d54ca..f6452d6 100644 --- a/tests/ClaudeDo.Worker.Tests/Services/StaleTaskRecoveryTests.cs +++ b/tests/ClaudeDo.Worker.Tests/Services/StaleTaskRecoveryTests.cs @@ -1,3 +1,4 @@ +using ClaudeDo.Data; using ClaudeDo.Data.Models; using ClaudeDo.Data.Repositories; using ClaudeDo.Worker.Services; @@ -10,16 +11,22 @@ namespace ClaudeDo.Worker.Tests.Services; public sealed class StaleTaskRecoveryTests : IDisposable { private readonly DbFixture _db = new(); + private readonly ClaudeDoDbContext _ctx; private readonly TaskRepository _tasks; private readonly ListRepository _lists; public StaleTaskRecoveryTests() { - _tasks = new TaskRepository(_db.Factory); - _lists = new ListRepository(_db.Factory); + _ctx = _db.CreateContext(); + _tasks = new TaskRepository(_ctx); + _lists = new ListRepository(_ctx); } - public void Dispose() => _db.Dispose(); + public void Dispose() + { + _ctx.Dispose(); + _db.Dispose(); + } [Fact] public async Task StartAsync_Flips_Running_Tasks_To_Failed() @@ -47,7 +54,7 @@ public sealed class StaleTaskRecoveryTests : IDisposable await _tasks.AddAsync(running); await _tasks.AddAsync(queued); - var recovery = new StaleTaskRecovery(_tasks, NullLogger.Instance); + var recovery = new StaleTaskRecovery(_db.CreateFactory(), NullLogger.Instance); await recovery.StartAsync(CancellationToken.None); var r = await _tasks.GetByIdAsync(running.Id);