diff --git a/internal/manager/api_impl/interfaces.go b/internal/manager/api_impl/interfaces.go index d965519f..127c1a02 100644 --- a/internal/manager/api_impl/interfaces.go +++ b/internal/manager/api_impl/interfaces.go @@ -36,8 +36,9 @@ type PersistenceService interface { SaveJobPriority(ctx context.Context, job *persistence.Job) error // FetchTask fetches the given task and the accompanying job. FetchTask(ctx context.Context, taskID string) (*persistence.Task, error) + // FetchTaskJobUUID fetches the UUID of the job this task belongs to. + FetchTaskJobUUID(ctx context.Context, taskID string) (string, error) FetchTaskFailureList(context.Context, *persistence.Task) ([]*persistence.Worker, error) - SaveTask(ctx context.Context, task *persistence.Task) error SaveTaskActivity(ctx context.Context, t *persistence.Task) error // TaskTouchedByWorker marks the task as 'touched' by a worker. This is used for timeout detection. TaskTouchedByWorker(context.Context, *persistence.Task) error diff --git a/internal/manager/api_impl/jobs.go b/internal/manager/api_impl/jobs.go index 9e9918b2..92d2aad9 100644 --- a/internal/manager/api_impl/jobs.go +++ b/internal/manager/api_impl/jobs.go @@ -439,7 +439,7 @@ func (f *Flamenco) FetchTaskLogInfo(e echo.Context, taskID string) error { return sendAPIError(e, http.StatusBadRequest, "bad task ID") } - dbTask, err := f.persist.FetchTask(ctx, taskID) + jobUUID, err := f.persist.FetchTaskJobUUID(ctx, taskID) if err != nil { if errors.Is(err, persistence.ErrTaskNotFound) { return sendAPIError(e, http.StatusNotFound, "no such task") @@ -447,9 +447,9 @@ func (f *Flamenco) FetchTaskLogInfo(e echo.Context, taskID string) error { logger.Error().Err(err).Msg("error fetching task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task: %v", err) } - logger = logger.With().Str("job", dbTask.Job.UUID).Logger() + logger = logger.With().Str("job", jobUUID).Logger() - size, err := f.logStorage.TaskLogSize(dbTask.Job.UUID, taskID) + size, err := f.logStorage.TaskLogSize(jobUUID, taskID) if err != nil { if errors.Is(err, os.ErrNotExist) { logger.Debug().Msg("task log unavailable, task has no log on disk") @@ -475,11 +475,11 @@ func (f *Flamenco) FetchTaskLogInfo(e echo.Context, taskID string) error { taskLogInfo := api.TaskLogInfo{ TaskId: taskID, - JobId: dbTask.Job.UUID, + JobId: jobUUID, Size: int(size), } - fullLogPath := f.logStorage.Filepath(dbTask.Job.UUID, taskID) + fullLogPath := f.logStorage.Filepath(jobUUID, taskID) relPath, err := f.localStorage.RelPath(fullLogPath) if err != nil { logger.Error().Err(err).Msg("task log is outside the manager storage, cannot construct its URL for download") @@ -501,7 +501,7 @@ func (f *Flamenco) FetchTaskLogTail(e echo.Context, taskID string) error { return sendAPIError(e, http.StatusBadRequest, "bad task ID") } - dbTask, err := f.persist.FetchTask(ctx, taskID) + jobUUID, err := f.persist.FetchTaskJobUUID(ctx, taskID) if err != nil { if errors.Is(err, persistence.ErrTaskNotFound) { return sendAPIError(e, http.StatusNotFound, "no such task") @@ -509,9 +509,9 @@ func (f *Flamenco) FetchTaskLogTail(e echo.Context, taskID string) error { logger.Error().Err(err).Msg("error fetching task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task: %v", err) } - logger = logger.With().Str("job", dbTask.Job.UUID).Logger() + logger = logger.With().Str("job", jobUUID).Logger() - tail, err := f.logStorage.Tail(dbTask.Job.UUID, taskID) + tail, err := f.logStorage.Tail(jobUUID, taskID) if err != nil { if errors.Is(err, os.ErrNotExist) { logger.Debug().Msg("task tail unavailable, task has no log on disk") @@ -700,7 +700,11 @@ func taskDBtoAPI(dbTask *persistence.Task) api.Task { Status: dbTask.Status, Activity: dbTask.Activity, Commands: make([]api.Command, len(dbTask.Commands)), - Worker: workerToTaskWorker(dbTask.Worker), + + // TODO: convert this to just store dbTask.WorkerUUID. + Worker: workerToTaskWorker(dbTask.Worker), + + JobId: dbTask.JobUUID, } if dbTask.Job != nil { diff --git a/internal/manager/api_impl/jobs_test.go b/internal/manager/api_impl/jobs_test.go index 61b52327..b5134cab 100644 --- a/internal/manager/api_impl/jobs_test.go +++ b/internal/manager/api_impl/jobs_test.go @@ -753,22 +753,10 @@ func TestFetchTaskLogTail(t *testing.T) { jobID := "18a9b096-d77e-438c-9be2-74397038298b" taskID := "2e020eee-20f8-4e95-8dcf-65f7dfc3ebab" - dbJob := persistence.Job{ - UUID: jobID, - Name: "test job", - Status: api.JobStatusActive, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, - } - dbTask := persistence.Task{ - UUID: taskID, - Job: &dbJob, - Name: "test task", - } // The task can be found, but has no on-disk task log. // This should not cause any error, but instead be returned as "no content". - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&dbTask, nil) + mf.persistence.EXPECT().FetchTaskJobUUID(gomock.Any(), taskID).Return(jobID, nil) mf.logStorage.EXPECT().Tail(jobID, taskID). Return("", fmt.Errorf("wrapped error: %w", os.ErrNotExist)) @@ -778,7 +766,7 @@ func TestFetchTaskLogTail(t *testing.T) { assertResponseNoContent(t, echoCtx) // Check that a 204 No Content is also returned when the task log file on disk exists, but is empty. - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&dbTask, nil) + mf.persistence.EXPECT().FetchTaskJobUUID(gomock.Any(), taskID).Return(jobID, nil) mf.logStorage.EXPECT().Tail(jobID, taskID). Return("", fmt.Errorf("wrapped error: %w", os.ErrNotExist)) @@ -796,21 +784,9 @@ func TestFetchTaskLogInfo(t *testing.T) { jobID := "18a9b096-d77e-438c-9be2-74397038298b" taskID := "2e020eee-20f8-4e95-8dcf-65f7dfc3ebab" - dbJob := persistence.Job{ - UUID: jobID, - Name: "test job", - Status: api.JobStatusActive, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, - } - dbTask := persistence.Task{ - UUID: taskID, - Job: &dbJob, - Name: "test task", - } mf.persistence.EXPECT(). - FetchTask(gomock.Any(), taskID). - Return(&dbTask, nil). + FetchTaskJobUUID(gomock.Any(), taskID). + Return(jobID, nil). AnyTimes() // The task can be found, but has no on-disk task log. diff --git a/internal/manager/api_impl/mocks/api_impl_mock.gen.go b/internal/manager/api_impl/mocks/api_impl_mock.gen.go index a7360e98..9d98ab19 100644 --- a/internal/manager/api_impl/mocks/api_impl_mock.gen.go +++ b/internal/manager/api_impl/mocks/api_impl_mock.gen.go @@ -244,6 +244,21 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTaskFailureList(arg0, arg1 in return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTaskFailureList", reflect.TypeOf((*MockPersistenceService)(nil).FetchTaskFailureList), arg0, arg1) } +// FetchTaskJobUUID mocks base method. +func (m *MockPersistenceService) FetchTaskJobUUID(arg0 context.Context, arg1 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchTaskJobUUID", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchTaskJobUUID indicates an expected call of FetchTaskJobUUID. +func (mr *MockPersistenceServiceMockRecorder) FetchTaskJobUUID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTaskJobUUID", reflect.TypeOf((*MockPersistenceService)(nil).FetchTaskJobUUID), arg0, arg1) +} + // FetchWorker mocks base method. func (m *MockPersistenceService) FetchWorker(arg0 context.Context, arg1 string) (*persistence.Worker, error) { m.ctrl.T.Helper() diff --git a/internal/manager/persistence/db.go b/internal/manager/persistence/db.go index 1ffe2806..388a4013 100644 --- a/internal/manager/persistence/db.go +++ b/internal/manager/persistence/db.go @@ -184,7 +184,9 @@ func (db *DB) queries() (*sqlc.Queries, error) { if err != nil { return nil, fmt.Errorf("could not get low-level database driver: %w", err) } - return sqlc.New(sqldb), nil + + loggingWrapper := LoggingDBConn{sqldb} + return sqlc.New(&loggingWrapper), nil } // now returns the result of `nowFunc()` wrapped in a sql.NullTime. diff --git a/internal/manager/persistence/errors.go b/internal/manager/persistence/errors.go index 24eb3dae..816a1383 100644 --- a/internal/manager/persistence/errors.go +++ b/internal/manager/persistence/errors.go @@ -2,6 +2,7 @@ package persistence import ( + "database/sql" "errors" "fmt" @@ -9,6 +10,7 @@ import ( ) var ( + // TODO: let these errors wrap database/sql.ErrNoRows. ErrJobNotFound = PersistenceError{Message: "job not found", Err: gorm.ErrRecordNotFound} ErrTaskNotFound = PersistenceError{Message: "task not found", Err: gorm.ErrRecordNotFound} ErrWorkerNotFound = PersistenceError{Message: "worker not found", Err: gorm.ErrRecordNotFound} @@ -63,36 +65,48 @@ func wrapError(errorToWrap error, message string, format ...interface{}) error { // translateGormJobError translates a Gorm error to a persistence layer error. // This helps to keep Gorm as "implementation detail" of the persistence layer. -func translateGormJobError(gormError error) error { - if errors.Is(gormError, gorm.ErrRecordNotFound) { +func translateGormJobError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return ErrTaskNotFound + } + if errors.Is(err, gorm.ErrRecordNotFound) { return ErrJobNotFound } - return gormError + return err } // translateGormTaskError translates a Gorm error to a persistence layer error. // This helps to keep Gorm as "implementation detail" of the persistence layer. -func translateGormTaskError(gormError error) error { - if errors.Is(gormError, gorm.ErrRecordNotFound) { +func translateGormTaskError(err error) error { + if errors.Is(err, sql.ErrNoRows) { return ErrTaskNotFound } - return gormError + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrTaskNotFound + } + return err } // translateGormWorkerError translates a Gorm error to a persistence layer error. // This helps to keep Gorm as "implementation detail" of the persistence layer. -func translateGormWorkerError(gormError error) error { - if errors.Is(gormError, gorm.ErrRecordNotFound) { +func translateGormWorkerError(err error) error { + if errors.Is(err, sql.ErrNoRows) { return ErrWorkerNotFound } - return gormError + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrWorkerNotFound + } + return err } // translateGormWorkerTagError translates a Gorm error to a persistence layer error. // This helps to keep Gorm as "implementation detail" of the persistence layer. -func translateGormWorkerTagError(gormError error) error { - if errors.Is(gormError, gorm.ErrRecordNotFound) { +func translateGormWorkerTagError(err error) error { + if errors.Is(err, sql.ErrNoRows) { return ErrWorkerTagNotFound } - return gormError + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrWorkerTagNotFound + } + return err } diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 8a2f94d3..2c0ce7ce 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -14,7 +14,6 @@ import ( "github.com/rs/zerolog/log" "gorm.io/gorm" - "gorm.io/gorm/clause" "projects.blender.org/studio/flamenco/internal/manager/job_compilers" "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" @@ -66,12 +65,14 @@ type Task struct { Type string `gorm:"type:varchar(32);default:''"` JobID uint `gorm:"default:0"` Job *Job `gorm:"foreignkey:JobID;references:ID;constraint:OnDelete:CASCADE"` + JobUUID string `gorm:"-"` // Fetched by SQLC, not GORM. Priority int `gorm:"type:smallint;default:50"` Status api.TaskStatus `gorm:"type:varchar(16);default:''"` // Which worker is/was working on this. WorkerID *uint Worker *Worker `gorm:"foreignkey:WorkerID;references:ID;constraint:OnDelete:SET NULL"` + WorkerUUID string `gorm:"-"` // Fetched by SQLC, not GORM. LastTouchedAt time.Time `gorm:"index"` // Should contain UTC timestamps. // Dependencies are tasks that need to be completed before this one can run. @@ -454,129 +455,279 @@ func (db *DB) SaveJobStorageInfo(ctx context.Context, j *Job) error { } func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { - dbTask := Task{} - tx := db.gormDB.WithContext(ctx). - // Allow finding the Worker, even after it was deleted. Jobs and Tasks - // don't have soft-deletion. - Unscoped(). - Joins("Job"). - Joins("Worker"). - First(&dbTask, "tasks.uuid = ?", taskUUID) - if tx.Error != nil { - return nil, taskError(tx.Error, "fetching task") + queries, err := db.queries() + if err != nil { + return nil, err } - return &dbTask, nil + + taskRow, err := queries.FetchTask(ctx, taskUUID) + if err != nil { + return nil, taskError(err, "fetching task %s", taskUUID) + } + + convertedTask, err := convertSqlcTask(taskRow.Task, taskRow.JobUUID.String, taskRow.WorkerUUID.String) + if err != nil { + return nil, err + } + + // TODO: remove this code, and let the caller fetch the job explicitly when needed. + if taskRow.Task.JobID > 0 { + dbJob, err := queries.FetchJobByID(ctx, taskRow.Task.JobID) + if err != nil { + return nil, jobError(err, "fetching job of task %s", taskUUID) + } + + convertedJob, err := convertSqlcJob(dbJob) + if err != nil { + return nil, jobError(err, "converting job of task %s", taskUUID) + } + convertedTask.Job = convertedJob + if convertedTask.JobUUID != convertedJob.UUID { + panic("Conversion to SQLC is incomplete") + } + } + + // TODO: remove this code, and let the caller fetch the Worker explicitly when needed. + if taskRow.WorkerUUID.Valid { + worker, err := queries.FetchWorkerUnconditional(ctx, taskRow.WorkerUUID.String) + if err != nil { + return nil, taskError(err, "fetching worker assigned to task %s", taskUUID) + } + convertedWorker := convertSqlcWorker(worker) + convertedTask.Worker = &convertedWorker + } + + return convertedTask, nil } +// FetchTaskJobUUID fetches the job UUID of the given task. +func (db *DB) FetchTaskJobUUID(ctx context.Context, taskUUID string) (string, error) { + queries, err := db.queries() + if err != nil { + return "", err + } + + jobUUID, err := queries.FetchTaskJobUUID(ctx, taskUUID) + if err != nil { + return "", taskError(err, "fetching job UUID of task %s", taskUUID) + } + if !jobUUID.Valid { + return "", PersistenceError{Message: fmt.Sprintf("unable to find job of task %s", taskUUID)} + } + return jobUUID.String, nil +} + +// SaveTask updates a task that already exists in the database. +// This function is not used by the Flamenco API, only by unit tests. func (db *DB) SaveTask(ctx context.Context, t *Task) error { - tx := db.gormDB.WithContext(ctx). - Omit("job"). - Omit("worker"). - Save(t) - if tx.Error != nil { - return taskError(tx.Error, "saving task") + if t.ID == 0 { + panic(fmt.Errorf("cannot use this function to insert a task")) + } + + queries, err := db.queries() + if err != nil { + return err + } + + commandsJSON, err := json.Marshal(t.Commands) + if err != nil { + return fmt.Errorf("cannot convert commands to JSON: %w", err) + } + + param := sqlc.UpdateTaskParams{ + UpdatedAt: db.now(), + Name: t.Name, + Type: t.Type, + Priority: int64(t.Priority), + Status: string(t.Status), + Commands: commandsJSON, + Activity: t.Activity, + ID: int64(t.ID), + } + if t.WorkerID != nil { + param.WorkerID = sql.NullInt64{ + Int64: int64(*t.WorkerID), + Valid: true, + } + } else if t.Worker != nil && t.Worker.ID > 0 { + param.WorkerID = sql.NullInt64{ + Int64: int64(t.Worker.ID), + Valid: true, + } + } + + if !t.LastTouchedAt.IsZero() { + param.LastTouchedAt = sql.NullTime{ + Time: t.LastTouchedAt, + Valid: true, + } + } + + err = queries.UpdateTask(ctx, param) + if err != nil { + return taskError(err, "updating task") } return nil } func (db *DB) SaveTaskStatus(ctx context.Context, t *Task) error { - tx := db.gormDB.WithContext(ctx). - Select("Status"). - Save(t) - if tx.Error != nil { - return taskError(tx.Error, "saving task") + queries, err := db.queries() + if err != nil { + return err + } + + err = queries.UpdateTaskStatus(ctx, sqlc.UpdateTaskStatusParams{ + UpdatedAt: db.now(), + Status: string(t.Status), + ID: int64(t.ID), + }) + if err != nil { + return taskError(err, "saving task status") } return nil } func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { - if err := db.gormDB.WithContext(ctx). - Model(t). - Select("Activity"). - Updates(Task{Activity: t.Activity}).Error; err != nil { + queries, err := db.queries() + if err != nil { + return err + } + + err = queries.UpdateTaskActivity(ctx, sqlc.UpdateTaskActivityParams{ + UpdatedAt: db.now(), + Activity: t.Activity, + ID: int64(t.ID), + }) + if err != nil { return taskError(err, "saving task activity") } return nil } +// TaskAssignToWorker assigns the given task to the given worker. +// This function is only used by unit tests. During normal operation, Flamenco +// uses the code in task_scheduler.go to assign tasks to workers. func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error { - tx := db.gormDB.WithContext(ctx). - Model(t). - Select("WorkerID"). - Updates(Task{WorkerID: &w.ID}) - if tx.Error != nil { - return taskError(tx.Error, "assigning task %s to worker %s", t.UUID, w.UUID) + queries, err := db.queries() + if err != nil { + return err } - // Gorm updates t.WorkerID itself, but not t.Worker (even when it's added to - // the Updates() call above). + err = queries.TaskAssignToWorker(ctx, sqlc.TaskAssignToWorkerParams{ + UpdatedAt: db.now(), + WorkerID: sql.NullInt64{ + Int64: int64(w.ID), + Valid: true, + }, + ID: int64(t.ID), + }) + if err != nil { + return taskError(err, "assigning task %s to worker %s", t.UUID, w.UUID) + } + + // Update the task itself. t.Worker = w + t.WorkerID = &w.ID return nil } func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]*Task, error) { - result := []*Task{} - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Joins("Job"). - Where("tasks.worker_id = ?", worker.ID). - Where("tasks.status = ?", taskStatus). - Scan(&result) - if tx.Error != nil { - return nil, taskError(tx.Error, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) + queries, err := db.queries() + if err != nil { + return nil, err + } + + rows, err := queries.FetchTasksOfWorkerInStatus(ctx, sqlc.FetchTasksOfWorkerInStatusParams{ + WorkerID: sql.NullInt64{ + Int64: int64(worker.ID), + Valid: true, + }, + TaskStatus: string(taskStatus), + }) + if err != nil { + return nil, taskError(err, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) + } + + result := make([]*Task, len(rows)) + for i := range rows { + gormTask, err := convertSqlcTask(rows[i].Task, rows[i].JobUUID.String, worker.UUID) + if err != nil { + return nil, err + } + gormTask.Worker = worker + gormTask.WorkerID = &worker.ID + result[i] = gormTask } return result, nil } func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worker, taskStatus api.TaskStatus, job *Job) ([]*Task, error) { - result := []*Task{} - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Joins("Job"). - Where("tasks.worker_id = ?", worker.ID). - Where("tasks.status = ?", taskStatus). - Where("job.id = ?", job.ID). - Scan(&result) - if tx.Error != nil { - return nil, taskError(tx.Error, "finding tasks of worker %s in status %q and job %s", worker.UUID, taskStatus, job.UUID) + queries, err := db.queries() + if err != nil { + return nil, err + } + + rows, err := queries.FetchTasksOfWorkerInStatusOfJob(ctx, sqlc.FetchTasksOfWorkerInStatusOfJobParams{ + WorkerID: sql.NullInt64{ + Int64: int64(worker.ID), + Valid: true, + }, + JobID: int64(job.ID), + TaskStatus: string(taskStatus), + }) + if err != nil { + return nil, taskError(err, "finding tasks of worker %s in status %q and job %s", worker.UUID, taskStatus, job.UUID) + } + + result := make([]*Task, len(rows)) + for i := range rows { + gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, worker.UUID) + if err != nil { + return nil, err + } + gormTask.Job = job + gormTask.JobID = job.ID + gormTask.Worker = worker + gormTask.WorkerID = &worker.ID + result[i] = gormTask } return result, nil } func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { - var numTasksInStatus int64 - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Where("job_id", job.ID). - Where("status", taskStatus). - Count(&numTasksInStatus) - if tx.Error != nil { - return false, taskError(tx.Error, "counting tasks of job %s in status %q", job.UUID, taskStatus) + queries, err := db.queries() + if err != nil { + return false, err } - return numTasksInStatus > 0, nil + + count, err := queries.JobCountTasksInStatus(ctx, sqlc.JobCountTasksInStatusParams{ + JobID: int64(job.ID), + TaskStatus: string(taskStatus), + }) + if err != nil { + return false, taskError(err, "counting tasks of job %s in status %q", job.UUID, taskStatus) + } + + return count > 0, nil } +// CountTasksOfJobInStatus counts the number of tasks in the job. +// It returns two counts, one is the number of tasks in the given statuses, the +// other is the total number of tasks of the job. func (db *DB) CountTasksOfJobInStatus( ctx context.Context, job *Job, taskStatuses ...api.TaskStatus, ) (numInStatus, numTotal int, err error) { - type Result struct { - Status api.TaskStatus - NumTasks int + queries, err := db.queries() + if err != nil { + return 0, 0, err } - var results []Result - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Select("status, count(*) as num_tasks"). - Where("job_id", job.ID). - Group("status"). - Scan(&results) - - if tx.Error != nil { - return 0, 0, jobError(tx.Error, "count tasks of job %s in status %q", job.UUID, taskStatuses) + results, err := queries.JobCountTaskStatuses(ctx, int64(job.ID)) + if err != nil { + return 0, 0, jobError(err, "count tasks of job %s in status %q", job.UUID, taskStatuses) } // Create lookup table for which statuses to count. @@ -587,10 +738,10 @@ func (db *DB) CountTasksOfJobInStatus( // Count the number of tasks per status. for _, result := range results { - if countStatus[result.Status] { - numInStatus += result.NumTasks + if countStatus[api.TaskStatus(result.Status)] { + numInStatus += int(result.NumTasks) } - numTotal += result.NumTasks + numTotal += int(result.NumTasks) } return @@ -598,39 +749,53 @@ func (db *DB) CountTasksOfJobInStatus( // FetchTaskIDsOfJob returns all tasks of the given job. func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { - var tasks []*Task - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Where("job_id", job.ID). - Scan(&tasks) - if tx.Error != nil { - return nil, taskError(tx.Error, "fetching tasks of job %s", job.UUID) + queries, err := db.queries() + if err != nil { + return nil, err } - for i := range tasks { - tasks[i].Job = job + rows, err := queries.FetchTasksOfJob(ctx, int64(job.ID)) + if err != nil { + return nil, taskError(err, "fetching tasks of job %s", job.UUID) } - return tasks, nil + result := make([]*Task, len(rows)) + for i := range rows { + gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, rows[i].WorkerUUID.String) + if err != nil { + return nil, err + } + gormTask.Job = job + result[i] = gormTask + } + return result, nil } // FetchTasksOfJobInStatus returns those tasks of the given job that have any of the given statuses. func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuses ...api.TaskStatus) ([]*Task, error) { - var tasks []*Task - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Where("job_id", job.ID). - Where("status in ?", taskStatuses). - Scan(&tasks) - if tx.Error != nil { - return nil, taskError(tx.Error, "fetching tasks of job %s in status %q", job.UUID, taskStatuses) + queries, err := db.queries() + if err != nil { + return nil, err } - for i := range tasks { - tasks[i].Job = job + rows, err := queries.FetchTasksOfJobInStatus(ctx, sqlc.FetchTasksOfJobInStatusParams{ + JobID: int64(job.ID), + TaskStatus: convertTaskStatuses(taskStatuses), + }) + if err != nil { + return nil, taskError(err, "fetching tasks of job %s in status %q", job.UUID, taskStatuses) } - return tasks, nil + result := make([]*Task, len(rows)) + for i := range rows { + gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, rows[i].WorkerUUID.String) + if err != nil { + return nil, err + } + gormTask.Job = job + result[i] = gormTask + } + return result, nil } // UpdateJobsTaskStatuses updates the status & activity of all tasks of `job`. @@ -641,13 +806,20 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Updates(Task{Status: taskStatus, Activity: activity}) + queries, err := db.queries() + if err != nil { + return err + } - if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks of job %s", job.UUID) + err = queries.UpdateJobsTaskStatuses(ctx, sqlc.UpdateJobsTaskStatusesParams{ + UpdatedAt: db.now(), + Status: string(taskStatus), + Activity: activity, + JobID: int64(job.ID), + }) + + if err != nil { + return taskError(err, "updating status of all tasks of job %s", job.UUID) } return nil } @@ -661,26 +833,45 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Where("status in ?", statusesToUpdate). - Updates(Task{Status: taskStatus, Activity: activity}) - if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) + queries, err := db.queries() + if err != nil { + return err + } + + err = queries.UpdateJobsTaskStatusesConditional(ctx, sqlc.UpdateJobsTaskStatusesConditionalParams{ + UpdatedAt: db.now(), + Status: string(taskStatus), + Activity: activity, + JobID: int64(job.ID), + StatusesToUpdate: convertTaskStatuses(statusesToUpdate), + }) + + if err != nil { + return taskError(err, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) } return nil } // TaskTouchedByWorker marks the task as 'touched' by a worker. This is used for timeout detection. func (db *DB) TaskTouchedByWorker(ctx context.Context, t *Task) error { - tx := db.gormDB.WithContext(ctx). - Model(t). - Select("LastTouchedAt"). - Updates(Task{LastTouchedAt: db.gormDB.NowFunc()}) - if err := tx.Error; err != nil { + queries, err := db.queries() + if err != nil { + return err + } + + now := db.now() + err = queries.TaskTouchedByWorker(ctx, sqlc.TaskTouchedByWorkerParams{ + UpdatedAt: now, + LastTouchedAt: now, + ID: int64(t.ID), + }) + if err != nil { return taskError(err, "saving task 'last touched at'") } + + // Also update the given task, so that it's consistent with the database. + t.LastTouchedAt = now.Time + return nil } @@ -693,64 +884,72 @@ func (db *DB) TaskTouchedByWorker(ctx context.Context, t *Task) error { // // Returns the new number of workers that failed this task. func (db *DB) AddWorkerToTaskFailedList(ctx context.Context, t *Task, w *Worker) (numFailed int, err error) { - entry := TaskFailure{ - Task: t, - Worker: w, - } - tx := db.gormDB.WithContext(ctx). - Clauses(clause.OnConflict{DoNothing: true}). - Create(&entry) - if tx.Error != nil { - return 0, tx.Error + queries, err := db.queries() + if err != nil { + return 0, err } - var numFailed64 int64 - tx = db.gormDB.WithContext(ctx).Model(&TaskFailure{}). - Where("task_id=?", t.ID). - Count(&numFailed64) + err = queries.AddWorkerToTaskFailedList(ctx, sqlc.AddWorkerToTaskFailedListParams{ + CreatedAt: db.now().Time, + TaskID: int64(t.ID), + WorkerID: int64(w.ID), + }) + if err != nil { + return 0, err + } + + numFailed64, err := queries.CountWorkersFailingTask(ctx, int64(t.ID)) + if err != nil { + return 0, err + } // Integer literals are of type `int`, so that's just a bit nicer to work with // than `int64`. if numFailed64 > math.MaxInt32 { log.Warn().Int64("numFailed", numFailed64).Msg("number of failed workers is crazy high, something is wrong here") - return math.MaxInt32, tx.Error + return math.MaxInt32, nil } - return int(numFailed64), tx.Error + return int(numFailed64), nil } // ClearFailureListOfTask clears the list of workers that failed this task. func (db *DB) ClearFailureListOfTask(ctx context.Context, t *Task) error { - tx := db.gormDB.WithContext(ctx). - Where("task_id = ?", t.ID). - Delete(&TaskFailure{}) - return tx.Error + queries, err := db.queries() + if err != nil { + return err + } + + return queries.ClearFailureListOfTask(ctx, int64(t.ID)) } // ClearFailureListOfJob en-mass, for all tasks of this job, clears the list of // workers that failed those tasks. func (db *DB) ClearFailureListOfJob(ctx context.Context, j *Job) error { + queries, err := db.queries() + if err != nil { + return err + } - // SQLite doesn't support JOIN in DELETE queries, so use a sub-query instead. - jobTasksQuery := db.gormDB.Model(&Task{}). - Select("id"). - Where("job_id = ?", j.ID) - - tx := db.gormDB.WithContext(ctx). - Where("task_id in (?)", jobTasksQuery). - Delete(&TaskFailure{}) - return tx.Error + return queries.ClearFailureListOfJob(ctx, int64(j.ID)) } func (db *DB) FetchTaskFailureList(ctx context.Context, t *Task) ([]*Worker, error) { - var workers []*Worker + queries, err := db.queries() + if err != nil { + return nil, err + } - tx := db.gormDB.WithContext(ctx). - Model(&Worker{}). - Joins("inner join task_failures TF on TF.worker_id = workers.id"). - Where("TF.task_id = ?", t.ID). - Scan(&workers) + failureList, err := queries.FetchTaskFailureList(ctx, int64(t.ID)) + if err != nil { + return nil, err + } - return workers, tx.Error + workers := make([]*Worker, len(failureList)) + for idx := range failureList { + worker := convertSqlcWorker(failureList[idx].Worker) + workers[idx] = &worker + } + return workers, nil } // convertSqlcJob converts a job from the SQLC-generated model to the model @@ -791,3 +990,52 @@ func convertSqlcJob(job sqlc.Job) (*Job, error) { return &dbJob, nil } + +// convertSqlcTask converts a FetchTaskRow from the SQLC-generated model to the +// model expected by the rest of the code. This is mostly in place to aid in the +// GORM to SQLC migration. It is intended that eventually the rest of the code +// will use the same SQLC-generated model. +func convertSqlcTask(task sqlc.Task, jobUUID string, workerUUID string) (*Task, error) { + dbTask := Task{ + Model: Model{ + ID: uint(task.ID), + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt.Time, + }, + + UUID: task.UUID, + Name: task.Name, + Type: task.Type, + Priority: int(task.Priority), + Status: api.TaskStatus(task.Status), + LastTouchedAt: task.LastTouchedAt.Time, + Activity: task.Activity, + + JobID: uint(task.JobID), + JobUUID: jobUUID, + WorkerUUID: workerUUID, + } + + // TODO: convert dependencies? + + if task.WorkerID.Valid { + workerID := uint(task.WorkerID.Int64) + dbTask.WorkerID = &workerID + } + + if err := json.Unmarshal(task.Commands, &dbTask.Commands); err != nil { + return nil, taskError(err, fmt.Sprintf("task %s of job %s has invalid commands: %v", + task.UUID, jobUUID, err)) + } + + return &dbTask, nil +} + +// convertTaskStatuses converts from []api.TaskStatus to []string for feeding to sqlc. +func convertTaskStatuses(taskStatuses []api.TaskStatus) []string { + statusesAsStrings := make([]string, len(taskStatuses)) + for index := range taskStatuses { + statusesAsStrings[index] = string(taskStatuses[index]) + } + return statusesAsStrings +} diff --git a/internal/manager/persistence/jobs_blocklist_test.go b/internal/manager/persistence/jobs_blocklist_test.go index 436c5f18..cd264f7f 100644 --- a/internal/manager/persistence/jobs_blocklist_test.go +++ b/internal/manager/persistence/jobs_blocklist_test.go @@ -238,9 +238,12 @@ func TestCountTaskFailuresOfWorker(t *testing.T) { ctx, close, db, dbJob, authoredJob := jobTasksTestFixtures(t) defer close() - task0, _ := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) - task1, _ := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) - task2, _ := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + task0, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + require.NoError(t, err) + task1, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + require.NoError(t, err) + task2, err := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + require.NoError(t, err) // Sanity check on the test data. assert.Equal(t, "blender", task0.Type) diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 000a6db4..af338f25 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -75,6 +75,19 @@ func TestStoreAuthoredJobWithShamanCheckoutID(t *testing.T) { assert.Equal(t, job.Storage.ShamanCheckoutID, fetchedJob.Storage.ShamanCheckoutID) } +func TestFetchTaskJobUUID(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, 1*time.Second) + defer cancel() + + job := createTestAuthoredJobWithTasks() + err := db.StoreAuthoredJob(ctx, job) + require.NoError(t, err) + + jobUUID, err := db.FetchTaskJobUUID(ctx, job.Tasks[0].UUID) + require.NoError(t, err) + assert.Equal(t, job.JobID, jobUUID) +} + func TestSaveJobStorageInfo(t *testing.T) { // Test that saving job storage info doesn't count as "update". // This is necessary for `cmd/shaman-checkout-id-setter` to do its work quietly. @@ -383,6 +396,12 @@ func TestCountTasksOfJobInStatus(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, numActive) assert.Equal(t, 3, numTotal) + + numCounted, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, + api.TaskStatusFailed, api.TaskStatusQueued) + require.NoError(t, err) + assert.Equal(t, 3, numCounted) + assert.Equal(t, 3, numTotal) } func TestCheckIfJobsHoldLargeNumOfTasks(t *testing.T) { diff --git a/internal/manager/persistence/logger.go b/internal/manager/persistence/logger.go index 2135a006..3e346ef8 100644 --- a/internal/manager/persistence/logger.go +++ b/internal/manager/persistence/logger.go @@ -4,13 +4,16 @@ package persistence import ( "context" + "database/sql" "errors" "fmt" "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" + "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" ) // dbLogger implements the behaviour of Gorm's default logger on top of Zerolog. @@ -126,3 +129,28 @@ func (l dbLogger) logger(args ...interface{}) zerolog.Logger { } return logCtx.Logger() } + +// LoggingDBConn wraps a database/sql.DB connection, so that it can be used with +// sqlc and log all the queries. +type LoggingDBConn struct { + wrappedConn sqlc.DBTX +} + +var _ sqlc.DBTX = (*LoggingDBConn)(nil) + +func (ldbc *LoggingDBConn) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query Exec") + return ldbc.wrappedConn.ExecContext(ctx, sql, args...) +} +func (ldbc *LoggingDBConn) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + log.Trace().Str("sql", sql).Msg("database: query Prepare") + return ldbc.wrappedConn.PrepareContext(ctx, sql) +} +func (ldbc *LoggingDBConn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query Query") + return ldbc.wrappedConn.QueryContext(ctx, sql, args...) +} +func (ldbc *LoggingDBConn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query QueryRow") + return ldbc.wrappedConn.QueryRowContext(ctx, sql, args...) +} diff --git a/internal/manager/persistence/sqlc/db.go b/internal/manager/persistence/sqlc/db.go index 8ed64d13..c5852e06 100644 --- a/internal/manager/persistence/sqlc/db.go +++ b/internal/manager/persistence/sqlc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package sqlc diff --git a/internal/manager/persistence/sqlc/models.go b/internal/manager/persistence/sqlc/models.go index d57a5b64..5bee5d96 100644 --- a/internal/manager/persistence/sqlc/models.go +++ b/internal/manager/persistence/sqlc/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package sqlc diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 0f606454..b8d79770 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -18,9 +18,15 @@ INSERT INTO jobs ( VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ); -- name: FetchJob :one +-- Fetch a job by its UUID. SELECT * FROM jobs WHERE uuid = ? LIMIT 1; +-- name: FetchJobByID :one +-- Fetch a job by its numerical ID. +SELECT * FROM jobs +WHERE id = ? LIMIT 1; + -- name: DeleteJob :exec DELETE FROM jobs WHERE uuid = ?; @@ -55,3 +61,129 @@ UPDATE jobs SET updated_at=@now, priority=@priority WHERE id=@id; -- name: SaveJobStorageInfo :exec UPDATE jobs SET storage_shaman_checkout_id=@storage_shaman_checkout_id WHERE id=@id; + +-- name: FetchTask :one +SELECT sqlc.embed(tasks), jobs.UUID as jobUUID, workers.UUID as workerUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.uuid = @uuid; + +-- name: FetchTasksOfWorkerInStatus :many +SELECT sqlc.embed(tasks), jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.worker_id = @worker_id + AND tasks.status = @task_status; + +-- name: FetchTasksOfWorkerInStatusOfJob :many +SELECT sqlc.embed(tasks) +FROM tasks +WHERE tasks.worker_id = @worker_id + AND tasks.job_id = @job_id + AND tasks.status = @task_status; + +-- name: FetchTasksOfJob :many +SELECT sqlc.embed(tasks), workers.UUID as workerUUID +FROM tasks +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.job_id = @job_id; + +-- name: FetchTasksOfJobInStatus :many +SELECT sqlc.embed(tasks), workers.UUID as workerUUID +FROM tasks +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.job_id = @job_id + AND tasks.status in (sqlc.slice('task_status')); + +-- name: FetchTaskJobUUID :one +SELECT jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.uuid = @uuid; + +-- name: UpdateTask :exec +-- Update a Task, except its id, created_at, uuid, or job_id fields. +UPDATE tasks SET + updated_at = @updated_at, + name = @name, + type = @type, + priority = @priority, + status = @status, + worker_id = @worker_id, + last_touched_at = @last_touched_at, + commands = @commands, + activity = @activity +WHERE id=@id; + +-- name: UpdateTaskStatus :exec +UPDATE tasks SET + updated_at = @updated_at, + status = @status +WHERE id=@id; + +-- name: UpdateTaskActivity :exec +UPDATE tasks SET + updated_at = @updated_at, + activity = @activity +WHERE id=@id; + +-- name: UpdateJobsTaskStatusesConditional :exec +UPDATE tasks SET + updated_at = @updated_at, + status = @status, + activity = @activity +WHERE job_id = @job_id AND status in (sqlc.slice('statuses_to_update')); + +-- name: UpdateJobsTaskStatuses :exec +UPDATE tasks SET + updated_at = @updated_at, + status = @status, + activity = @activity +WHERE job_id = @job_id; + +-- name: TaskAssignToWorker :exec +UPDATE tasks SET + updated_at = @updated_at, + worker_id = @worker_id +WHERE id=@id; + +-- name: TaskTouchedByWorker :exec +UPDATE tasks SET + updated_at = @updated_at, + last_touched_at = @last_touched_at +WHERE id=@id; + +-- name: JobCountTasksInStatus :one +-- Fetch number of tasks in the given status, of the given job. +SELECT count(*) as num_tasks FROM tasks +WHERE job_id = @job_id AND status = @task_status; + +-- name: JobCountTaskStatuses :many +-- Fetch (status, num tasks in that status) rows for the given job. +SELECT status, count(*) as num_tasks FROM tasks +WHERE job_id = @job_id +GROUP BY status; + +-- name: AddWorkerToTaskFailedList :exec +INSERT INTO task_failures (created_at, task_id, worker_id) +VALUES (@created_at, @task_id, @worker_id) +ON CONFLICT DO NOTHING; + +-- name: CountWorkersFailingTask :one +-- Count how many workers have failed a given task. +SELECT count(*) as num_failed FROM task_failures +WHERE task_id=@task_id; + +-- name: ClearFailureListOfTask :exec +DELETE FROM task_failures WHERE task_id=@task_id; + +-- name: ClearFailureListOfJob :exec +-- SQLite doesn't support JOIN in DELETE queries, so use a sub-query instead. +DELETE FROM task_failures +WHERE task_id in (SELECT id FROM tasks WHERE job_id=@job_id); + +-- name: FetchTaskFailureList :many +SELECT sqlc.embed(workers) FROM workers +INNER JOIN task_failures TF on TF.worker_id=workers.id +WHERE TF.task_id=@task_id; diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 26032839..a5cef021 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: query_jobs.sql package sqlc @@ -13,6 +13,56 @@ import ( "time" ) +const addWorkerToTaskFailedList = `-- name: AddWorkerToTaskFailedList :exec +INSERT INTO task_failures (created_at, task_id, worker_id) +VALUES (?1, ?2, ?3) +ON CONFLICT DO NOTHING +` + +type AddWorkerToTaskFailedListParams struct { + CreatedAt time.Time + TaskID int64 + WorkerID int64 +} + +func (q *Queries) AddWorkerToTaskFailedList(ctx context.Context, arg AddWorkerToTaskFailedListParams) error { + _, err := q.db.ExecContext(ctx, addWorkerToTaskFailedList, arg.CreatedAt, arg.TaskID, arg.WorkerID) + return err +} + +const clearFailureListOfJob = `-- name: ClearFailureListOfJob :exec +DELETE FROM task_failures +WHERE task_id in (SELECT id FROM tasks WHERE job_id=?1) +` + +// SQLite doesn't support JOIN in DELETE queries, so use a sub-query instead. +func (q *Queries) ClearFailureListOfJob(ctx context.Context, jobID int64) error { + _, err := q.db.ExecContext(ctx, clearFailureListOfJob, jobID) + return err +} + +const clearFailureListOfTask = `-- name: ClearFailureListOfTask :exec +DELETE FROM task_failures WHERE task_id=?1 +` + +func (q *Queries) ClearFailureListOfTask(ctx context.Context, taskID int64) error { + _, err := q.db.ExecContext(ctx, clearFailureListOfTask, taskID) + return err +} + +const countWorkersFailingTask = `-- name: CountWorkersFailingTask :one +SELECT count(*) as num_failed FROM task_failures +WHERE task_id=?1 +` + +// Count how many workers have failed a given task. +func (q *Queries) CountWorkersFailingTask(ctx context.Context, taskID int64) (int64, error) { + row := q.db.QueryRowContext(ctx, countWorkersFailingTask, taskID) + var num_failed int64 + err := row.Scan(&num_failed) + return num_failed, err +} + const createJob = `-- name: CreateJob :exec INSERT INTO jobs ( @@ -44,6 +94,7 @@ type CreateJobParams struct { } // Jobs / Tasks queries +// func (q *Queries) CreateJob(ctx context.Context, arg CreateJobParams) error { _, err := q.db.ExecContext(ctx, createJob, arg.CreatedAt, @@ -74,6 +125,7 @@ SELECT id, created_at, updated_at, uuid, name, job_type, priority, status, activ WHERE uuid = ? LIMIT 1 ` +// Fetch a job by its UUID. func (q *Queries) FetchJob(ctx context.Context, uuid string) (Job, error) { row := q.db.QueryRowContext(ctx, fetchJob, uuid) var i Job @@ -96,6 +148,34 @@ func (q *Queries) FetchJob(ctx context.Context, uuid string) (Job, error) { return i, err } +const fetchJobByID = `-- name: FetchJobByID :one +SELECT id, created_at, updated_at, uuid, name, job_type, priority, status, activity, settings, metadata, delete_requested_at, storage_shaman_checkout_id, worker_tag_id FROM jobs +WHERE id = ? LIMIT 1 +` + +// Fetch a job by its numerical ID. +func (q *Queries) FetchJobByID(ctx context.Context, id int64) (Job, error) { + row := q.db.QueryRowContext(ctx, fetchJobByID, id) + var i Job + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.UUID, + &i.Name, + &i.JobType, + &i.Priority, + &i.Status, + &i.Activity, + &i.Settings, + &i.Metadata, + &i.DeleteRequestedAt, + &i.StorageShamanCheckoutID, + &i.WorkerTagID, + ) + return i, err +} + const fetchJobUUIDsUpdatedBefore = `-- name: FetchJobUUIDsUpdatedBefore :many SELECT uuid FROM jobs WHERE updated_at <= ?1 ` @@ -204,6 +284,388 @@ func (q *Queries) FetchJobsInStatus(ctx context.Context, statuses []string) ([]J return items, nil } +const fetchTask = `-- name: FetchTask :one +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.UUID as jobUUID, workers.UUID as workerUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.uuid = ?1 +` + +type FetchTaskRow struct { + Task Task + JobUUID sql.NullString + WorkerUUID sql.NullString +} + +func (q *Queries) FetchTask(ctx context.Context, uuid string) (FetchTaskRow, error) { + row := q.db.QueryRowContext(ctx, fetchTask, uuid) + var i FetchTaskRow + err := row.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.JobUUID, + &i.WorkerUUID, + ) + return i, err +} + +const fetchTaskFailureList = `-- name: FetchTaskFailureList :many +SELECT workers.id, workers.created_at, workers.updated_at, workers.uuid, workers.secret, workers.name, workers.address, workers.platform, workers.software, workers.status, workers.last_seen_at, workers.status_requested, workers.lazy_status_request, workers.supported_task_types, workers.deleted_at, workers.can_restart FROM workers +INNER JOIN task_failures TF on TF.worker_id = workers.id +WHERE TF.task_id=?1 +` + +type FetchTaskFailureListRow struct { + Worker Worker +} + +func (q *Queries) FetchTaskFailureList(ctx context.Context, taskID int64) ([]FetchTaskFailureListRow, error) { + rows, err := q.db.QueryContext(ctx, fetchTaskFailureList, taskID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTaskFailureListRow + for rows.Next() { + var i FetchTaskFailureListRow + if err := rows.Scan( + &i.Worker.ID, + &i.Worker.CreatedAt, + &i.Worker.UpdatedAt, + &i.Worker.UUID, + &i.Worker.Secret, + &i.Worker.Name, + &i.Worker.Address, + &i.Worker.Platform, + &i.Worker.Software, + &i.Worker.Status, + &i.Worker.LastSeenAt, + &i.Worker.StatusRequested, + &i.Worker.LazyStatusRequest, + &i.Worker.SupportedTaskTypes, + &i.Worker.DeletedAt, + &i.Worker.CanRestart, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchTaskJobUUID = `-- name: FetchTaskJobUUID :one +SELECT jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.uuid = ?1 +` + +func (q *Queries) FetchTaskJobUUID(ctx context.Context, uuid string) (sql.NullString, error) { + row := q.db.QueryRowContext(ctx, fetchTaskJobUUID, uuid) + var jobuuid sql.NullString + err := row.Scan(&jobuuid) + return jobuuid, err +} + +const fetchTasksOfJob = `-- name: FetchTasksOfJob :many +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, workers.UUID as workerUUID +FROM tasks +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.job_id = ?1 +` + +type FetchTasksOfJobRow struct { + Task Task + WorkerUUID sql.NullString +} + +func (q *Queries) FetchTasksOfJob(ctx context.Context, jobID int64) ([]FetchTasksOfJobRow, error) { + rows, err := q.db.QueryContext(ctx, fetchTasksOfJob, jobID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTasksOfJobRow + for rows.Next() { + var i FetchTasksOfJobRow + if err := rows.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.WorkerUUID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchTasksOfJobInStatus = `-- name: FetchTasksOfJobInStatus :many +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, workers.UUID as workerUUID +FROM tasks +LEFT JOIN workers ON (tasks.worker_id = workers.id) +WHERE tasks.job_id = ?1 + AND tasks.status in (/*SLICE:task_status*/?) +` + +type FetchTasksOfJobInStatusParams struct { + JobID int64 + TaskStatus []string +} + +type FetchTasksOfJobInStatusRow struct { + Task Task + WorkerUUID sql.NullString +} + +func (q *Queries) FetchTasksOfJobInStatus(ctx context.Context, arg FetchTasksOfJobInStatusParams) ([]FetchTasksOfJobInStatusRow, error) { + query := fetchTasksOfJobInStatus + var queryParams []interface{} + queryParams = append(queryParams, arg.JobID) + if len(arg.TaskStatus) > 0 { + for _, v := range arg.TaskStatus { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:task_status*/?", strings.Repeat(",?", len(arg.TaskStatus))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:task_status*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTasksOfJobInStatusRow + for rows.Next() { + var i FetchTasksOfJobInStatusRow + if err := rows.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.WorkerUUID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchTasksOfWorkerInStatus = `-- name: FetchTasksOfWorkerInStatus :many +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.worker_id = ?1 + AND tasks.status = ?2 +` + +type FetchTasksOfWorkerInStatusParams struct { + WorkerID sql.NullInt64 + TaskStatus string +} + +type FetchTasksOfWorkerInStatusRow struct { + Task Task + JobUUID sql.NullString +} + +func (q *Queries) FetchTasksOfWorkerInStatus(ctx context.Context, arg FetchTasksOfWorkerInStatusParams) ([]FetchTasksOfWorkerInStatusRow, error) { + rows, err := q.db.QueryContext(ctx, fetchTasksOfWorkerInStatus, arg.WorkerID, arg.TaskStatus) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTasksOfWorkerInStatusRow + for rows.Next() { + var i FetchTasksOfWorkerInStatusRow + if err := rows.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.JobUUID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchTasksOfWorkerInStatusOfJob = `-- name: FetchTasksOfWorkerInStatusOfJob :many +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity +FROM tasks +WHERE tasks.worker_id = ?1 + AND tasks.job_id = ?2 + AND tasks.status = ?3 +` + +type FetchTasksOfWorkerInStatusOfJobParams struct { + WorkerID sql.NullInt64 + JobID int64 + TaskStatus string +} + +type FetchTasksOfWorkerInStatusOfJobRow struct { + Task Task +} + +func (q *Queries) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, arg FetchTasksOfWorkerInStatusOfJobParams) ([]FetchTasksOfWorkerInStatusOfJobRow, error) { + rows, err := q.db.QueryContext(ctx, fetchTasksOfWorkerInStatusOfJob, arg.WorkerID, arg.JobID, arg.TaskStatus) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTasksOfWorkerInStatusOfJobRow + for rows.Next() { + var i FetchTasksOfWorkerInStatusOfJobRow + if err := rows.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountTaskStatuses = `-- name: JobCountTaskStatuses :many +SELECT status, count(*) as num_tasks FROM tasks +WHERE job_id = ?1 +GROUP BY status +` + +type JobCountTaskStatusesRow struct { + Status string + NumTasks int64 +} + +// Fetch (status, num tasks in that status) rows for the given job. +func (q *Queries) JobCountTaskStatuses(ctx context.Context, jobID int64) ([]JobCountTaskStatusesRow, error) { + rows, err := q.db.QueryContext(ctx, jobCountTaskStatuses, jobID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []JobCountTaskStatusesRow + for rows.Next() { + var i JobCountTaskStatusesRow + if err := rows.Scan(&i.Status, &i.NumTasks); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const jobCountTasksInStatus = `-- name: JobCountTasksInStatus :one +SELECT count(*) as num_tasks FROM tasks +WHERE job_id = ?1 AND status = ?2 +` + +type JobCountTasksInStatusParams struct { + JobID int64 + TaskStatus string +} + +// Fetch number of tasks in the given status, of the given job. +func (q *Queries) JobCountTasksInStatus(ctx context.Context, arg JobCountTasksInStatusParams) (int64, error) { + row := q.db.QueryRowContext(ctx, jobCountTasksInStatus, arg.JobID, arg.TaskStatus) + var num_tasks int64 + err := row.Scan(&num_tasks) + return num_tasks, err +} + const requestJobDeletion = `-- name: RequestJobDeletion :exec UPDATE jobs SET updated_at = ?1, @@ -298,3 +760,179 @@ func (q *Queries) SaveJobStorageInfo(ctx context.Context, arg SaveJobStorageInfo _, err := q.db.ExecContext(ctx, saveJobStorageInfo, arg.StorageShamanCheckoutID, arg.ID) return err } + +const taskAssignToWorker = `-- name: TaskAssignToWorker :exec +UPDATE tasks SET + updated_at = ?1, + worker_id = ?2 +WHERE id=?3 +` + +type TaskAssignToWorkerParams struct { + UpdatedAt sql.NullTime + WorkerID sql.NullInt64 + ID int64 +} + +func (q *Queries) TaskAssignToWorker(ctx context.Context, arg TaskAssignToWorkerParams) error { + _, err := q.db.ExecContext(ctx, taskAssignToWorker, arg.UpdatedAt, arg.WorkerID, arg.ID) + return err +} + +const taskTouchedByWorker = `-- name: TaskTouchedByWorker :exec +UPDATE tasks SET + updated_at = ?1, + last_touched_at = ?2 +WHERE id=?3 +` + +type TaskTouchedByWorkerParams struct { + UpdatedAt sql.NullTime + LastTouchedAt sql.NullTime + ID int64 +} + +func (q *Queries) TaskTouchedByWorker(ctx context.Context, arg TaskTouchedByWorkerParams) error { + _, err := q.db.ExecContext(ctx, taskTouchedByWorker, arg.UpdatedAt, arg.LastTouchedAt, arg.ID) + return err +} + +const updateJobsTaskStatuses = `-- name: UpdateJobsTaskStatuses :exec +UPDATE tasks SET + updated_at = ?1, + status = ?2, + activity = ?3 +WHERE job_id = ?4 +` + +type UpdateJobsTaskStatusesParams struct { + UpdatedAt sql.NullTime + Status string + Activity string + JobID int64 +} + +func (q *Queries) UpdateJobsTaskStatuses(ctx context.Context, arg UpdateJobsTaskStatusesParams) error { + _, err := q.db.ExecContext(ctx, updateJobsTaskStatuses, + arg.UpdatedAt, + arg.Status, + arg.Activity, + arg.JobID, + ) + return err +} + +const updateJobsTaskStatusesConditional = `-- name: UpdateJobsTaskStatusesConditional :exec +UPDATE tasks SET + updated_at = ?1, + status = ?2, + activity = ?3 +WHERE job_id = ?4 AND status in (/*SLICE:statuses_to_update*/?) +` + +type UpdateJobsTaskStatusesConditionalParams struct { + UpdatedAt sql.NullTime + Status string + Activity string + JobID int64 + StatusesToUpdate []string +} + +func (q *Queries) UpdateJobsTaskStatusesConditional(ctx context.Context, arg UpdateJobsTaskStatusesConditionalParams) error { + query := updateJobsTaskStatusesConditional + var queryParams []interface{} + queryParams = append(queryParams, arg.UpdatedAt) + queryParams = append(queryParams, arg.Status) + queryParams = append(queryParams, arg.Activity) + queryParams = append(queryParams, arg.JobID) + if len(arg.StatusesToUpdate) > 0 { + for _, v := range arg.StatusesToUpdate { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:statuses_to_update*/?", strings.Repeat(",?", len(arg.StatusesToUpdate))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:statuses_to_update*/?", "NULL", 1) + } + _, err := q.db.ExecContext(ctx, query, queryParams...) + return err +} + +const updateTask = `-- name: UpdateTask :exec +UPDATE tasks SET + updated_at = ?1, + name = ?2, + type = ?3, + priority = ?4, + status = ?5, + worker_id = ?6, + last_touched_at = ?7, + commands = ?8, + activity = ?9 +WHERE id=?10 +` + +type UpdateTaskParams struct { + UpdatedAt sql.NullTime + Name string + Type string + Priority int64 + Status string + WorkerID sql.NullInt64 + LastTouchedAt sql.NullTime + Commands json.RawMessage + Activity string + ID int64 +} + +// Update a Task, except its id, created_at, uuid, or job_id fields. +func (q *Queries) UpdateTask(ctx context.Context, arg UpdateTaskParams) error { + _, err := q.db.ExecContext(ctx, updateTask, + arg.UpdatedAt, + arg.Name, + arg.Type, + arg.Priority, + arg.Status, + arg.WorkerID, + arg.LastTouchedAt, + arg.Commands, + arg.Activity, + arg.ID, + ) + return err +} + +const updateTaskActivity = `-- name: UpdateTaskActivity :exec +UPDATE tasks SET + updated_at = ?1, + activity = ?2 +WHERE id=?3 +` + +type UpdateTaskActivityParams struct { + UpdatedAt sql.NullTime + Activity string + ID int64 +} + +func (q *Queries) UpdateTaskActivity(ctx context.Context, arg UpdateTaskActivityParams) error { + _, err := q.db.ExecContext(ctx, updateTaskActivity, arg.UpdatedAt, arg.Activity, arg.ID) + return err +} + +const updateTaskStatus = `-- name: UpdateTaskStatus :exec +UPDATE tasks SET + updated_at = ?1, + status = ?2 +WHERE id=?3 +` + +type UpdateTaskStatusParams struct { + UpdatedAt sql.NullTime + Status string + ID int64 +} + +func (q *Queries) UpdateTaskStatus(ctx context.Context, arg UpdateTaskStatusParams) error { + _, err := q.db.ExecContext(ctx, updateTaskStatus, arg.UpdatedAt, arg.Status, arg.ID) + return err +} diff --git a/internal/manager/persistence/sqlc/query_workers.sql b/internal/manager/persistence/sqlc/query_workers.sql new file mode 100644 index 00000000..a2d6713c --- /dev/null +++ b/internal/manager/persistence/sqlc/query_workers.sql @@ -0,0 +1,18 @@ + +-- Worker queries +-- + +-- name: FetchWorker :one +-- FetchWorker only returns the worker if it wasn't soft-deleted. +SELECT * FROM workers WHERE workers.uuid = @uuid and deleted_at is NULL; + +-- name: FetchWorkerUnconditional :one +-- FetchWorkerUnconditional ignores soft-deletion status and just returns the worker. +SELECT * FROM workers WHERE workers.uuid = @uuid; + +-- name: FetchWorkerTags :many +SELECT worker_tags.* +FROM worker_tags +LEFT JOIN worker_tag_membership m ON (m.worker_tag_id = worker_tags.id) +LEFT JOIN workers on (m.worker_id = workers.id) +WHERE workers.uuid = @uuid; diff --git a/internal/manager/persistence/sqlc/query_workers.sql.go b/internal/manager/persistence/sqlc/query_workers.sql.go new file mode 100644 index 00000000..8d7c8682 --- /dev/null +++ b/internal/manager/persistence/sqlc/query_workers.sql.go @@ -0,0 +1,109 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 +// source: query_workers.sql + +package sqlc + +import ( + "context" +) + +const fetchWorker = `-- name: FetchWorker :one + +SELECT id, created_at, updated_at, uuid, secret, name, address, platform, software, status, last_seen_at, status_requested, lazy_status_request, supported_task_types, deleted_at, can_restart FROM workers WHERE workers.uuid = ?1 and deleted_at is NULL +` + +// Worker queries +// +// FetchWorker only returns the worker if it wasn't soft-deleted. +func (q *Queries) FetchWorker(ctx context.Context, uuid string) (Worker, error) { + row := q.db.QueryRowContext(ctx, fetchWorker, uuid) + var i Worker + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.UUID, + &i.Secret, + &i.Name, + &i.Address, + &i.Platform, + &i.Software, + &i.Status, + &i.LastSeenAt, + &i.StatusRequested, + &i.LazyStatusRequest, + &i.SupportedTaskTypes, + &i.DeletedAt, + &i.CanRestart, + ) + return i, err +} + +const fetchWorkerTags = `-- name: FetchWorkerTags :many +SELECT worker_tags.id, worker_tags.created_at, worker_tags.updated_at, worker_tags.uuid, worker_tags.name, worker_tags.description +FROM worker_tags +LEFT JOIN worker_tag_membership m ON (m.worker_tag_id = worker_tags.id) +LEFT JOIN workers on (m.worker_id = workers.id) +WHERE workers.uuid = ?1 +` + +func (q *Queries) FetchWorkerTags(ctx context.Context, uuid string) ([]WorkerTag, error) { + rows, err := q.db.QueryContext(ctx, fetchWorkerTags, uuid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkerTag + for rows.Next() { + var i WorkerTag + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.UUID, + &i.Name, + &i.Description, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchWorkerUnconditional = `-- name: FetchWorkerUnconditional :one +SELECT id, created_at, updated_at, uuid, secret, name, address, platform, software, status, last_seen_at, status_requested, lazy_status_request, supported_task_types, deleted_at, can_restart FROM workers WHERE workers.uuid = ?1 +` + +// FetchWorkerUnconditional ignores soft-deletion status and just returns the worker. +func (q *Queries) FetchWorkerUnconditional(ctx context.Context, uuid string) (Worker, error) { + row := q.db.QueryRowContext(ctx, fetchWorkerUnconditional, uuid) + var i Worker + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.UUID, + &i.Secret, + &i.Name, + &i.Address, + &i.Platform, + &i.Software, + &i.Status, + &i.LastSeenAt, + &i.StatusRequested, + &i.LazyStatusRequest, + &i.SupportedTaskTypes, + &i.DeletedAt, + &i.CanRestart, + ) + return i, err +} diff --git a/internal/manager/persistence/test_support.go b/internal/manager/persistence/test_support.go index cba04ade..a0c2a92f 100644 --- a/internal/manager/persistence/test_support.go +++ b/internal/manager/persistence/test_support.go @@ -15,7 +15,6 @@ import ( "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "gorm.io/gorm" - "projects.blender.org/studio/flamenco/internal/uuid" "projects.blender.org/studio/flamenco/pkg/api" ) @@ -106,7 +105,7 @@ func workerTestFixtures(t *testing.T, testContextTimeout time.Duration) WorkerTe ctx, cancel, db := persistenceTestFixtures(t, testContextTimeout) w := Worker{ - UUID: uuid.New(), + UUID: "557930e7-5b55-469e-a6d7-fc800f3685be", Name: "дрон", Address: "fe80::5054:ff:fede:2ad7", Platform: "linux", @@ -116,7 +115,7 @@ func workerTestFixtures(t *testing.T, testContextTimeout time.Duration) WorkerTe } wc := WorkerTag{ - UUID: uuid.New(), + UUID: "e0e05417-9793-4829-b1d0-d446dd819f3d", Name: "arbejdsklynge", Description: "Worker tag in Danish", } diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index a9637d54..c7d87f9e 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -10,6 +10,7 @@ import ( "github.com/rs/zerolog/log" "gorm.io/gorm" + "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" "projects.blender.org/studio/flamenco/pkg/api" ) @@ -73,18 +74,30 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { } func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) { - w := Worker{} - tx := db.gormDB.WithContext(ctx). - Preload("Tags"). - Find(&w, "uuid = ?", uuid). - Limit(1) - if tx.Error != nil { - return nil, workerError(tx.Error, "fetching worker") + queries, err := db.queries() + if err != nil { + return nil, err } - if w.ID == 0 { - return nil, ErrWorkerNotFound + + worker, err := queries.FetchWorker(ctx, uuid) + if err != nil { + return nil, workerError(err, "fetching worker %s", uuid) } - return &w, nil + + // TODO: remove this code, and let the caller fetch the tags when interested in them. + workerTags, err := queries.FetchWorkerTags(ctx, uuid) + if err != nil { + return nil, workerTagError(err, "fetching tags of worker %s", uuid) + } + + convertedWorker := convertSqlcWorker(worker) + convertedWorker.Tags = make([]*WorkerTag, len(workerTags)) + for index := range workerTags { + convertedTag := convertSqlcWorkerTag(workerTags[index]) + convertedWorker.Tags[index] = &convertedTag + } + + return &convertedWorker, nil } func (db *DB) DeleteWorker(ctx context.Context, uuid string) error { @@ -216,3 +229,48 @@ func (db *DB) SummarizeWorkerStatuses(ctx context.Context) (WorkerStatusCount, e return statusCounts, nil } + +// convertSqlcWorker converts a worker from the SQLC-generated model to the model +// expected by the rest of the code. This is mostly in place to aid in the GORM +// to SQLC migration. It is intended that eventually the rest of the code will +// use the same SQLC-generated model. +func convertSqlcWorker(worker sqlc.Worker) Worker { + return Worker{ + Model: Model{ + ID: uint(worker.ID), + CreatedAt: worker.CreatedAt, + UpdatedAt: worker.UpdatedAt.Time, + }, + DeletedAt: gorm.DeletedAt(worker.DeletedAt), + + UUID: worker.UUID, + Secret: worker.Secret, + Name: worker.Name, + Address: worker.Address, + Platform: worker.Platform, + Software: worker.Software, + Status: api.WorkerStatus(worker.Status), + LastSeenAt: worker.LastSeenAt.Time, + CanRestart: worker.CanRestart != 0, + StatusRequested: api.WorkerStatus(worker.StatusRequested), + LazyStatusRequest: worker.LazyStatusRequest != 0, + SupportedTaskTypes: worker.SupportedTaskTypes, + } +} + +// convertSqlcWorkerTag converts a worker tag from the SQLC-generated model to +// the model expected by the rest of the code. This is mostly in place to aid in +// the GORM to SQLC migration. It is intended that eventually the rest of the +// code will use the same SQLC-generated model. +func convertSqlcWorkerTag(tag sqlc.WorkerTag) WorkerTag { + return WorkerTag{ + Model: Model{ + ID: uint(tag.ID), + CreatedAt: tag.CreatedAt, + UpdatedAt: tag.UpdatedAt.Time, + }, + UUID: tag.UUID, + Name: tag.Name, + Description: tag.Description, + } +} diff --git a/internal/manager/persistence/workers_test.go b/internal/manager/persistence/workers_test.go index 8bd6ab94..27f28f5a 100644 --- a/internal/manager/persistence/workers_test.go +++ b/internal/manager/persistence/workers_test.go @@ -408,7 +408,7 @@ func TestSummarizeWorkerStatusesTimeout(t *testing.T) { // Force a timeout of the context. And yes, even when a nanosecond is quite // short, it is still necessary to wait. - time.Sleep(2 * time.Nanosecond) + time.Sleep(1 * time.Millisecond) // Test the summary. summary, err := f.db.SummarizeWorkerStatuses(subCtx) diff --git a/sqlc.yaml b/sqlc.yaml index 887ce133..5fd3fcbc 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -14,3 +14,23 @@ sql: rename: uuid: "UUID" uuids: "UUIDs" + jobuuid: "JobUUID" + taskUUID: "TaskUUID" + workeruuid: "WorkerUUID" + - engine: "sqlite" + schema: "internal/manager/persistence/sqlc/schema.sql" + queries: "internal/manager/persistence/sqlc/query_workers.sql" + gen: + go: + out: "internal/manager/persistence/sqlc" + overrides: + - db_type: "jsonb" + go_type: + import: "encoding/json" + type: "RawMessage" + rename: + uuid: "UUID" + uuids: "UUIDs" + jobuuid: "JobUUID" + taskUUID: "TaskUUID" + workeruuid: "WorkerUUID"