feat(worker,ui): wire EF Core into DI and update all consumers to IDbContextFactory

Worker and App Program.cs: replace SqliteConnectionFactory+SchemaInitializer
with AddDbContextFactory<ClaudeDoDbContext> + Database.Migrate(). Repos
changed from AddSingleton to AddScoped.

All singleton services (QueueService, StaleTaskRecovery, WorktreeManager,
TaskRunner) and singleton ViewModels (MainWindowViewModel, TaskDetailViewModel,
TaskListViewModel, TaskEditorViewModel) now take IDbContextFactory<ClaudeDoDbContext>
and create short-lived contexts per operation.

Test infrastructure: DbFixture now uses EF migrations instead of SchemaInitializer;
all test classes create contexts via DbFixture.CreateContext().

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
mika kuns
2026-04-16 08:59:24 +02:00
parent b7be52a623
commit 36484ed45a
18 changed files with 479 additions and 232 deletions

View File

@@ -5,6 +5,7 @@ using ClaudeDo.Data.Repositories;
using ClaudeDo.Ui; using ClaudeDo.Ui;
using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Services;
using ClaudeDo.Ui.ViewModels; using ClaudeDo.Ui.ViewModels;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using System; using System;
@@ -18,9 +19,10 @@ sealed class Program
var services = BuildServices(); var services = BuildServices();
App.Services = services; App.Services = services;
// Ensure DB schema exists using (var scope = services.CreateScope())
var factory = services.GetRequiredService<SqliteConnectionFactory>(); {
SchemaInitializer.Apply(factory); scope.ServiceProvider.GetRequiredService<ClaudeDoDbContext>().Database.Migrate();
}
try try
{ {
@@ -55,14 +57,10 @@ sealed class Program
// Infrastructure // Infrastructure
sc.AddSingleton(settings); sc.AddSingleton(settings);
sc.AddSingleton(new SqliteConnectionFactory(dbPath)); sc.AddDbContextFactory<ClaudeDoDbContext>(opt =>
opt.UseSqlite($"Data Source={dbPath}"));
// Repositories sc.AddScoped<ClaudeDoDbContext>(sp =>
sc.AddSingleton<ListRepository>(); sp.GetRequiredService<IDbContextFactory<ClaudeDoDbContext>>().CreateDbContext());
sc.AddSingleton<TaskRepository>();
sc.AddSingleton<SubtaskRepository>();
sc.AddSingleton<TagRepository>();
sc.AddSingleton<WorktreeRepository>();
// Services // Services
sc.AddSingleton<GitService>(); sc.AddSingleton<GitService>();
@@ -72,30 +70,21 @@ sealed class Program
sc.AddTransient<ListEditorViewModel>(); sc.AddTransient<ListEditorViewModel>();
sc.AddTransient<TaskEditorViewModel>(); sc.AddTransient<TaskEditorViewModel>();
sc.AddSingleton<StatusBarViewModel>(); sc.AddSingleton<StatusBarViewModel>();
sc.AddSingleton<TaskDetailViewModel>(sp => new TaskDetailViewModel( sc.AddSingleton<TaskDetailViewModel>();
sp.GetRequiredService<TaskRepository>(),
sp.GetRequiredService<WorktreeRepository>(),
sp.GetRequiredService<ListRepository>(),
sp.GetRequiredService<GitService>(),
sp.GetRequiredService<WorkerClient>(),
sp.GetRequiredService<TagRepository>(),
sp.GetRequiredService<SubtaskRepository>()));
sc.AddSingleton<TaskListViewModel>(sp => sc.AddSingleton<TaskListViewModel>(sp =>
{ {
var taskRepo = sp.GetRequiredService<TaskRepository>(); var dbFactory = sp.GetRequiredService<IDbContextFactory<ClaudeDoDbContext>>();
var tagRepo = sp.GetRequiredService<TagRepository>();
var listRepo = sp.GetRequiredService<ListRepository>();
var worker = sp.GetRequiredService<WorkerClient>(); var worker = sp.GetRequiredService<WorkerClient>();
var statusBar = sp.GetRequiredService<StatusBarViewModel>(); var statusBar = sp.GetRequiredService<StatusBarViewModel>();
return new TaskListViewModel( return new TaskListViewModel(
taskRepo, tagRepo, listRepo, worker, dbFactory, worker,
() => sp.GetRequiredService<TaskEditorViewModel>(), () => sp.GetRequiredService<TaskEditorViewModel>(),
msg => statusBar.ShowMessage(msg)); msg => statusBar.ShowMessage(msg));
}); });
sc.AddSingleton<MainWindowViewModel>(sp => sc.AddSingleton<MainWindowViewModel>(sp =>
{ {
return new MainWindowViewModel( return new MainWindowViewModel(
sp.GetRequiredService<ListRepository>(), sp.GetRequiredService<IDbContextFactory<ClaudeDoDbContext>>(),
sp.GetRequiredService<WorkerClient>(), sp.GetRequiredService<WorkerClient>(),
sp.GetRequiredService<TaskListViewModel>(), sp.GetRequiredService<TaskListViewModel>(),
sp.GetRequiredService<TaskDetailViewModel>(), sp.GetRequiredService<TaskDetailViewModel>(),

View File

@@ -2,18 +2,20 @@ using System.Collections.ObjectModel;
using Avalonia; using Avalonia;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Controls.ApplicationLifetimes; using Avalonia.Controls.ApplicationLifetimes;
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Services;
using ClaudeDo.Ui.Views; using ClaudeDo.Ui.Views;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Ui.ViewModels; namespace ClaudeDo.Ui.ViewModels;
public partial class MainWindowViewModel : ViewModelBase public partial class MainWindowViewModel : ViewModelBase
{ {
private readonly ListRepository _listRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly WorkerClient _worker; private readonly WorkerClient _worker;
private readonly Func<ListEditorViewModel> _listEditorFactory; private readonly Func<ListEditorViewModel> _listEditorFactory;
@@ -26,14 +28,14 @@ public partial class MainWindowViewModel : ViewModelBase
public StatusBarViewModel StatusBar { get; } public StatusBarViewModel StatusBar { get; }
public MainWindowViewModel( public MainWindowViewModel(
ListRepository listRepo, IDbContextFactory<ClaudeDoDbContext> dbFactory,
WorkerClient worker, WorkerClient worker,
TaskListViewModel taskList, TaskListViewModel taskList,
TaskDetailViewModel taskDetail, TaskDetailViewModel taskDetail,
StatusBarViewModel statusBar, StatusBarViewModel statusBar,
Func<ListEditorViewModel> listEditorFactory) Func<ListEditorViewModel> listEditorFactory)
{ {
_listRepo = listRepo; _dbFactory = dbFactory;
_worker = worker; _worker = worker;
_listEditorFactory = listEditorFactory; _listEditorFactory = listEditorFactory;
TaskList = taskList; TaskList = taskList;
@@ -48,7 +50,9 @@ public partial class MainWindowViewModel : ViewModelBase
{ {
try try
{ {
var lists = await _listRepo.GetAllAsync(); using var context = _dbFactory.CreateDbContext();
var listRepo = new ListRepository(context);
var lists = await listRepo.GetAllAsync();
foreach (var l in lists) foreach (var l in lists)
Lists.Add(new ListItemViewModel(l)); Lists.Add(new ListItemViewModel(l));
} }
@@ -91,10 +95,12 @@ public partial class MainWindowViewModel : ViewModelBase
try try
{ {
await _listRepo.AddAsync(entity); using var context = _dbFactory.CreateDbContext();
var listRepo = new ListRepository(context);
await listRepo.AddAsync(entity);
var configEntity = editor.BuildConfig(entity.Id); var configEntity = editor.BuildConfig(entity.Id);
if (configEntity is not null) if (configEntity is not null)
await _listRepo.SetConfigAsync(configEntity); await listRepo.SetConfigAsync(configEntity);
Lists.Add(new ListItemViewModel(entity)); Lists.Add(new ListItemViewModel(entity));
} }
catch (Exception ex) catch (Exception ex)
@@ -107,10 +113,17 @@ public partial class MainWindowViewModel : ViewModelBase
private async Task EditList() private async Task EditList()
{ {
if (SelectedList is null) return; if (SelectedList is null) return;
var existing = await _listRepo.GetByIdAsync(SelectedList.Id);
if (existing is null) return;
var config = await _listRepo.GetConfigAsync(existing.Id); ListEntity? existing;
ListConfigEntity? config;
using (var context = _dbFactory.CreateDbContext())
{
var listRepo = new ListRepository(context);
existing = await listRepo.GetByIdAsync(SelectedList.Id);
if (existing is null) return;
config = await listRepo.GetConfigAsync(existing.Id);
}
var editor = _listEditorFactory(); var editor = _listEditorFactory();
await editor.LoadAgentsAsync(_worker); await editor.LoadAgentsAsync(_worker);
editor.InitForEdit(existing, config); editor.InitForEdit(existing, config);
@@ -125,10 +138,12 @@ public partial class MainWindowViewModel : ViewModelBase
try try
{ {
await _listRepo.UpdateAsync(entity); using var context = _dbFactory.CreateDbContext();
var listRepo = new ListRepository(context);
await listRepo.UpdateAsync(entity);
var configEntity = editor.BuildConfig(entity.Id); var configEntity = editor.BuildConfig(entity.Id);
if (configEntity is not null) if (configEntity is not null)
await _listRepo.SetConfigAsync(configEntity); await listRepo.SetConfigAsync(configEntity);
SelectedList.Name = entity.Name; SelectedList.Name = entity.Name;
SelectedList.WorkingDir = entity.WorkingDir; SelectedList.WorkingDir = entity.WorkingDir;
SelectedList.DefaultCommitType = entity.DefaultCommitType; SelectedList.DefaultCommitType = entity.DefaultCommitType;
@@ -146,7 +161,9 @@ public partial class MainWindowViewModel : ViewModelBase
// TODO: confirmation dialog // TODO: confirmation dialog
try try
{ {
await _listRepo.DeleteAsync(SelectedList.Id); using var context = _dbFactory.CreateDbContext();
var listRepo = new ListRepository(context);
await listRepo.DeleteAsync(SelectedList.Id);
Lists.Remove(SelectedList); Lists.Remove(SelectedList);
SelectedList = null; SelectedList = null;
} }

View File

@@ -2,6 +2,7 @@ using System.Collections.ObjectModel;
using System.ComponentModel; using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using ClaudeDo.Data;
using ClaudeDo.Data.Git; using ClaudeDo.Data.Git;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
@@ -9,18 +10,15 @@ using ClaudeDo.Ui.Helpers;
using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Services;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Ui.ViewModels; namespace ClaudeDo.Ui.ViewModels;
public partial class TaskDetailViewModel : ViewModelBase public partial class TaskDetailViewModel : ViewModelBase
{ {
private readonly TaskRepository _taskRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly WorktreeRepository _worktreeRepo;
private readonly ListRepository _listRepo;
private readonly GitService _git; private readonly GitService _git;
private readonly WorkerClient _worker; private readonly WorkerClient _worker;
private readonly TagRepository _tagRepo;
private readonly SubtaskRepository _subtaskRepo;
[ObservableProperty] private string _title = ""; [ObservableProperty] private string _title = "";
[ObservableProperty] private string? _description; [ObservableProperty] private string? _description;
@@ -62,17 +60,11 @@ public partial class TaskDetailViewModel : ViewModelBase
public event Action<string>? TaskChanged; public event Action<string>? TaskChanged;
public TaskDetailViewModel(TaskRepository taskRepo, WorktreeRepository worktreeRepo, public TaskDetailViewModel(IDbContextFactory<ClaudeDoDbContext> dbFactory, GitService git, WorkerClient worker)
ListRepository listRepo, GitService git, WorkerClient worker, TagRepository tagRepo,
SubtaskRepository subtaskRepo)
{ {
_taskRepo = taskRepo; _dbFactory = dbFactory;
_worktreeRepo = worktreeRepo;
_listRepo = listRepo;
_git = git; _git = git;
_worker = worker; _worker = worker;
_tagRepo = tagRepo;
_subtaskRepo = subtaskRepo;
worker.TaskMessageEvent += OnTaskMessage; worker.TaskMessageEvent += OnTaskMessage;
worker.WorktreeUpdatedEvent += OnWorktreeUpdated; worker.WorktreeUpdatedEvent += OnWorktreeUpdated;
@@ -98,8 +90,24 @@ public partial class TaskDetailViewModel : ViewModelBase
try try
{ {
var task = await _taskRepo.GetByIdAsync(taskId, ct); TaskEntity? task;
if (task is null) return; List<TagEntity> tags;
List<SubtaskEntity> subtasks;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
task = await taskRepo.GetByIdAsync(taskId, ct);
if (task is null) return;
ct.ThrowIfCancellationRequested();
tags = await taskRepo.GetTagsAsync(taskId, ct);
ct.ThrowIfCancellationRequested();
var subtaskRepo = new SubtaskRepository(context);
subtasks = await subtaskRepo.GetByTaskIdAsync(taskId, ct);
}
ct.ThrowIfCancellationRequested(); ct.ThrowIfCancellationRequested();
if (AvailableAgents.Count == 0) if (AvailableAgents.Count == 0)
@@ -149,14 +157,12 @@ public partial class TaskDetailViewModel : ViewModelBase
} }
Tags.Clear(); Tags.Clear();
var tags = await _taskRepo.GetTagsAsync(taskId, ct);
foreach (var tag in tags) foreach (var tag in tags)
Tags.Add(tag); Tags.Add(tag);
// Tear down old subtask subscriptions before replacing them. // Tear down old subtask subscriptions before replacing them.
foreach (var old in Subtasks) old.PropertyChanged -= OnSubtaskPropertyChanged; foreach (var old in Subtasks) old.PropertyChanged -= OnSubtaskPropertyChanged;
Subtasks.Clear(); Subtasks.Clear();
var subtasks = await _subtaskRepo.GetByTaskIdAsync(taskId, ct);
foreach (var s in subtasks) foreach (var s in subtasks)
{ {
var vm = SubtaskItemViewModel.From(s); var vm = SubtaskItemViewModel.From(s);
@@ -181,7 +187,9 @@ public partial class TaskDetailViewModel : ViewModelBase
{ {
if (_isLoading || _taskId is null) return; if (_isLoading || _taskId is null) return;
var entity = await _taskRepo.GetByIdAsync(_taskId); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var entity = await taskRepo.GetByIdAsync(_taskId);
if (entity is null) return; if (entity is null) return;
entity.Title = Title; entity.Title = Title;
@@ -196,7 +204,7 @@ public partial class TaskDetailViewModel : ViewModelBase
if (Enum.TryParse<Data.Models.TaskStatus>(StatusChoice, true, out var status)) if (Enum.TryParse<Data.Models.TaskStatus>(StatusChoice, true, out var status))
entity.Status = status; entity.Status = status;
await _taskRepo.UpdateAsync(entity); await taskRepo.UpdateAsync(entity);
StatusText = entity.Status.ToString().ToLowerInvariant(); StatusText = entity.Status.ToString().ToLowerInvariant();
TaskChanged?.Invoke(_taskId); TaskChanged?.Invoke(_taskId);
} }
@@ -207,11 +215,15 @@ public partial class TaskDetailViewModel : ViewModelBase
var name = NewTagInput.Trim(); var name = NewTagInput.Trim();
if (string.IsNullOrEmpty(name) || _taskId is null) return; if (string.IsNullOrEmpty(name) || _taskId is null) return;
var tagId = await _tagRepo.GetOrCreateAsync(name); using var context = _dbFactory.CreateDbContext();
await _taskRepo.AddTagAsync(_taskId, tagId); var tagRepo = new TagRepository(context);
var taskRepo = new TaskRepository(context);
var tagId = await tagRepo.GetOrCreateAsync(name);
await taskRepo.AddTagAsync(_taskId, tagId);
Tags.Clear(); Tags.Clear();
var tags = await _taskRepo.GetTagsAsync(_taskId); var tags = await taskRepo.GetTagsAsync(_taskId);
foreach (var tag in tags) foreach (var tag in tags)
Tags.Add(tag); Tags.Add(tag);
@@ -223,7 +235,9 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task RemoveTag(TagEntity tag) private async Task RemoveTag(TagEntity tag)
{ {
if (_taskId is null) return; if (_taskId is null) return;
await _taskRepo.RemoveTagAsync(_taskId, tag.Id); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.RemoveTagAsync(_taskId, tag.Id);
Tags.Remove(tag); Tags.Remove(tag);
TaskChanged?.Invoke(_taskId); TaskChanged?.Invoke(_taskId);
} }
@@ -241,7 +255,9 @@ public partial class TaskDetailViewModel : ViewModelBase
OrderNum = Subtasks.Count, OrderNum = Subtasks.Count,
CreatedAt = DateTime.UtcNow, CreatedAt = DateTime.UtcNow,
}; };
await _subtaskRepo.AddAsync(entity); using var context = _dbFactory.CreateDbContext();
var subtaskRepo = new SubtaskRepository(context);
await subtaskRepo.AddAsync(entity);
var vm = SubtaskItemViewModel.From(entity); var vm = SubtaskItemViewModel.From(entity);
vm.PropertyChanged += OnSubtaskPropertyChanged; vm.PropertyChanged += OnSubtaskPropertyChanged;
Subtasks.Add(vm); Subtasks.Add(vm);
@@ -251,7 +267,11 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task RemoveSubtask(SubtaskItemViewModel item) private async Task RemoveSubtask(SubtaskItemViewModel item)
{ {
if (!string.IsNullOrEmpty(item.Id)) if (!string.IsNullOrEmpty(item.Id))
await _subtaskRepo.DeleteAsync(item.Id); {
using var context = _dbFactory.CreateDbContext();
var subtaskRepo = new SubtaskRepository(context);
await subtaskRepo.DeleteAsync(item.Id);
}
item.PropertyChanged -= OnSubtaskPropertyChanged; item.PropertyChanged -= OnSubtaskPropertyChanged;
Subtasks.Remove(item); Subtasks.Remove(item);
} }
@@ -262,7 +282,9 @@ public partial class TaskDetailViewModel : ViewModelBase
if (e.PropertyName is not (nameof(SubtaskItemViewModel.Title) or nameof(SubtaskItemViewModel.Completed))) return; if (e.PropertyName is not (nameof(SubtaskItemViewModel.Title) or nameof(SubtaskItemViewModel.Completed))) return;
try try
{ {
await _subtaskRepo.UpdateAsync(new SubtaskEntity using var context = _dbFactory.CreateDbContext();
var subtaskRepo = new SubtaskRepository(context);
await subtaskRepo.UpdateAsync(new SubtaskEntity
{ {
Id = vm.Id, Id = vm.Id,
TaskId = _taskId ?? "", TaskId = _taskId ?? "",
@@ -321,7 +343,9 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task LoadWorktreeAsync(string taskId) private async Task LoadWorktreeAsync(string taskId)
{ {
var wt = await _worktreeRepo.GetByTaskIdAsync(taskId); using var context = _dbFactory.CreateDbContext();
var wtRepo = new WorktreeRepository(context);
var wt = await wtRepo.GetByTaskIdAsync(taskId);
HasWorktree = wt is not null; HasWorktree = wt is not null;
if (wt is not null) if (wt is not null)
{ {
@@ -378,14 +402,27 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task MergeIntoMainAsync() private async Task MergeIntoMainAsync()
{ {
if (_taskId is null || _listId is null) return; if (_taskId is null || _listId is null) return;
var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId);
var list = await _listRepo.GetByIdAsync(_listId); WorktreeEntity? wt;
ListEntity? list;
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
wt = await wtRepo.GetByTaskIdAsync(_taskId);
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(_listId);
}
if (wt is null || list?.WorkingDir is null) return; if (wt is null || list?.WorkingDir is null) return;
await _git.MergeFfOnlyAsync(list.WorkingDir, wt.BranchName); await _git.MergeFfOnlyAsync(list.WorkingDir, wt.BranchName);
await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true);
await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true); await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true);
await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Merged);
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Merged);
}
await LoadWorktreeAsync(_taskId); await LoadWorktreeAsync(_taskId);
} }
@@ -393,12 +430,25 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task KeepAsBranchAsync() private async Task KeepAsBranchAsync()
{ {
if (_taskId is null || _listId is null) return; if (_taskId is null || _listId is null) return;
var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId);
var list = await _listRepo.GetByIdAsync(_listId); WorktreeEntity? wt;
ListEntity? list;
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
wt = await wtRepo.GetByTaskIdAsync(_taskId);
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(_listId);
}
if (wt is null || list?.WorkingDir is null) return; if (wt is null || list?.WorkingDir is null) return;
await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true);
await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Kept);
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Kept);
}
await LoadWorktreeAsync(_taskId); await LoadWorktreeAsync(_taskId);
} }
@@ -406,13 +456,26 @@ public partial class TaskDetailViewModel : ViewModelBase
private async Task DiscardAsync() private async Task DiscardAsync()
{ {
if (_taskId is null || _listId is null) return; if (_taskId is null || _listId is null) return;
var wt = await _worktreeRepo.GetByTaskIdAsync(_taskId);
var list = await _listRepo.GetByIdAsync(_listId); WorktreeEntity? wt;
ListEntity? list;
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
wt = await wtRepo.GetByTaskIdAsync(_taskId);
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(_listId);
}
if (wt is null || list?.WorkingDir is null) return; if (wt is null || list?.WorkingDir is null) return;
await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true); await _git.WorktreeRemoveAsync(list.WorkingDir, wt.Path, force: true);
await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true); await _git.BranchDeleteAsync(list.WorkingDir, wt.BranchName, force: true);
await _worktreeRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Discarded);
using (var context = _dbFactory.CreateDbContext())
{
var wtRepo = new WorktreeRepository(context);
await wtRepo.SetStateAsync(_taskId, Data.Models.WorktreeState.Discarded);
}
await LoadWorktreeAsync(_taskId); await LoadWorktreeAsync(_taskId);
} }

View File

@@ -1,17 +1,19 @@
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.IO; using System.IO;
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Services;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using Microsoft.EntityFrameworkCore;
using TaskStatus = ClaudeDo.Data.Models.TaskStatus; using TaskStatus = ClaudeDo.Data.Models.TaskStatus;
namespace ClaudeDo.Ui.ViewModels; namespace ClaudeDo.Ui.ViewModels;
public partial class TaskEditorViewModel : ViewModelBase public partial class TaskEditorViewModel : ViewModelBase
{ {
private readonly SubtaskRepository _subtaskRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
[ObservableProperty] private string _title = ""; [ObservableProperty] private string _title = "";
[ObservableProperty] private string? _description; [ObservableProperty] private string? _description;
@@ -40,9 +42,9 @@ public partial class TaskEditorViewModel : ViewModelBase
public static string[] StatusChoices { get; } = public static string[] StatusChoices { get; } =
["manual", "queued"]; ["manual", "queued"];
public TaskEditorViewModel(SubtaskRepository subtaskRepo) public TaskEditorViewModel(IDbContextFactory<ClaudeDoDbContext> dbFactory)
{ {
_subtaskRepo = subtaskRepo; _dbFactory = dbFactory;
} }
public async Task LoadAgentsAsync(WorkerClient worker) public async Task LoadAgentsAsync(WorkerClient worker)
@@ -116,7 +118,9 @@ public partial class TaskEditorViewModel : ViewModelBase
WindowTitle = $"Edit Task: {entity.Title}"; WindowTitle = $"Edit Task: {entity.Title}";
Subtasks.Clear(); Subtasks.Clear();
var list = await _subtaskRepo.GetByTaskIdAsync(entity.Id, ct); using var context = _dbFactory.CreateDbContext();
var subtaskRepo = new SubtaskRepository(context);
var list = await subtaskRepo.GetByTaskIdAsync(entity.Id, ct);
foreach (var s in list) foreach (var s in list)
Subtasks.Add(SubtaskItemViewModel.From(s)); Subtasks.Add(SubtaskItemViewModel.From(s));
} }
@@ -196,36 +200,42 @@ public partial class TaskEditorViewModel : ViewModelBase
// Persist subtask changes // Persist subtask changes
if (_editId is not null) if (_editId is not null)
{ {
var existing = await _subtaskRepo.GetByTaskIdAsync(taskId); using var context = _dbFactory.CreateDbContext();
var subtaskRepo = new SubtaskRepository(context);
var existing = await subtaskRepo.GetByTaskIdAsync(taskId);
var existingIds = existing.Select(s => s.Id).ToHashSet(); var existingIds = existing.Select(s => s.Id).ToHashSet();
var currentIds = Subtasks.Where(s => s.Id != "").Select(s => s.Id).ToHashSet(); var currentIds = Subtasks.Where(s => s.Id != "").Select(s => s.Id).ToHashSet();
// Deleted // Deleted
foreach (var id in existingIds.Except(currentIds)) foreach (var id in existingIds.Except(currentIds))
await _subtaskRepo.DeleteAsync(id); await subtaskRepo.DeleteAsync(id);
// Updated // Updated
foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i))) foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i)))
{ {
if (vm.Id == "") continue; if (vm.Id == "") continue;
if (vm.Title != vm.OriginalTitle || vm.Completed != vm.OriginalCompleted) if (vm.Title != vm.OriginalTitle || vm.Completed != vm.OriginalCompleted)
await _subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); await subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow });
else else
{ {
// update order_num if position changed // update order_num if position changed
var orig = existing.FirstOrDefault(e => e.Id == vm.Id); var orig = existing.FirstOrDefault(e => e.Id == vm.Id);
if (orig is not null && orig.OrderNum != idx) if (orig is not null && orig.OrderNum != idx)
await _subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = orig.CreatedAt }); await subtaskRepo.UpdateAsync(new SubtaskEntity { Id = vm.Id, TaskId = taskId, Title = vm.Title, Completed = vm.Completed, OrderNum = idx, CreatedAt = orig.CreatedAt });
} }
} }
} }
// Added (id == "" means new) // Added (id == "" means new)
foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i)).Where(x => x.v.Id == ""))
{ {
if (string.IsNullOrWhiteSpace(vm.Title)) continue; using var context = _dbFactory.CreateDbContext();
var newId = Guid.NewGuid().ToString(); var subtaskRepo = new SubtaskRepository(context);
await _subtaskRepo.AddAsync(new SubtaskEntity { Id = newId, TaskId = taskId, Title = vm.Title.Trim(), Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow }); foreach (var (vm, idx) in Subtasks.Select((v, i) => (v, i)).Where(x => x.v.Id == ""))
{
if (string.IsNullOrWhiteSpace(vm.Title)) continue;
var newId = Guid.NewGuid().ToString();
await subtaskRepo.AddAsync(new SubtaskEntity { Id = newId, TaskId = taskId, Title = vm.Title.Trim(), Completed = vm.Completed, OrderNum = idx, CreatedAt = DateTime.UtcNow });
}
} }
_tcs.TrySetResult(entity); _tcs.TrySetResult(entity);

View File

@@ -2,21 +2,21 @@ using System.Collections.ObjectModel;
using Avalonia; using Avalonia;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Controls.ApplicationLifetimes; using Avalonia.Controls.ApplicationLifetimes;
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Ui.Services; using ClaudeDo.Ui.Services;
using ClaudeDo.Ui.Views; using ClaudeDo.Ui.Views;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input; using CommunityToolkit.Mvvm.Input;
using Microsoft.EntityFrameworkCore;
using TaskStatus = ClaudeDo.Data.Models.TaskStatus; using TaskStatus = ClaudeDo.Data.Models.TaskStatus;
namespace ClaudeDo.Ui.ViewModels; namespace ClaudeDo.Ui.ViewModels;
public partial class TaskListViewModel : ViewModelBase public partial class TaskListViewModel : ViewModelBase
{ {
private readonly TaskRepository _taskRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly TagRepository _tagRepo;
private readonly ListRepository _listRepo;
private readonly WorkerClient _worker; private readonly WorkerClient _worker;
private readonly Func<TaskEditorViewModel> _editorFactory; private readonly Func<TaskEditorViewModel> _editorFactory;
private readonly Action<string> _showMessage; private readonly Action<string> _showMessage;
@@ -33,13 +33,10 @@ public partial class TaskListViewModel : ViewModelBase
partial void OnSelectedTaskChanged(TaskItemViewModel? value) => partial void OnSelectedTaskChanged(TaskItemViewModel? value) =>
SelectedTaskChanged?.Invoke(value); SelectedTaskChanged?.Invoke(value);
public TaskListViewModel(TaskRepository taskRepo, TagRepository tagRepo, public TaskListViewModel(IDbContextFactory<ClaudeDoDbContext> dbFactory, WorkerClient worker,
ListRepository listRepo, WorkerClient worker,
Func<TaskEditorViewModel> editorFactory, Action<string> showMessage) Func<TaskEditorViewModel> editorFactory, Action<string> showMessage)
{ {
_taskRepo = taskRepo; _dbFactory = dbFactory;
_tagRepo = tagRepo;
_listRepo = listRepo;
_worker = worker; _worker = worker;
_editorFactory = editorFactory; _editorFactory = editorFactory;
_showMessage = showMessage; _showMessage = showMessage;
@@ -77,7 +74,9 @@ public partial class TaskListViewModel : ViewModelBase
if (listId is not null) if (listId is not null)
{ {
var list = await _listRepo.GetByIdAsync(listId); using var context = _dbFactory.CreateDbContext();
var listRepo = new ListRepository(context);
var list = await listRepo.GetByIdAsync(listId);
ListName = list?.Name ?? "Tasks"; ListName = list?.Name ?? "Tasks";
} }
else else
@@ -89,10 +88,12 @@ public partial class TaskListViewModel : ViewModelBase
try try
{ {
var entities = await _taskRepo.GetByListAsync(listId); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var entities = await taskRepo.GetByListIdAsync(listId);
foreach (var e in entities) foreach (var e in entities)
{ {
var tags = await _taskRepo.GetEffectiveTagsAsync(e.Id); var tags = await taskRepo.GetEffectiveTagsAsync(e.Id);
Tasks.Add(new TaskItemViewModel(e, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync)); Tasks.Add(new TaskItemViewModel(e, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync));
} }
} }
@@ -110,8 +111,13 @@ public partial class TaskListViewModel : ViewModelBase
var title = InlineAddTitle.Trim(); var title = InlineAddTitle.Trim();
if (string.IsNullOrEmpty(title) || CurrentListId is null) return; if (string.IsNullOrEmpty(title) || CurrentListId is null) return;
var list = await _listRepo.GetByIdAsync(CurrentListId); string defaultCommitType;
var defaultCommitType = list?.DefaultCommitType ?? "chore"; using (var context = _dbFactory.CreateDbContext())
{
var listRepo = new ListRepository(context);
var list = await listRepo.GetByIdAsync(CurrentListId);
defaultCommitType = list?.DefaultCommitType ?? "chore";
}
var entity = new TaskEntity var entity = new TaskEntity
{ {
@@ -125,8 +131,10 @@ public partial class TaskListViewModel : ViewModelBase
try try
{ {
await _taskRepo.AddAsync(entity); using var context = _dbFactory.CreateDbContext();
var tags = await _taskRepo.GetEffectiveTagsAsync(entity.Id); var taskRepo = new TaskRepository(context);
await taskRepo.AddAsync(entity);
var tags = await taskRepo.GetEffectiveTagsAsync(entity.Id);
var vm = new TaskItemViewModel(entity, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync); var vm = new TaskItemViewModel(entity, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync);
Tasks.Add(vm); Tasks.Add(vm);
SelectedTask = vm; SelectedTask = vm;
@@ -141,9 +149,13 @@ public partial class TaskListViewModel : ViewModelBase
[RelayCommand(CanExecute = nameof(CanAddTask))] [RelayCommand(CanExecute = nameof(CanAddTask))]
private async Task AddTask() private async Task AddTask()
{ {
// Get list default commit type string defaultCommitType;
var list = await _listRepo.GetByIdAsync(CurrentListId); using (var context = _dbFactory.CreateDbContext())
var defaultCommitType = list?.DefaultCommitType ?? "chore"; {
var listRepo = new ListRepository(context);
var list = await listRepo.GetByIdAsync(CurrentListId);
defaultCommitType = list?.DefaultCommitType ?? "chore";
}
var editor = _editorFactory(); var editor = _editorFactory();
await editor.LoadAgentsAsync(_worker); await editor.LoadAgentsAsync(_worker);
@@ -159,15 +171,18 @@ public partial class TaskListViewModel : ViewModelBase
try try
{ {
await _taskRepo.AddAsync(saved); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var tagRepo = new TagRepository(context);
await taskRepo.AddAsync(saved);
foreach (var tagName in editor.SelectedTagNames) foreach (var tagName in editor.SelectedTagNames)
{ {
var tagId = await _tagRepo.GetOrCreateAsync(tagName); var tagId = await tagRepo.GetOrCreateAsync(tagName);
await _taskRepo.AddTagAsync(saved.Id, tagId); await taskRepo.AddTagAsync(saved.Id, tagId);
} }
var tags = await _taskRepo.GetEffectiveTagsAsync(saved.Id); var tags = await taskRepo.GetEffectiveTagsAsync(saved.Id);
Tasks.Add(new TaskItemViewModel(saved, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync)); Tasks.Add(new TaskItemViewModel(saved, tags, RunNowAsync, () => _worker.IsConnected, ToggleDoneAsync));
// Auto wake-queue if agent+queued // Auto wake-queue if agent+queued
@@ -188,10 +203,17 @@ public partial class TaskListViewModel : ViewModelBase
private async Task EditTask() private async Task EditTask()
{ {
if (SelectedTask is null || CurrentListId is null) return; if (SelectedTask is null || CurrentListId is null) return;
var entity = await _taskRepo.GetByIdAsync(SelectedTask.Id);
if (entity is null) return;
var taskTags = await _taskRepo.GetTagsAsync(entity.Id); TaskEntity? entity;
List<TagEntity> taskTags;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
entity = await taskRepo.GetByIdAsync(SelectedTask.Id);
if (entity is null) return;
taskTags = await taskRepo.GetTagsAsync(entity.Id);
}
var editor = _editorFactory(); var editor = _editorFactory();
await editor.LoadAgentsAsync(_worker); await editor.LoadAgentsAsync(_worker);
await editor.InitForEditAsync(entity, taskTags); await editor.InitForEditAsync(entity, taskTags);
@@ -206,18 +228,21 @@ public partial class TaskListViewModel : ViewModelBase
try try
{ {
await _taskRepo.UpdateAsync(saved); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var tagRepo = new TagRepository(context);
await taskRepo.UpdateAsync(saved);
var existingTags = await _taskRepo.GetTagsAsync(saved.Id); var existingTags = await taskRepo.GetTagsAsync(saved.Id);
foreach (var old in existingTags) foreach (var old in existingTags)
await _taskRepo.RemoveTagAsync(saved.Id, old.Id); await taskRepo.RemoveTagAsync(saved.Id, old.Id);
foreach (var tagName in editor.SelectedTagNames) foreach (var tagName in editor.SelectedTagNames)
{ {
var tagId = await _tagRepo.GetOrCreateAsync(tagName); var tagId = await tagRepo.GetOrCreateAsync(tagName);
await _taskRepo.AddTagAsync(saved.Id, tagId); await taskRepo.AddTagAsync(saved.Id, tagId);
} }
var newTags = await _taskRepo.GetEffectiveTagsAsync(saved.Id); var newTags = await taskRepo.GetEffectiveTagsAsync(saved.Id);
SelectedTask.Refresh(saved, newTags); SelectedTask.Refresh(saved, newTags);
} }
catch (Exception ex) catch (Exception ex)
@@ -232,7 +257,9 @@ public partial class TaskListViewModel : ViewModelBase
if (SelectedTask is null) return; if (SelectedTask is null) return;
try try
{ {
await _taskRepo.DeleteAsync(SelectedTask.Id); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.DeleteAsync(SelectedTask.Id);
Tasks.Remove(SelectedTask); Tasks.Remove(SelectedTask);
SelectedTask = null; SelectedTask = null;
} }
@@ -244,14 +271,16 @@ public partial class TaskListViewModel : ViewModelBase
public async Task RefreshSingleAsync(string taskId) public async Task RefreshSingleAsync(string taskId)
{ {
var entity = await _taskRepo.GetByIdAsync(taskId); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var entity = await taskRepo.GetByIdAsync(taskId);
var existing = Tasks.FirstOrDefault(t => t.Id == taskId); var existing = Tasks.FirstOrDefault(t => t.Id == taskId);
if (entity is null) if (entity is null)
{ {
if (existing is not null) Tasks.Remove(existing); if (existing is not null) Tasks.Remove(existing);
return; return;
} }
var tags = await _taskRepo.GetEffectiveTagsAsync(taskId); var tags = await taskRepo.GetEffectiveTagsAsync(taskId);
if (existing is not null) if (existing is not null)
existing.Refresh(entity, tags); existing.Refresh(entity, tags);
} }
@@ -270,14 +299,16 @@ public partial class TaskListViewModel : ViewModelBase
private async Task ToggleDoneAsync(string taskId) private async Task ToggleDoneAsync(string taskId)
{ {
var entity = await _taskRepo.GetByIdAsync(taskId); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var entity = await taskRepo.GetByIdAsync(taskId);
if (entity is null) return; if (entity is null) return;
entity.Status = entity.Status == TaskStatus.Done ? TaskStatus.Manual : TaskStatus.Done; entity.Status = entity.Status == TaskStatus.Done ? TaskStatus.Manual : TaskStatus.Done;
if (entity.Status == TaskStatus.Done) if (entity.Status == TaskStatus.Done)
entity.FinishedAt = DateTime.UtcNow; entity.FinishedAt = DateTime.UtcNow;
await _taskRepo.UpdateAsync(entity); await taskRepo.UpdateAsync(entity);
await RefreshSingleAsync(taskId); await RefreshSingleAsync(taskId);
} }

View File

@@ -5,6 +5,7 @@ using ClaudeDo.Worker.Config;
using ClaudeDo.Worker.Hub; using ClaudeDo.Worker.Hub;
using ClaudeDo.Worker.Runner; using ClaudeDo.Worker.Runner;
using ClaudeDo.Worker.Services; using ClaudeDo.Worker.Services;
using Microsoft.EntityFrameworkCore;
var cfg = WorkerConfig.Load(); var cfg = WorkerConfig.Load();
@@ -14,18 +15,18 @@ var builder = WebApplication.CreateBuilder(args);
// doesn't think we crashed (~30s timeout). No-op when running interactively. // doesn't think we crashed (~30s timeout). No-op when running interactively.
builder.Host.UseWindowsService(o => o.ServiceName = "ClaudeDoWorker"); builder.Host.UseWindowsService(o => o.ServiceName = "ClaudeDoWorker");
// Initialize DB schema before the host starts accepting connections. builder.Services.AddDbContextFactory<ClaudeDoDbContext>(opt =>
var dbFactory = new SqliteConnectionFactory(cfg.DbPath); opt.UseSqlite($"Data Source={cfg.DbPath}"));
SchemaInitializer.Apply(dbFactory); builder.Services.AddDbContext<ClaudeDoDbContext>(opt =>
opt.UseSqlite($"Data Source={cfg.DbPath}"));
builder.Services.AddSingleton(cfg); builder.Services.AddSingleton(cfg);
builder.Services.AddSingleton(dbFactory); builder.Services.AddScoped<TagRepository>();
builder.Services.AddSingleton<TagRepository>(); builder.Services.AddScoped<ListRepository>();
builder.Services.AddSingleton<ListRepository>(); builder.Services.AddScoped<TaskRepository>();
builder.Services.AddSingleton<TaskRepository>(); builder.Services.AddScoped<SubtaskRepository>();
builder.Services.AddSingleton<SubtaskRepository>(); builder.Services.AddScoped<WorktreeRepository>();
builder.Services.AddSingleton<WorktreeRepository>(); builder.Services.AddScoped<TaskRunRepository>();
builder.Services.AddSingleton<TaskRunRepository>();
builder.Services.AddHostedService<StaleTaskRecovery>(); builder.Services.AddHostedService<StaleTaskRecovery>();
builder.Services.AddSignalR(); builder.Services.AddSignalR();
@@ -51,6 +52,11 @@ builder.WebHost.UseUrls($"http://127.0.0.1:{cfg.SignalRPort}");
var app = builder.Build(); var app = builder.Build();
using (var scope = app.Services.CreateScope())
{
scope.ServiceProvider.GetRequiredService<ClaudeDoDbContext>().Database.Migrate();
}
app.MapHub<WorkerHub>("/hub"); app.MapHub<WorkerHub>("/hub");
app.Logger.LogInformation("ClaudeDo.Worker listening on http://127.0.0.1:{Port} (db: {Db})", app.Logger.LogInformation("ClaudeDo.Worker listening on http://127.0.0.1:{Port} (db: {Db})",

View File

@@ -1,18 +1,16 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Config;
using ClaudeDo.Worker.Hub; using ClaudeDo.Worker.Hub;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Runner; namespace ClaudeDo.Worker.Runner;
public sealed class TaskRunner public sealed class TaskRunner
{ {
private readonly IClaudeProcess _claude; private readonly IClaudeProcess _claude;
private readonly TaskRepository _taskRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly TaskRunRepository _runRepo;
private readonly ListRepository _listRepo;
private readonly WorktreeRepository _wtRepo;
private readonly SubtaskRepository _subtaskRepo;
private readonly HubBroadcaster _broadcaster; private readonly HubBroadcaster _broadcaster;
private readonly WorktreeManager _wtManager; private readonly WorktreeManager _wtManager;
private readonly ClaudeArgsBuilder _argsBuilder; private readonly ClaudeArgsBuilder _argsBuilder;
@@ -21,11 +19,7 @@ public sealed class TaskRunner
public TaskRunner( public TaskRunner(
IClaudeProcess claude, IClaudeProcess claude,
TaskRepository taskRepo, IDbContextFactory<ClaudeDoDbContext> dbFactory,
TaskRunRepository runRepo,
ListRepository listRepo,
WorktreeRepository wtRepo,
SubtaskRepository subtaskRepo,
HubBroadcaster broadcaster, HubBroadcaster broadcaster,
WorktreeManager wtManager, WorktreeManager wtManager,
ClaudeArgsBuilder argsBuilder, ClaudeArgsBuilder argsBuilder,
@@ -33,11 +27,7 @@ public sealed class TaskRunner
ILogger<TaskRunner> logger) ILogger<TaskRunner> logger)
{ {
_claude = claude; _claude = claude;
_taskRepo = taskRepo; _dbFactory = dbFactory;
_runRepo = runRepo;
_listRepo = listRepo;
_wtRepo = wtRepo;
_subtaskRepo = subtaskRepo;
_broadcaster = broadcaster; _broadcaster = broadcaster;
_wtManager = wtManager; _wtManager = wtManager;
_argsBuilder = argsBuilder; _argsBuilder = argsBuilder;
@@ -49,11 +39,23 @@ public sealed class TaskRunner
{ {
try try
{ {
var list = await _listRepo.GetByIdAsync(task.ListId, ct); ListEntity? list;
if (list is null) ListConfigEntity? listConfig;
List<SubtaskEntity> subtasks;
using (var context = _dbFactory.CreateDbContext())
{ {
await MarkFailed(task.Id, slot, "List not found."); var listRepo = new ListRepository(context);
return; list = await listRepo.GetByIdAsync(task.ListId, ct);
if (list is null)
{
await MarkFailed(task.Id, slot, "List not found.");
return;
}
listConfig = await listRepo.GetConfigAsync(task.ListId, ct);
var subtaskRepo = new SubtaskRepository(context);
subtasks = await subtaskRepo.GetByTaskIdAsync(task.Id, ct);
} }
// Determine working directory: worktree or sandbox. // Determine working directory: worktree or sandbox.
@@ -81,7 +83,6 @@ public sealed class TaskRunner
} }
// Resolve config: task overrides > list config > null. // Resolve config: task overrides > list config > null.
var listConfig = await _listRepo.GetConfigAsync(task.ListId, ct);
var resolvedConfig = new ClaudeRunConfig( var resolvedConfig = new ClaudeRunConfig(
Model: task.Model ?? listConfig?.Model ?? "claude-sonnet-4-6", Model: task.Model ?? listConfig?.Model ?? "claude-sonnet-4-6",
SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt, SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt,
@@ -90,11 +91,14 @@ public sealed class TaskRunner
); );
var now = DateTime.UtcNow; var now = DateTime.UtcNow;
await _taskRepo.MarkRunningAsync(task.Id, now, ct); using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkRunningAsync(task.Id, now, ct);
}
await _broadcaster.TaskStarted(slot, task.Id, now); await _broadcaster.TaskStarted(slot, task.Id, now);
// Build prompt. // Build prompt.
var subtasks = await _subtaskRepo.GetByTaskIdAsync(task.Id, ct);
var sb = new System.Text.StringBuilder(task.Title); var sb = new System.Text.StringBuilder(task.Title);
if (!string.IsNullOrWhiteSpace(task.Description)) sb.Append("\n\n").Append(task.Description.Trim()); if (!string.IsNullOrWhiteSpace(task.Description)) sb.Append("\n\n").Append(task.Description.Trim());
if (subtasks.Count > 0) if (subtasks.Count > 0)
@@ -155,19 +159,34 @@ public sealed class TaskRunner
public async Task ContinueAsync(string taskId, string followUpPrompt, string slot, CancellationToken ct) public async Task ContinueAsync(string taskId, string followUpPrompt, string slot, CancellationToken ct)
{ {
var task = await _taskRepo.GetByIdAsync(taskId, ct) TaskEntity task;
?? throw new KeyNotFoundException($"Task '{taskId}' not found."); TaskRunEntity lastRun;
ListEntity list;
ListConfigEntity? listConfig;
WorktreeEntity? worktree;
var lastRun = await _runRepo.GetLatestByTaskIdAsync(taskId, ct) using (var context = _dbFactory.CreateDbContext())
?? throw new InvalidOperationException("No previous run to continue."); {
var taskRepo = new TaskRepository(context);
task = await taskRepo.GetByIdAsync(taskId, ct)
?? throw new KeyNotFoundException($"Task '{taskId}' not found.");
if (lastRun.SessionId is null) var runRepo = new TaskRunRepository(context);
throw new InvalidOperationException("Previous run has no session ID — cannot resume."); lastRun = await runRepo.GetLatestByTaskIdAsync(taskId, ct)
?? throw new InvalidOperationException("No previous run to continue.");
var list = await _listRepo.GetByIdAsync(task.ListId, ct) if (lastRun.SessionId is null)
?? throw new InvalidOperationException("List not found."); throw new InvalidOperationException("Previous run has no session ID — cannot resume.");
var listRepo = new ListRepository(context);
list = await listRepo.GetByIdAsync(task.ListId, ct)
?? throw new InvalidOperationException("List not found.");
listConfig = await listRepo.GetConfigAsync(task.ListId, ct);
var wtRepo = new WorktreeRepository(context);
worktree = await wtRepo.GetByTaskIdAsync(taskId, ct);
}
var listConfig = await _listRepo.GetConfigAsync(task.ListId, ct);
var resolvedConfig = new ClaudeRunConfig( var resolvedConfig = new ClaudeRunConfig(
Model: task.Model ?? listConfig?.Model, Model: task.Model ?? listConfig?.Model,
SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt, SystemPrompt: task.SystemPrompt ?? listConfig?.SystemPrompt,
@@ -178,7 +197,6 @@ public sealed class TaskRunner
// Determine run directory from existing worktree or sandbox. // Determine run directory from existing worktree or sandbox.
string runDir; string runDir;
WorktreeContext? wtCtx = null; WorktreeContext? wtCtx = null;
var worktree = await _wtRepo.GetByTaskIdAsync(taskId, ct);
if (worktree is not null) if (worktree is not null)
{ {
runDir = worktree.Path; runDir = worktree.Path;
@@ -190,7 +208,11 @@ public sealed class TaskRunner
} }
var now = DateTime.UtcNow; var now = DateTime.UtcNow;
await _taskRepo.MarkRunningAsync(taskId, now, ct); using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkRunningAsync(taskId, now, ct);
}
await _broadcaster.TaskStarted(slot, taskId, now); await _broadcaster.TaskStarted(slot, taskId, now);
var nextRunNumber = lastRun.RunNumber + 1; var nextRunNumber = lastRun.RunNumber + 1;
@@ -226,7 +248,12 @@ public sealed class TaskRunner
LogPath = logPath, LogPath = logPath,
StartedAt = DateTime.UtcNow, StartedAt = DateTime.UtcNow,
}; };
await _runRepo.AddAsync(run, ct);
using (var context = _dbFactory.CreateDbContext())
{
var runRepo = new TaskRunRepository(context);
await runRepo.AddAsync(run, ct);
}
var arguments = _argsBuilder.Build(config); var arguments = _argsBuilder.Build(config);
@@ -257,10 +284,15 @@ public sealed class TaskRunner
run.TokensIn = result.TokensIn; run.TokensIn = result.TokensIn;
run.TokensOut = result.TokensOut; run.TokensOut = result.TokensOut;
run.FinishedAt = DateTime.UtcNow; run.FinishedAt = DateTime.UtcNow;
await _runRepo.UpdateAsync(run, CancellationToken.None);
// Update denormalized fields on the task. using (var context = _dbFactory.CreateDbContext())
await _taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); {
var runRepo = new TaskRunRepository(context);
await runRepo.UpdateAsync(run, CancellationToken.None);
var taskRepo = new TaskRepository(context);
await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None);
}
return result; return result;
} }
@@ -273,8 +305,12 @@ public sealed class TaskRunner
run.FinishedAt = DateTime.UtcNow; run.FinishedAt = DateTime.UtcNow;
try try
{ {
await _runRepo.UpdateAsync(run, CancellationToken.None); using var context = _dbFactory.CreateDbContext();
await _taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None); var runRepo = new TaskRunRepository(context);
await runRepo.UpdateAsync(run, CancellationToken.None);
var taskRepo = new TaskRepository(context);
await taskRepo.SetLogPathAsync(taskId, logPath, CancellationToken.None);
} }
catch (Exception updateEx) catch (Exception updateEx)
{ {
@@ -297,7 +333,11 @@ public sealed class TaskRunner
// is never left as 'running' because of a cancel that arrived // is never left as 'running' because of a cancel that arrived
// after the Claude run already succeeded. // after the Claude run already succeeded.
var finishedAt = DateTime.UtcNow; var finishedAt = DateTime.UtcNow;
await _taskRepo.MarkDoneAsync(task.Id, finishedAt, result.ResultMarkdown, CancellationToken.None); using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
await taskRepo.MarkDoneAsync(task.Id, finishedAt, result.ResultMarkdown, CancellationToken.None);
}
await _broadcaster.TaskFinished(slot, task.Id, "done", finishedAt); await _broadcaster.TaskFinished(slot, task.Id, "done", finishedAt);
_logger.LogInformation("Task {TaskId} completed (turns={Turns}, tokens_in={In}, tokens_out={Out})", _logger.LogInformation("Task {TaskId} completed (turns={Turns}, tokens_in={In}, tokens_out={Out})",
task.Id, result.TurnCount, result.TokensIn, result.TokensOut); task.Id, result.TurnCount, result.TokensIn, result.TokensOut);
@@ -308,7 +348,9 @@ public sealed class TaskRunner
// Intentionally does not accept a CancellationToken: this is the // Intentionally does not accept a CancellationToken: this is the
// terminal write for a failed task and must always be persisted. // terminal write for a failed task and must always be persisted.
var finishedAt = DateTime.UtcNow; var finishedAt = DateTime.UtcNow;
await _taskRepo.MarkFailedAsync(taskId, finishedAt, result.ErrorMarkdown, CancellationToken.None); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.MarkFailedAsync(taskId, finishedAt, result.ErrorMarkdown, CancellationToken.None);
await _broadcaster.TaskFinished(slot, taskId, "failed", finishedAt); await _broadcaster.TaskFinished(slot, taskId, "failed", finishedAt);
_logger.LogWarning("Task {TaskId} failed (turns={Turns}): {Error}", taskId, result.TurnCount, result.ErrorMarkdown); _logger.LogWarning("Task {TaskId} failed (turns={Turns}): {Error}", taskId, result.TurnCount, result.ErrorMarkdown);
} }
@@ -319,7 +361,9 @@ public sealed class TaskRunner
{ {
var now = DateTime.UtcNow; var now = DateTime.UtcNow;
// Terminal write — never cancel. // Terminal write — never cancel.
await _taskRepo.MarkFailedAsync(taskId, now, error, CancellationToken.None); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
await taskRepo.MarkFailedAsync(taskId, now, error, CancellationToken.None);
await _broadcaster.TaskFinished(slot, taskId, "failed", now); await _broadcaster.TaskFinished(slot, taskId, "failed", now);
await _broadcaster.TaskUpdated(taskId); await _broadcaster.TaskUpdated(taskId);
} }

View File

@@ -1,7 +1,9 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Git; using ClaudeDo.Data.Git;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Config;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Runner; namespace ClaudeDo.Worker.Runner;
@@ -10,14 +12,14 @@ public sealed record WorktreeContext(string WorktreePath, string BranchName, str
public sealed class WorktreeManager public sealed class WorktreeManager
{ {
private readonly GitService _git; private readonly GitService _git;
private readonly WorktreeRepository _wtRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly WorkerConfig _cfg; private readonly WorkerConfig _cfg;
private readonly ILogger<WorktreeManager> _logger; private readonly ILogger<WorktreeManager> _logger;
public WorktreeManager(GitService git, WorktreeRepository wtRepo, WorkerConfig cfg, ILogger<WorktreeManager> logger) public WorktreeManager(GitService git, IDbContextFactory<ClaudeDoDbContext> dbFactory, WorkerConfig cfg, ILogger<WorktreeManager> logger)
{ {
_git = git; _git = git;
_wtRepo = wtRepo; _dbFactory = dbFactory;
_cfg = cfg; _cfg = cfg;
_logger = logger; _logger = logger;
} }
@@ -50,7 +52,9 @@ public sealed class WorktreeManager
await _git.WorktreeAddAsync(workingDir, branchName, worktreePath, baseCommit, ct); await _git.WorktreeAddAsync(workingDir, branchName, worktreePath, baseCommit, ct);
// Insert worktrees row AFTER git succeeds — if git throws, no row is created. // Insert worktrees row AFTER git succeeds — if git throws, no row is created.
await _wtRepo.AddAsync(new WorktreeEntity using var context = _dbFactory.CreateDbContext();
var wtRepo = new WorktreeRepository(context);
await wtRepo.AddAsync(new WorktreeEntity
{ {
TaskId = task.Id, TaskId = task.Id,
Path = worktreePath, Path = worktreePath,
@@ -87,7 +91,9 @@ public sealed class WorktreeManager
var head = await _git.RevParseHeadAsync(ctx.WorktreePath, ct); var head = await _git.RevParseHeadAsync(ctx.WorktreePath, ct);
var diffStat = await _git.DiffStatAsync(ctx.WorktreePath, ctx.BaseCommit, head, ct); var diffStat = await _git.DiffStatAsync(ctx.WorktreePath, ctx.BaseCommit, head, ct);
await _wtRepo.UpdateHeadAsync(task.Id, head, diffStat, ct); using var context = _dbFactory.CreateDbContext();
var wtRepo = new WorktreeRepository(context);
await wtRepo.UpdateHeadAsync(task.Id, head, diffStat, ct);
_logger.LogInformation("Committed changes for task {TaskId}: {Head}", task.Id, head); _logger.LogInformation("Committed changes for task {TaskId}: {Head}", task.Id, head);
return true; return true;

View File

@@ -1,7 +1,9 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Config; using ClaudeDo.Worker.Config;
using ClaudeDo.Worker.Runner; using ClaudeDo.Worker.Runner;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Services; namespace ClaudeDo.Worker.Services;
@@ -14,7 +16,7 @@ public sealed class QueueSlotState
public sealed class QueueService : BackgroundService public sealed class QueueService : BackgroundService
{ {
private readonly TaskRepository _taskRepo; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly TaskRunner _runner; private readonly TaskRunner _runner;
private readonly WorkerConfig _cfg; private readonly WorkerConfig _cfg;
private readonly ILogger<QueueService> _logger; private readonly ILogger<QueueService> _logger;
@@ -26,12 +28,12 @@ public sealed class QueueService : BackgroundService
private readonly SemaphoreSlim _wakeSignal = new(0, 1); private readonly SemaphoreSlim _wakeSignal = new(0, 1);
public QueueService( public QueueService(
TaskRepository taskRepo, IDbContextFactory<ClaudeDoDbContext> dbFactory,
TaskRunner runner, TaskRunner runner,
WorkerConfig cfg, WorkerConfig cfg,
ILogger<QueueService> logger) ILogger<QueueService> logger)
{ {
_taskRepo = taskRepo; _dbFactory = dbFactory;
_runner = runner; _runner = runner;
_cfg = cfg; _cfg = cfg;
_logger = logger; _logger = logger;
@@ -56,7 +58,9 @@ public sealed class QueueService : BackgroundService
public async Task RunNow(string taskId) public async Task RunNow(string taskId)
{ {
var task = await _taskRepo.GetByIdAsync(taskId); using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var task = await taskRepo.GetByIdAsync(taskId);
if (task is null) if (task is null)
throw new KeyNotFoundException($"Task '{taskId}' not found."); throw new KeyNotFoundException($"Task '{taskId}' not found.");
@@ -78,7 +82,9 @@ public sealed class QueueService : BackgroundService
public async Task<string> ContinueTask(string taskId, string followUpPrompt) public async Task<string> ContinueTask(string taskId, string followUpPrompt)
{ {
var task = await _taskRepo.GetByIdAsync(taskId) using var context = _dbFactory.CreateDbContext();
var taskRepo = new TaskRepository(context);
var task = await taskRepo.GetByIdAsync(taskId)
?? throw new KeyNotFoundException($"Task '{taskId}' not found."); ?? throw new KeyNotFoundException($"Task '{taskId}' not found.");
if (task.Status == Data.Models.TaskStatus.Running) if (task.Status == Data.Models.TaskStatus.Running)
@@ -144,7 +150,12 @@ public sealed class QueueService : BackgroundService
if (_queueSlot is not null) continue; if (_queueSlot is not null) continue;
var task = await _taskRepo.GetNextQueuedAgentTaskAsync(DateTime.UtcNow, stoppingToken); TaskEntity? task;
using (var context = _dbFactory.CreateDbContext())
{
var taskRepo = new TaskRepository(context);
task = await taskRepo.GetNextQueuedAgentTaskAsync(DateTime.UtcNow, stoppingToken);
}
if (task is null) continue; if (task is null) continue;
lock (_lock) lock (_lock)

View File

@@ -1,21 +1,25 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Services; namespace ClaudeDo.Worker.Services;
public sealed class StaleTaskRecovery : IHostedService public sealed class StaleTaskRecovery : IHostedService
{ {
private readonly TaskRepository _tasks; private readonly IDbContextFactory<ClaudeDoDbContext> _dbFactory;
private readonly ILogger<StaleTaskRecovery> _logger; private readonly ILogger<StaleTaskRecovery> _logger;
public StaleTaskRecovery(TaskRepository tasks, ILogger<StaleTaskRecovery> logger) public StaleTaskRecovery(IDbContextFactory<ClaudeDoDbContext> dbFactory, ILogger<StaleTaskRecovery> logger)
{ {
_tasks = tasks; _dbFactory = dbFactory;
_logger = logger; _logger = logger;
} }
public async Task StartAsync(CancellationToken cancellationToken) public async Task StartAsync(CancellationToken cancellationToken)
{ {
var flipped = await _tasks.FlipAllRunningToFailedAsync("worker restart", cancellationToken); using var context = _dbFactory.CreateDbContext();
var tasks = new TaskRepository(context);
var flipped = await tasks.FlipAllRunningToFailedAsync("worker restart", cancellationToken);
if (flipped > 0) if (flipped > 0)
_logger.LogWarning("Stale task recovery: flipped {Count} running task(s) to failed", flipped); _logger.LogWarning("Stale task recovery: flipped {Count} running task(s) to failed", flipped);
else else

View File

@@ -1,19 +1,30 @@
using ClaudeDo.Data; using ClaudeDo.Data;
using Microsoft.EntityFrameworkCore;
namespace ClaudeDo.Worker.Tests.Infrastructure; namespace ClaudeDo.Worker.Tests.Infrastructure;
public sealed class DbFixture : IDisposable public sealed class DbFixture : IDisposable
{ {
public string DbPath { get; } public string DbPath { get; }
public SqliteConnectionFactory Factory { get; }
public DbFixture() public DbFixture()
{ {
DbPath = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}.db"); DbPath = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}.db");
Factory = new SqliteConnectionFactory(DbPath); // Apply migrations so the schema is created.
SchemaInitializer.Apply(Factory); using var ctx = CreateContext();
ctx.Database.Migrate();
} }
public ClaudeDoDbContext CreateContext()
{
var options = new DbContextOptionsBuilder<ClaudeDoDbContext>()
.UseSqlite($"Data Source={DbPath}")
.Options;
return new ClaudeDoDbContext(options);
}
public TestDbContextFactory CreateFactory() => new(this);
public void Dispose() public void Dispose()
{ {
try { File.Delete(DbPath); } catch { /* best effort */ } try { File.Delete(DbPath); } catch { /* best effort */ }
@@ -21,3 +32,10 @@ public sealed class DbFixture : IDisposable
try { File.Delete(DbPath + "-shm"); } catch { } try { File.Delete(DbPath + "-shm"); } catch { }
} }
} }
public sealed class TestDbContextFactory : IDbContextFactory<ClaudeDoDbContext>
{
private readonly DbFixture _fixture;
public TestDbContextFactory(DbFixture fixture) => _fixture = fixture;
public ClaudeDoDbContext CreateDbContext() => _fixture.CreateContext();
}

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Tests.Infrastructure; using ClaudeDo.Worker.Tests.Infrastructure;
@@ -7,12 +8,14 @@ namespace ClaudeDo.Worker.Tests.Repositories;
public sealed class ListRepositoryConfigTests : IDisposable public sealed class ListRepositoryConfigTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly ListRepository _repo; private readonly ListRepository _repo;
private readonly string _listId; private readonly string _listId;
public ListRepositoryConfigTests() public ListRepositoryConfigTests()
{ {
_repo = new ListRepository(_db.Factory); _ctx = _db.CreateContext();
_repo = new ListRepository(_ctx);
_listId = Guid.NewGuid().ToString(); _listId = Guid.NewGuid().ToString();
_repo.AddAsync(new ListEntity _repo.AddAsync(new ListEntity
{ {
@@ -57,5 +60,9 @@ public sealed class ListRepositoryConfigTests : IDisposable
Assert.Equal("haiku-4-5", fetched.Model); Assert.Equal("haiku-4-5", fetched.Model);
} }
public void Dispose() => _db.Dispose(); public void Dispose()
{
_ctx.Dispose();
_db.Dispose();
}
} }

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Tests.Infrastructure; using ClaudeDo.Worker.Tests.Infrastructure;
@@ -7,16 +8,22 @@ namespace ClaudeDo.Worker.Tests.Repositories;
public sealed class ListRepositoryTests : IDisposable public sealed class ListRepositoryTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly ListRepository _lists; private readonly ListRepository _lists;
private readonly TagRepository _tags; private readonly TagRepository _tags;
public ListRepositoryTests() public ListRepositoryTests()
{ {
_lists = new ListRepository(_db.Factory); _ctx = _db.CreateContext();
_tags = new TagRepository(_db.Factory); _lists = new ListRepository(_ctx);
_tags = new TagRepository(_ctx);
} }
public void Dispose() => _db.Dispose(); public void Dispose()
{
_ctx.Dispose();
_db.Dispose();
}
[Fact] [Fact]
public async Task AddAsync_And_GetByIdAsync_Roundtrips() public async Task AddAsync_And_GetByIdAsync_Roundtrips()

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Tests.Infrastructure; using ClaudeDo.Worker.Tests.Infrastructure;
@@ -8,18 +9,24 @@ namespace ClaudeDo.Worker.Tests.Repositories;
public sealed class TaskRepositoryTests : IDisposable public sealed class TaskRepositoryTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly TaskRepository _tasks; private readonly TaskRepository _tasks;
private readonly ListRepository _lists; private readonly ListRepository _lists;
private readonly TagRepository _tags; private readonly TagRepository _tags;
public TaskRepositoryTests() public TaskRepositoryTests()
{ {
_tasks = new TaskRepository(_db.Factory); _ctx = _db.CreateContext();
_lists = new ListRepository(_db.Factory); _tasks = new TaskRepository(_ctx);
_tags = new TagRepository(_db.Factory); _lists = new ListRepository(_ctx);
_tags = new TagRepository(_ctx);
} }
public void Dispose() => _db.Dispose(); public void Dispose()
{
_ctx.Dispose();
_db.Dispose();
}
private async Task<string> CreateListAsync(string? id = null) private async Task<string> CreateListAsync(string? id = null)
{ {
@@ -197,7 +204,7 @@ public sealed class TaskRepositoryTests : IDisposable
var listId = await CreateListAsync(); var listId = await CreateListAsync();
var agentTagId = await _tags.GetOrCreateAsync("agent"); var agentTagId = await _tags.GetOrCreateAsync("agent");
var manualTagId = await _tags.GetOrCreateAsync("manual"); var manualTagId = await _tags.GetOrCreateAsync("manual");
var codeTagId = await TagRepository.GetOrCreateAsync(_db.Factory.Open(), "code"); var codeTagId = await _tags.GetOrCreateAsync("code");
await _lists.AddTagAsync(listId, agentTagId); await _lists.AddTagAsync(listId, agentTagId);

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Tests.Infrastructure; using ClaudeDo.Worker.Tests.Infrastructure;
@@ -7,16 +8,18 @@ namespace ClaudeDo.Worker.Tests.Repositories;
public sealed class TaskRunRepositoryTests : IDisposable public sealed class TaskRunRepositoryTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly TaskRunRepository _runs; private readonly TaskRunRepository _runs;
private readonly string _taskId; private readonly string _taskId;
public TaskRunRepositoryTests() public TaskRunRepositoryTests()
{ {
_runs = new TaskRunRepository(_db.Factory); _ctx = _db.CreateContext();
_runs = new TaskRunRepository(_ctx);
// Seed a list and task for all tests // Seed a list and task for all tests
var lists = new ListRepository(_db.Factory); var lists = new ListRepository(_ctx);
var tasks = new TaskRepository(_db.Factory); var tasks = new TaskRepository(_ctx);
var listId = Guid.NewGuid().ToString(); var listId = Guid.NewGuid().ToString();
lists.AddAsync(new ListEntity lists.AddAsync(new ListEntity
{ {
@@ -37,7 +40,11 @@ public sealed class TaskRunRepositoryTests : IDisposable
}).GetAwaiter().GetResult(); }).GetAwaiter().GetResult();
} }
public void Dispose() => _db.Dispose(); public void Dispose()
{
_ctx.Dispose();
_db.Dispose();
}
private TaskRunEntity MakeRun(int runNumber, bool isRetry = false) => new() private TaskRunEntity MakeRun(int runNumber, bool isRetry = false) => new()
{ {

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Git; using ClaudeDo.Data.Git;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
@@ -24,19 +25,19 @@ public class WorktreeManagerTests : IDisposable
return f; return f;
} }
private async Task<(WorktreeManager mgr, WorktreeRepository wtRepo)> CreateManagerAsync( private async Task<(WorktreeManager mgr, DbFixture db)> CreateManagerAsync(
TaskEntity task, ListEntity list, string strategy = "sibling", string? centralRoot = null) TaskEntity task, ListEntity list, string strategy = "sibling", string? centralRoot = null)
{ {
var db = new DbFixture(); var db = new DbFixture();
_dbFixtures.Add(db); _dbFixtures.Add(db);
// Seed the DB with list and task so FK constraints pass. // Seed the DB with list and task so FK constraints pass.
var listRepo = new ListRepository(db.Factory); using var seedCtx = db.CreateContext();
var taskRepo = new TaskRepository(db.Factory); var listRepo = new ListRepository(seedCtx);
var taskRepo = new TaskRepository(seedCtx);
await listRepo.AddAsync(list); await listRepo.AddAsync(list);
await taskRepo.AddAsync(task); await taskRepo.AddAsync(task);
var wtRepo = new WorktreeRepository(db.Factory);
var cfg = new WorkerConfig var cfg = new WorkerConfig
{ {
WorktreeRootStrategy = strategy, WorktreeRootStrategy = strategy,
@@ -45,8 +46,8 @@ public class WorktreeManagerTests : IDisposable
cfg.CentralWorktreeRoot = centralRoot; cfg.CentralWorktreeRoot = centralRoot;
var mgr = new WorktreeManager( var mgr = new WorktreeManager(
new GitService(), wtRepo, cfg, NullLogger<WorktreeManager>.Instance); new GitService(), db.CreateFactory(), cfg, NullLogger<WorktreeManager>.Instance);
return (mgr, wtRepo); return (mgr, db);
} }
[Fact] [Fact]
@@ -56,7 +57,7 @@ public class WorktreeManagerTests : IDisposable
var repo = CreateRepo(); var repo = CreateRepo();
var (task, list) = MakeEntities(repo.RepoDir); var (task, list) = MakeEntities(repo.RepoDir);
var (mgr, wtRepo) = await CreateManagerAsync(task, list); var (mgr, db) = await CreateManagerAsync(task, list);
var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None);
_worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath));
@@ -66,6 +67,8 @@ public class WorktreeManagerTests : IDisposable
Assert.Equal($"claudedo/{task.Id.Replace("-", "")}", ctx.BranchName); Assert.Equal($"claudedo/{task.Id.Replace("-", "")}", ctx.BranchName);
Assert.Equal(repo.BaseCommit, ctx.BaseCommit); Assert.Equal(repo.BaseCommit, ctx.BaseCommit);
using var readCtx = db.CreateContext();
var wtRepo = new WorktreeRepository(readCtx);
var row = await wtRepo.GetByTaskIdAsync(task.Id); var row = await wtRepo.GetByTaskIdAsync(task.Id);
Assert.NotNull(row); Assert.NotNull(row);
Assert.Equal(WorktreeState.Active, row!.State); Assert.Equal(WorktreeState.Active, row!.State);
@@ -80,7 +83,7 @@ public class WorktreeManagerTests : IDisposable
var repo = CreateRepo(); var repo = CreateRepo();
var (task, list) = MakeEntities(repo.RepoDir); var (task, list) = MakeEntities(repo.RepoDir);
var (mgr, wtRepo) = await CreateManagerAsync(task, list); var (mgr, db) = await CreateManagerAsync(task, list);
var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None);
_worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath));
@@ -88,6 +91,8 @@ public class WorktreeManagerTests : IDisposable
var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None); var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None);
Assert.False(committed); Assert.False(committed);
using var readCtx = db.CreateContext();
var wtRepo = new WorktreeRepository(readCtx);
var row = await wtRepo.GetByTaskIdAsync(task.Id); var row = await wtRepo.GetByTaskIdAsync(task.Id);
Assert.Null(row!.HeadCommit); Assert.Null(row!.HeadCommit);
} }
@@ -99,7 +104,7 @@ public class WorktreeManagerTests : IDisposable
var repo = CreateRepo(); var repo = CreateRepo();
var (task, list) = MakeEntities(repo.RepoDir); var (task, list) = MakeEntities(repo.RepoDir);
var (mgr, wtRepo) = await CreateManagerAsync(task, list); var (mgr, db) = await CreateManagerAsync(task, list);
var ctx = await mgr.CreateAsync(task, list, CancellationToken.None); var ctx = await mgr.CreateAsync(task, list, CancellationToken.None);
_worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath)); _worktreeCleanups.Add((repo.RepoDir, ctx.WorktreePath));
@@ -109,6 +114,8 @@ public class WorktreeManagerTests : IDisposable
var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None); var committed = await mgr.CommitIfChangedAsync(ctx, task, list, CancellationToken.None);
Assert.True(committed); Assert.True(committed);
using var readCtx = db.CreateContext();
var wtRepo = new WorktreeRepository(readCtx);
var row = await wtRepo.GetByTaskIdAsync(task.Id); var row = await wtRepo.GetByTaskIdAsync(task.Id);
Assert.NotNull(row!.HeadCommit); Assert.NotNull(row!.HeadCommit);
Assert.NotEqual(ctx.BaseCommit, row.HeadCommit); Assert.NotEqual(ctx.BaseCommit, row.HeadCommit);
@@ -129,20 +136,24 @@ public class WorktreeManagerTests : IDisposable
var db = new DbFixture(); var db = new DbFixture();
_dbFixtures.Add(db); _dbFixtures.Add(db);
var listRepo = new ListRepository(db.Factory); using (var seedCtx = db.CreateContext())
var taskRepo = new TaskRepository(db.Factory); {
await listRepo.AddAsync(list); var listRepo = new ListRepository(seedCtx);
await taskRepo.AddAsync(task); var taskRepo = new TaskRepository(seedCtx);
await listRepo.AddAsync(list);
await taskRepo.AddAsync(task);
}
var wtRepo = new WorktreeRepository(db.Factory);
var cfg = new WorkerConfig { WorktreeRootStrategy = "sibling" }; var cfg = new WorkerConfig { WorktreeRootStrategy = "sibling" };
var mgr = new WorktreeManager( var mgr = new WorktreeManager(
new GitService(), wtRepo, cfg, NullLogger<WorktreeManager>.Instance); new GitService(), db.CreateFactory(), cfg, NullLogger<WorktreeManager>.Instance);
var ex = await Assert.ThrowsAsync<InvalidOperationException>( var ex = await Assert.ThrowsAsync<InvalidOperationException>(
() => mgr.CreateAsync(task, list, CancellationToken.None)); () => mgr.CreateAsync(task, list, CancellationToken.None));
Assert.Contains("not a git repository", ex.Message); Assert.Contains("not a git repository", ex.Message);
using var readCtx = db.CreateContext();
var wtRepo = new WorktreeRepository(readCtx);
var row = await wtRepo.GetByTaskIdAsync(task.Id); var row = await wtRepo.GetByTaskIdAsync(task.Id);
Assert.Null(row); Assert.Null(row);
} }

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Git; using ClaudeDo.Data.Git;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
@@ -15,6 +16,7 @@ namespace ClaudeDo.Worker.Tests.Services;
public sealed class QueueServiceTests : IDisposable public sealed class QueueServiceTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly TaskRepository _taskRepo; private readonly TaskRepository _taskRepo;
private readonly ListRepository _listRepo; private readonly ListRepository _listRepo;
private readonly TagRepository _tagRepo; private readonly TagRepository _tagRepo;
@@ -23,9 +25,10 @@ public sealed class QueueServiceTests : IDisposable
public QueueServiceTests() public QueueServiceTests()
{ {
_taskRepo = new TaskRepository(_db.Factory); _ctx = _db.CreateContext();
_listRepo = new ListRepository(_db.Factory); _taskRepo = new TaskRepository(_ctx);
_tagRepo = new TagRepository(_db.Factory); _listRepo = new ListRepository(_ctx);
_tagRepo = new TagRepository(_ctx);
_tempDir = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}"); _tempDir = Path.Combine(Path.GetTempPath(), $"claudedo_test_{Guid.NewGuid():N}");
Directory.CreateDirectory(_tempDir); Directory.CreateDirectory(_tempDir);
_cfg = new WorkerConfig _cfg = new WorkerConfig
@@ -38,6 +41,7 @@ public sealed class QueueServiceTests : IDisposable
public void Dispose() public void Dispose()
{ {
_ctx.Dispose();
_db.Dispose(); _db.Dispose();
try { Directory.Delete(_tempDir, true); } catch { } try { Directory.Delete(_tempDir, true); } catch { }
} }
@@ -47,14 +51,12 @@ public sealed class QueueServiceTests : IDisposable
{ {
var fake = new FakeClaudeProcess(handler); var fake = new FakeClaudeProcess(handler);
var broadcaster = new HubBroadcaster(new FakeHubContext()); var broadcaster = new HubBroadcaster(new FakeHubContext());
var wtRepo = new WorktreeRepository(_db.Factory); var dbFactory = _db.CreateFactory();
var runRepo = new TaskRunRepository(_db.Factory); var wtManager = new WorktreeManager(new GitService(), dbFactory, _cfg, NullLogger<WorktreeManager>.Instance);
var wtManager = new WorktreeManager(new GitService(), wtRepo, _cfg, NullLogger<WorktreeManager>.Instance);
var argsBuilder = new ClaudeArgsBuilder(); var argsBuilder = new ClaudeArgsBuilder();
var subtaskRepo = new SubtaskRepository(_db.Factory); var runner = new TaskRunner(fake, dbFactory, broadcaster, wtManager, argsBuilder, _cfg,
var runner = new TaskRunner(fake, _taskRepo, runRepo, _listRepo, wtRepo, subtaskRepo, broadcaster, wtManager, argsBuilder, _cfg,
NullLogger<TaskRunner>.Instance); NullLogger<TaskRunner>.Instance);
var service = new QueueService(_taskRepo, runner, _cfg, NullLogger<QueueService>.Instance); var service = new QueueService(dbFactory, runner, _cfg, NullLogger<QueueService>.Instance);
return (service, fake); return (service, fake);
} }

View File

@@ -1,3 +1,4 @@
using ClaudeDo.Data;
using ClaudeDo.Data.Models; using ClaudeDo.Data.Models;
using ClaudeDo.Data.Repositories; using ClaudeDo.Data.Repositories;
using ClaudeDo.Worker.Services; using ClaudeDo.Worker.Services;
@@ -10,16 +11,22 @@ namespace ClaudeDo.Worker.Tests.Services;
public sealed class StaleTaskRecoveryTests : IDisposable public sealed class StaleTaskRecoveryTests : IDisposable
{ {
private readonly DbFixture _db = new(); private readonly DbFixture _db = new();
private readonly ClaudeDoDbContext _ctx;
private readonly TaskRepository _tasks; private readonly TaskRepository _tasks;
private readonly ListRepository _lists; private readonly ListRepository _lists;
public StaleTaskRecoveryTests() public StaleTaskRecoveryTests()
{ {
_tasks = new TaskRepository(_db.Factory); _ctx = _db.CreateContext();
_lists = new ListRepository(_db.Factory); _tasks = new TaskRepository(_ctx);
_lists = new ListRepository(_ctx);
} }
public void Dispose() => _db.Dispose(); public void Dispose()
{
_ctx.Dispose();
_db.Dispose();
}
[Fact] [Fact]
public async Task StartAsync_Flips_Running_Tasks_To_Failed() public async Task StartAsync_Flips_Running_Tasks_To_Failed()
@@ -47,7 +54,7 @@ public sealed class StaleTaskRecoveryTests : IDisposable
await _tasks.AddAsync(running); await _tasks.AddAsync(running);
await _tasks.AddAsync(queued); await _tasks.AddAsync(queued);
var recovery = new StaleTaskRecovery(_tasks, NullLogger<StaleTaskRecovery>.Instance); var recovery = new StaleTaskRecovery(_db.CreateFactory(), NullLogger<StaleTaskRecovery>.Instance);
await recovery.StartAsync(CancellationToken.None); await recovery.StartAsync(CancellationToken.None);
var r = await _tasks.GetByIdAsync(running.Id); var r = await _tasks.GetByIdAsync(running.Id);