diff --git a/api/custom/tasks.go b/api/custom/tasks.go index 0a390957994e843cd405e3b50e7f7c5d3bbbafd3..3d25e709381996aee10f6f66aded8a42c4bb0190 100644 --- a/api/custom/tasks.go +++ b/api/custom/tasks.go @@ -5,22 +5,22 @@ import ( "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/database" "net/http" "strconv" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/rcontext" - "github.com/turt2live/matrix-media-repo/storage" ) type TaskStatus struct { - TaskID int `json:"task_id"` - Name string `json:"task_name"` - Params map[string]interface{} `json:"params"` - StartTs int64 `json:"start_ts"` - EndTs int64 `json:"end_ts"` - IsFinished bool `json:"is_finished"` + TaskID int `json:"task_id"` + Name string `json:"task_name"` + Params *database.AnonymousJson `json:"params"` + StartTs int64 `json:"start_ts"` + EndTs int64 `json:"end_ts"` + IsFinished bool `json:"is_finished"` } func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { @@ -35,17 +35,20 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn "taskId": taskId, }) - db := storage.GetDatabase().GetMetadataStore(rctx) + db := database.GetInstance().Tasks.Prepare(rctx) - task, err := db.GetBackgroundTask(taskId) + task, err := db.Get(taskId) if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) return _responses.InternalServerError("failed to get task information") } + if task == nil { + return _responses.NotFoundError() + } return &_responses.DoNotCacheResponse{Payload: &TaskStatus{ - TaskID: task.ID, + TaskID: task.TaskId, Name: task.Name, Params: task.Params, StartTs: task.StartTs, @@ -55,9 +58,9 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn } func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - db := storage.GetDatabase().GetMetadataStore(rctx) + db := database.GetInstance().Tasks.Prepare(rctx) - tasks, err := db.GetAllBackgroundTasks() + tasks, err := db.GetAll(true) if err != nil { logrus.Error(err) sentry.CaptureException(err) @@ -67,7 +70,7 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U statusObjs := make([]*TaskStatus, 0) for _, task := range tasks { statusObjs = append(statusObjs, &TaskStatus{ - TaskID: task.ID, + TaskID: task.TaskId, Name: task.Name, Params: task.Params, StartTs: task.StartTs, @@ -80,9 +83,9 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U } func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - db := storage.GetDatabase().GetMetadataStore(rctx) + db := database.GetInstance().Tasks.Prepare(rctx) - tasks, err := db.GetAllBackgroundTasks() + tasks, err := db.GetAll(false) if err != nil { logrus.Error(err) sentry.CaptureException(err) @@ -91,11 +94,8 @@ func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _ap statusObjs := make([]*TaskStatus, 0) for _, task := range tasks { - if task.EndTs > 0 { - continue - } statusObjs = append(statusObjs, &TaskStatus{ - TaskID: task.ID, + TaskID: task.TaskId, Name: task.Name, Params: task.Params, StartTs: task.StartTs, diff --git a/database/table_tasks.go b/database/table_tasks.go index 015006f41d6a2f8fcf58c6b5d4209f357bbb6d4d..84e34be3fc2290ce88865e2019846c40d954f5f4 100644 --- a/database/table_tasks.go +++ b/database/table_tasks.go @@ -17,10 +17,14 @@ type DbTask struct { const selectTask = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE id = $1;" const insertTask = "INSERT INTO background_tasks (task, params, start_ts, end_ts) VALUES ($1, $2, $3, 0) RETURNING id, task, params, start_ts, end_ts;" +const selectAllTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks;" +const selectIncompleteTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE end_ts <= 0;" type tasksTableStatements struct { - selectTask *sql.Stmt - insertTask *sql.Stmt + selectTask *sql.Stmt + insertTask *sql.Stmt + selectAllTasks *sql.Stmt + selectIncompleteTasks *sql.Stmt } type tasksTableWithContext struct { @@ -38,6 +42,12 @@ func prepareTasksTables(db *sql.DB) (*tasksTableStatements, error) { if stmts.insertTask, err = db.Prepare(insertTask); err != nil { return nil, errors.New("error preparing insertTask: " + err.Error()) } + if stmts.selectAllTasks, err = db.Prepare(selectAllTasks); err != nil { + return nil, errors.New("error preparing selectAllTasks: " + err.Error()) + } + if stmts.selectIncompleteTasks, err = db.Prepare(selectIncompleteTasks); err != nil { + return nil, errors.New("error preparing selectIncompleteTasks: " + err.Error()) + } return stmts, nil } @@ -69,3 +79,26 @@ func (s *tasksTableWithContext) Get(id int) (*DbTask, error) { } return val, err } + +func (s *tasksTableWithContext) GetAll(includingFinished bool) ([]*DbTask, error) { + results := make([]*DbTask, 0) + q := s.statements.selectAllTasks + if !includingFinished { + q = s.statements.selectIncompleteTasks + } + rows, err := q.QueryContext(s.ctx) + if err != nil { + if err == sql.ErrNoRows { + return results, nil + } + return nil, err + } + for rows.Next() { + val := &DbTask{} + if err = rows.Scan(&val.TaskId, &val.Name, &val.Params, &val.StartTs, &val.EndTs); err != nil { + return nil, err + } + results = append(results, val) + } + return results, nil +}