diff --git a/src/ClaudeDo.Worker/Services/TaskResetService.cs b/src/ClaudeDo.Worker/Services/TaskResetService.cs new file mode 100644 index 0000000..5602ac9 --- /dev/null +++ b/src/ClaudeDo.Worker/Services/TaskResetService.cs @@ -0,0 +1,68 @@ +using ClaudeDo.Data; +using ClaudeDo.Data.Models; +using ClaudeDo.Data.Repositories; +using ClaudeDo.Worker.Hub; +using ClaudeDo.Worker.Runner; +using Microsoft.EntityFrameworkCore; +using TaskStatus = ClaudeDo.Data.Models.TaskStatus; + +namespace ClaudeDo.Worker.Services; + +public sealed class TaskResetService +{ + private readonly IDbContextFactory _dbFactory; + private readonly WorktreeManager _wtManager; + private readonly HubBroadcaster _broadcaster; + private readonly ILogger _logger; + + public TaskResetService( + IDbContextFactory dbFactory, + WorktreeManager wtManager, + HubBroadcaster broadcaster, + ILogger logger) + { + _dbFactory = dbFactory; + _wtManager = wtManager; + _broadcaster = broadcaster; + _logger = logger; + } + + public async Task ResetAsync(string taskId, CancellationToken ct) + { + TaskEntity task; + ListEntity list; + WorktreeEntity? wt; + + using (var ctx = _dbFactory.CreateDbContext()) + { + task = await new TaskRepository(ctx).GetByIdAsync(taskId, ct) + ?? throw new KeyNotFoundException($"Task '{taskId}' not found."); + + if (task.Status == TaskStatus.Running) + throw new InvalidOperationException("Cannot reset a running task. Cancel it first."); + + list = await new ListRepository(ctx).GetByIdAsync(task.ListId, ct) + ?? throw new InvalidOperationException("List not found."); + + wt = await new WorktreeRepository(ctx).GetByTaskIdAsync(taskId, ct); + } + + bool worktreeChanged = false; + if (wt is not null && wt.State == WorktreeState.Active && list.WorkingDir is not null) + { + await _wtManager.DiscardAsync(wt, list.WorkingDir, ct); + worktreeChanged = true; + } + + using (var ctx = _dbFactory.CreateDbContext()) + { + await new TaskRepository(ctx).ResetToManualAsync(taskId, ct); + } + + await _broadcaster.TaskUpdated(taskId); + if (worktreeChanged) + await _broadcaster.WorktreeUpdated(taskId); + + _logger.LogInformation("Reset task {TaskId} to Manual (worktree discarded: {Discarded})", taskId, worktreeChanged); + } +} diff --git a/tests/ClaudeDo.Worker.Tests/Services/TaskResetServiceTests.cs b/tests/ClaudeDo.Worker.Tests/Services/TaskResetServiceTests.cs new file mode 100644 index 0000000..3c147d0 --- /dev/null +++ b/tests/ClaudeDo.Worker.Tests/Services/TaskResetServiceTests.cs @@ -0,0 +1,231 @@ +using ClaudeDo.Data.Models; +using ClaudeDo.Data.Repositories; +using ClaudeDo.Worker.Config; +using ClaudeDo.Worker.Hub; +using ClaudeDo.Worker.Runner; +using ClaudeDo.Worker.Services; +using ClaudeDo.Worker.Tests.Infrastructure; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Logging.Abstractions; +using TaskStatus = ClaudeDo.Data.Models.TaskStatus; + +namespace ClaudeDo.Worker.Tests.Services; + +public class TaskResetServiceTests : IDisposable +{ + private readonly List _dbs = new(); + private readonly List _repos = new(); + private readonly List<(string repoDir, string wtPath)> _worktreeCleanups = new(); + + private DbFixture NewDb() { var d = new DbFixture(); _dbs.Add(d); return d; } + private GitRepoFixture NewRepo() { var r = new GitRepoFixture(); _repos.Add(r); return r; } + + public void Dispose() + { + foreach (var (repoDir, wtPath) in _worktreeCleanups) + { + try { GitRepoFixture.RunGit(repoDir, "worktree", "remove", "--force", wtPath); } catch { } + } + foreach (var d in _dbs) try { d.Dispose(); } catch { } + foreach (var r in _repos) try { r.Dispose(); } catch { } + } + + private static (TaskResetService svc, RecordingClientProxy proxy) BuildService(DbFixture db, WorktreeManager wtMgr) + { + var fakeHub = new RecordingHubContext(); + var broadcaster = new HubBroadcaster(fakeHub); + var svc = new TaskResetService( + db.CreateFactory(), + wtMgr, + broadcaster, + NullLogger.Instance); + return (svc, fakeHub.Proxy); + } + + private static WorktreeManager BuildWorktreeManager(DbFixture db) + { + var cfg = new WorkerConfig { WorktreeRootStrategy = "sibling" }; + return new WorktreeManager( + new ClaudeDo.Data.Git.GitService(), + db.CreateFactory(), + cfg, + NullLogger.Instance); + } + + [Fact] + public async Task ResetAsync_FailedTaskWithWorktree_ClearsEverything_AndPreservesRunHistory() + { + if (!GitRepoFixture.IsGitAvailable()) return; + + var repo = NewRepo(); + var db = NewDb(); + var wtMgr = BuildWorktreeManager(db); + + var list = new ListEntity + { + Id = Guid.NewGuid().ToString(), + Name = "reset-test", + WorkingDir = repo.RepoDir, + DefaultCommitType = "feat", + CreatedAt = DateTime.UtcNow, + }; + var task = new TaskEntity + { + Id = Guid.NewGuid().ToString(), + ListId = list.Id, + Title = "test task", + Status = TaskStatus.Failed, + StartedAt = DateTime.UtcNow.AddMinutes(-5), + FinishedAt = DateTime.UtcNow.AddMinutes(-1), + Result = "some error", + CreatedAt = DateTime.UtcNow, + }; + + using (var ctx = db.CreateContext()) + { + await new ListRepository(ctx).AddAsync(list); + await new TaskRepository(ctx).AddAsync(task); + } + + var wtCtx = await wtMgr.CreateAsync(task, list, CancellationToken.None); + _worktreeCleanups.Add((repo.RepoDir, wtCtx.WorktreePath)); + + using (var ctx = db.CreateContext()) + { + await new TaskRunRepository(ctx).AddAsync(new TaskRunEntity + { + Id = Guid.NewGuid().ToString(), + TaskId = task.Id, + RunNumber = 1, + IsRetry = false, + Prompt = "do the thing", + SessionId = "s1", + }); + } + + var (svc, proxy) = BuildService(db, wtMgr); + await svc.ResetAsync(task.Id, CancellationToken.None); + + // Task must be Manual with cleared timestamps/result + using (var ctx = db.CreateContext()) + { + var updated = await new TaskRepository(ctx).GetByIdAsync(task.Id); + Assert.NotNull(updated); + Assert.Equal(TaskStatus.Manual, updated!.Status); + Assert.Null(updated.Result); + Assert.Null(updated.StartedAt); + Assert.Null(updated.FinishedAt); + } + + // Worktree directory must be gone + Assert.False(Directory.Exists(wtCtx.WorktreePath)); + + // Worktree DB row must be Discarded + using (var ctx = db.CreateContext()) + { + var wt = await new WorktreeRepository(ctx).GetByTaskIdAsync(task.Id); + Assert.NotNull(wt); + Assert.Equal(WorktreeState.Discarded, wt!.State); + } + + // task_runs row must still be present + using (var ctx = db.CreateContext()) + { + var runs = await new TaskRunRepository(ctx).GetByTaskIdAsync(task.Id); + Assert.Single(runs); + Assert.Equal("s1", runs[0].SessionId); + } + + // Broadcaster must have fired TaskUpdated AND WorktreeUpdated + Assert.Contains(proxy.Calls, i => i.Method == "TaskUpdated" && i.Args[0] is string s && s == task.Id); + Assert.Contains(proxy.Calls, i => i.Method == "WorktreeUpdated" && i.Args[0] is string s && s == task.Id); + } + + [Fact] + public async Task ResetAsync_RunningTask_Throws_AndDoesNotMutate() + { + var db = NewDb(); + var wtMgr = BuildWorktreeManager(db); + + var list = new ListEntity + { + Id = Guid.NewGuid().ToString(), + Name = "no-op list", + WorkingDir = null, + DefaultCommitType = "feat", + CreatedAt = DateTime.UtcNow, + }; + var startedAt = DateTime.UtcNow.AddMinutes(-2); + var task = new TaskEntity + { + Id = Guid.NewGuid().ToString(), + ListId = list.Id, + Title = "running task", + Status = TaskStatus.Running, + StartedAt = startedAt, + CreatedAt = DateTime.UtcNow, + }; + + using (var ctx = db.CreateContext()) + { + await new ListRepository(ctx).AddAsync(list); + await new TaskRepository(ctx).AddAsync(task); + } + + var (svc, proxy) = BuildService(db, wtMgr); + + await Assert.ThrowsAsync( + () => svc.ResetAsync(task.Id, CancellationToken.None)); + + // Task must be unchanged + using (var ctx = db.CreateContext()) + { + var unchanged = await new TaskRepository(ctx).GetByIdAsync(task.Id); + Assert.NotNull(unchanged); + Assert.Equal(TaskStatus.Running, unchanged!.Status); + Assert.Equal(startedAt, unchanged.StartedAt); + } + + // No broadcaster invocations + Assert.Empty(proxy.Calls); + } +} + +#region Test doubles + +internal sealed record HubCall(string Method, object?[] Args); + +internal sealed class RecordingClientProxy : IClientProxy +{ + public readonly List Calls = new(); + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + Calls.Add(new HubCall(method, args)); + return Task.CompletedTask; + } +} + +internal sealed class RecordingHubClients : IHubClients +{ + public RecordingClientProxy AllProxy { get; } = new(); + public IClientProxy All => AllProxy; + public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) => AllProxy; + public IClientProxy Client(string connectionId) => AllProxy; + public IClientProxy Clients(IReadOnlyList connectionIds) => AllProxy; + public IClientProxy Group(string groupName) => AllProxy; + public IClientProxy GroupExcept(string groupName, IReadOnlyList excludedConnectionIds) => AllProxy; + public IClientProxy Groups(IReadOnlyList groupNames) => AllProxy; + public IClientProxy User(string userId) => AllProxy; + public IClientProxy Users(IReadOnlyList userIds) => AllProxy; +} + +internal sealed class RecordingHubContext : IHubContext +{ + private readonly RecordingHubClients _clients = new(); + public RecordingClientProxy Proxy => _clients.AllProxy; + public IHubClients Clients => _clients; + public IGroupManager Groups => throw new NotImplementedException(); +} + +#endregion