gitea source for verification 2026-05-22
This commit is contained in:
213
models/actions/artifact.go
Normal file
213
models/actions/artifact.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// This artifact server is inspired by https://github.com/nektos/act/blob/master/pkg/artifacts/server.go.
|
||||
// It updates url setting and uses ObjectStore to handle artifacts persistence.
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ArtifactStatus is the status of an artifact, uploading, expired or need-delete
|
||||
type ArtifactStatus int64
|
||||
|
||||
const (
|
||||
ArtifactStatusUploadPending ArtifactStatus = iota + 1 // 1, ArtifactStatusUploadPending is the status of an artifact upload that is pending
|
||||
ArtifactStatusUploadConfirmed // 2, ArtifactStatusUploadConfirmed is the status of an artifact upload that is confirmed
|
||||
ArtifactStatusUploadError // 3, ArtifactStatusUploadError is the status of an artifact upload that is errored
|
||||
ArtifactStatusExpired // 4, ArtifactStatusExpired is the status of an artifact that is expired
|
||||
ArtifactStatusPendingDeletion // 5, ArtifactStatusPendingDeletion is the status of an artifact that is pending deletion
|
||||
ArtifactStatusDeleted // 6, ArtifactStatusDeleted is the status of an artifact that is deleted
|
||||
)
|
||||
|
||||
func (status ArtifactStatus) ToString() string {
|
||||
switch status {
|
||||
case ArtifactStatusUploadPending:
|
||||
return "upload is not yet completed"
|
||||
case ArtifactStatusUploadConfirmed:
|
||||
return "upload is completed"
|
||||
case ArtifactStatusUploadError:
|
||||
return "upload failed"
|
||||
case ArtifactStatusExpired:
|
||||
return "expired"
|
||||
case ArtifactStatusPendingDeletion:
|
||||
return "pending deletion"
|
||||
case ArtifactStatusDeleted:
|
||||
return "deleted"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionArtifact))
|
||||
}
|
||||
|
||||
// ActionArtifact is a file that is stored in the artifact storage.
|
||||
type ActionArtifact struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
RunID int64 `xorm:"index unique(runid_name_path)"` // The run id of the artifact
|
||||
RunnerID int64
|
||||
RepoID int64 `xorm:"index"`
|
||||
OwnerID int64
|
||||
CommitSHA string
|
||||
StoragePath string // The path to the artifact in the storage
|
||||
FileSize int64 // The size of the artifact in bytes
|
||||
FileCompressedSize int64 // The size of the artifact in bytes after gzip compression
|
||||
ContentEncoding string // The content encoding of the artifact
|
||||
ArtifactPath string `xorm:"index unique(runid_name_path)"` // The path to the artifact when runner uploads it
|
||||
ArtifactName string `xorm:"index unique(runid_name_path)"` // The name of the artifact when runner uploads it
|
||||
Status ArtifactStatus `xorm:"index"` // The status of the artifact, uploading, expired or need-delete
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated index"`
|
||||
ExpiredUnix timeutil.TimeStamp `xorm:"index"` // The time when the artifact will be expired
|
||||
}
|
||||
|
||||
func CreateArtifact(ctx context.Context, t *ActionTask, artifactName, artifactPath string, expiredDays int64) (*ActionArtifact, error) {
|
||||
if err := t.LoadJob(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artifact, err := getArtifactByNameAndPath(ctx, t.Job.RunID, artifactName, artifactPath)
|
||||
if errors.Is(err, util.ErrNotExist) {
|
||||
artifact := &ActionArtifact{
|
||||
ArtifactName: artifactName,
|
||||
ArtifactPath: artifactPath,
|
||||
RunID: t.Job.RunID,
|
||||
RunnerID: t.RunnerID,
|
||||
RepoID: t.RepoID,
|
||||
OwnerID: t.OwnerID,
|
||||
CommitSHA: t.CommitSHA,
|
||||
Status: ArtifactStatusUploadPending,
|
||||
ExpiredUnix: timeutil.TimeStamp(time.Now().Unix() + timeutil.Day*expiredDays),
|
||||
}
|
||||
if _, err := db.GetEngine(ctx).Insert(artifact); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return artifact, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := db.GetEngine(ctx).ID(artifact.ID).Cols("expired_unix").Update(&ActionArtifact{
|
||||
ExpiredUnix: timeutil.TimeStamp(time.Now().Unix() + timeutil.Day*expiredDays),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func getArtifactByNameAndPath(ctx context.Context, runID int64, name, fpath string) (*ActionArtifact, error) {
|
||||
var art ActionArtifact
|
||||
has, err := db.GetEngine(ctx).Where("run_id = ? AND artifact_name = ? AND artifact_path = ?", runID, name, fpath).Get(&art)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, util.ErrNotExist
|
||||
}
|
||||
return &art, nil
|
||||
}
|
||||
|
||||
// UpdateArtifactByID updates an artifact by id
|
||||
func UpdateArtifactByID(ctx context.Context, id int64, art *ActionArtifact) error {
|
||||
art.ID = id
|
||||
_, err := db.GetEngine(ctx).ID(id).AllCols().Update(art)
|
||||
return err
|
||||
}
|
||||
|
||||
type FindArtifactsOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
RunID int64
|
||||
ArtifactName string
|
||||
Status int
|
||||
FinalizedArtifactsV4 bool
|
||||
}
|
||||
|
||||
func (opts FindArtifactsOptions) ToOrders() string {
|
||||
return "id"
|
||||
}
|
||||
|
||||
var _ db.FindOptionsOrder = (*FindArtifactsOptions)(nil)
|
||||
|
||||
func (opts FindArtifactsOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.RunID > 0 {
|
||||
cond = cond.And(builder.Eq{"run_id": opts.RunID})
|
||||
}
|
||||
if opts.ArtifactName != "" {
|
||||
cond = cond.And(builder.Eq{"artifact_name": opts.ArtifactName})
|
||||
}
|
||||
if opts.Status > 0 {
|
||||
cond = cond.And(builder.Eq{"status": opts.Status})
|
||||
}
|
||||
if opts.FinalizedArtifactsV4 {
|
||||
cond = cond.And(builder.Eq{"status": ArtifactStatusUploadConfirmed}.Or(builder.Eq{"status": ArtifactStatusExpired}))
|
||||
cond = cond.And(builder.Eq{"content_encoding": "application/zip"})
|
||||
}
|
||||
|
||||
return cond
|
||||
}
|
||||
|
||||
// ActionArtifactMeta is the meta-data of an artifact
|
||||
type ActionArtifactMeta struct {
|
||||
ArtifactName string
|
||||
FileSize int64
|
||||
Status ArtifactStatus
|
||||
}
|
||||
|
||||
// ListUploadedArtifactsMeta returns all uploaded artifacts meta of a run
|
||||
func ListUploadedArtifactsMeta(ctx context.Context, runID int64) ([]*ActionArtifactMeta, error) {
|
||||
arts := make([]*ActionArtifactMeta, 0, 10)
|
||||
return arts, db.GetEngine(ctx).Table("action_artifact").
|
||||
Where("run_id=? AND (status=? OR status=?)", runID, ArtifactStatusUploadConfirmed, ArtifactStatusExpired).
|
||||
GroupBy("artifact_name").
|
||||
Select("artifact_name, sum(file_size) as file_size, max(status) as status").
|
||||
Find(&arts)
|
||||
}
|
||||
|
||||
// ListNeedExpiredArtifacts returns all need expired artifacts but not deleted
|
||||
func ListNeedExpiredArtifacts(ctx context.Context) ([]*ActionArtifact, error) {
|
||||
arts := make([]*ActionArtifact, 0, 10)
|
||||
return arts, db.GetEngine(ctx).
|
||||
Where("expired_unix < ? AND status = ?", timeutil.TimeStamp(time.Now().Unix()), ArtifactStatusUploadConfirmed).Find(&arts)
|
||||
}
|
||||
|
||||
// ListPendingDeleteArtifacts returns all artifacts in pending-delete status.
|
||||
// limit is the max number of artifacts to return.
|
||||
func ListPendingDeleteArtifacts(ctx context.Context, limit int) ([]*ActionArtifact, error) {
|
||||
arts := make([]*ActionArtifact, 0, limit)
|
||||
return arts, db.GetEngine(ctx).
|
||||
Where("status = ?", ArtifactStatusPendingDeletion).Limit(limit).Find(&arts)
|
||||
}
|
||||
|
||||
// SetArtifactExpired sets an artifact to expired
|
||||
func SetArtifactExpired(ctx context.Context, artifactID int64) error {
|
||||
_, err := db.GetEngine(ctx).Where("id=? AND status = ?", artifactID, ArtifactStatusUploadConfirmed).Cols("status").Update(&ActionArtifact{Status: ArtifactStatusExpired})
|
||||
return err
|
||||
}
|
||||
|
||||
// SetArtifactNeedDelete sets an artifact to need-delete, cron job will delete it
|
||||
func SetArtifactNeedDelete(ctx context.Context, runID int64, name string) error {
|
||||
_, err := db.GetEngine(ctx).Where("run_id=? AND artifact_name=? AND status = ?", runID, name, ArtifactStatusUploadConfirmed).Cols("status").Update(&ActionArtifact{Status: ArtifactStatusPendingDeletion})
|
||||
return err
|
||||
}
|
||||
|
||||
// SetArtifactDeleted sets an artifact to deleted
|
||||
func SetArtifactDeleted(ctx context.Context, artifactID int64) error {
|
||||
_, err := db.GetEngine(ctx).ID(artifactID).Cols("status").Update(&ActionArtifact{Status: ArtifactStatusDeleted})
|
||||
return err
|
||||
}
|
||||
20
models/actions/main_test.go
Normal file
20
models/actions/main_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m, &unittest.TestOptions{
|
||||
FixtureFiles: []string{
|
||||
"action_runner_token.yml",
|
||||
"action_run.yml",
|
||||
"repository.yml",
|
||||
},
|
||||
})
|
||||
}
|
||||
444
models/actions/run.go
Normal file
444
models/actions/run.go
Normal file
@@ -0,0 +1,444 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/git"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
api "code.gitea.io/gitea/modules/structs"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
webhook_module "code.gitea.io/gitea/modules/webhook"
|
||||
|
||||
"github.com/nektos/act/pkg/jobparser"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionRun represents a run of a workflow file
|
||||
type ActionRun struct {
|
||||
ID int64
|
||||
Title string
|
||||
RepoID int64 `xorm:"index unique(repo_index)"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
WorkflowID string `xorm:"index"` // the name of workflow file
|
||||
Index int64 `xorm:"index unique(repo_index)"` // a unique number for each run of a repository
|
||||
TriggerUserID int64 `xorm:"index"`
|
||||
TriggerUser *user_model.User `xorm:"-"`
|
||||
ScheduleID int64
|
||||
Ref string `xorm:"index"` // the commit/tag/… that caused the run
|
||||
IsRefDeleted bool `xorm:"-"`
|
||||
CommitSHA string
|
||||
IsForkPullRequest bool // If this is triggered by a PR from a forked repository or an untrusted user, we need to check if it is approved and limit permissions when running the workflow.
|
||||
NeedApproval bool // may need approval if it's a fork pull request
|
||||
ApprovedBy int64 `xorm:"index"` // who approved
|
||||
Event webhook_module.HookEventType // the webhook event that causes the workflow to run
|
||||
EventPayload string `xorm:"LONGTEXT"`
|
||||
TriggerEvent string // the trigger event defined in the `on` configuration of the triggered workflow
|
||||
Status Status `xorm:"index"`
|
||||
Version int `xorm:"version default 0"` // Status could be updated concomitantly, so an optimistic lock is needed
|
||||
// Started and Stopped is used for recording last run time, if rerun happened, they will be reset to 0
|
||||
Started timeutil.TimeStamp
|
||||
Stopped timeutil.TimeStamp
|
||||
// PreviousDuration is used for recording previous duration
|
||||
PreviousDuration time.Duration
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionRun))
|
||||
db.RegisterModel(new(ActionRunIndex))
|
||||
}
|
||||
|
||||
func (run *ActionRun) HTMLURL() string {
|
||||
if run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/actions/runs/%d", run.Repo.HTMLURL(), run.Index)
|
||||
}
|
||||
|
||||
func (run *ActionRun) Link() string {
|
||||
if run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/actions/runs/%d", run.Repo.Link(), run.Index)
|
||||
}
|
||||
|
||||
func (run *ActionRun) WorkflowLink() string {
|
||||
if run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/actions/?workflow=%s", run.Repo.Link(), run.WorkflowID)
|
||||
}
|
||||
|
||||
// RefLink return the url of run's ref
|
||||
func (run *ActionRun) RefLink() string {
|
||||
refName := git.RefName(run.Ref)
|
||||
if refName.IsPull() {
|
||||
return run.Repo.Link() + "/pulls/" + refName.ShortName()
|
||||
}
|
||||
return run.Repo.Link() + "/src/" + refName.RefWebLinkPath()
|
||||
}
|
||||
|
||||
// PrettyRef return #id for pull ref or ShortName for others
|
||||
func (run *ActionRun) PrettyRef() string {
|
||||
refName := git.RefName(run.Ref)
|
||||
if refName.IsPull() {
|
||||
return "#" + strings.TrimSuffix(strings.TrimPrefix(run.Ref, git.PullPrefix), "/head")
|
||||
}
|
||||
return refName.ShortName()
|
||||
}
|
||||
|
||||
// LoadAttributes load Repo TriggerUser if not loaded
|
||||
func (run *ActionRun) LoadAttributes(ctx context.Context) error {
|
||||
if run == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := run.LoadRepo(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := run.Repo.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if run.TriggerUser == nil {
|
||||
u, err := user_model.GetPossibleUserByID(ctx, run.TriggerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
run.TriggerUser = u
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (run *ActionRun) LoadRepo(ctx context.Context) error {
|
||||
if run == nil || run.Repo != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
repo, err := repo_model.GetRepositoryByID(ctx, run.RepoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
run.Repo = repo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (run *ActionRun) Duration() time.Duration {
|
||||
return calculateDuration(run.Started, run.Stopped, run.Status) + run.PreviousDuration
|
||||
}
|
||||
|
||||
func (run *ActionRun) GetPushEventPayload() (*api.PushPayload, error) {
|
||||
if run.Event == webhook_module.HookEventPush {
|
||||
var payload api.PushPayload
|
||||
if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
return nil, fmt.Errorf("event %s is not a push event", run.Event)
|
||||
}
|
||||
|
||||
func (run *ActionRun) GetPullRequestEventPayload() (*api.PullRequestPayload, error) {
|
||||
if run.Event.IsPullRequest() {
|
||||
var payload api.PullRequestPayload
|
||||
if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
return nil, fmt.Errorf("event %s is not a pull request event", run.Event)
|
||||
}
|
||||
|
||||
func (run *ActionRun) GetWorkflowRunEventPayload() (*api.WorkflowRunPayload, error) {
|
||||
if run.Event == webhook_module.HookEventWorkflowRun {
|
||||
var payload api.WorkflowRunPayload
|
||||
if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &payload, nil
|
||||
}
|
||||
return nil, fmt.Errorf("event %s is not a workflow run event", run.Event)
|
||||
}
|
||||
|
||||
func (run *ActionRun) IsSchedule() bool {
|
||||
return run.ScheduleID > 0
|
||||
}
|
||||
|
||||
func updateRepoRunsNumbers(ctx context.Context, repo *repo_model.Repository) error {
|
||||
_, err := db.GetEngine(ctx).ID(repo.ID).
|
||||
NoAutoTime().
|
||||
Cols("num_action_runs", "num_closed_action_runs").
|
||||
SetExpr("num_action_runs",
|
||||
builder.Select("count(*)").From("action_run").
|
||||
Where(builder.Eq{"repo_id": repo.ID}),
|
||||
).
|
||||
SetExpr("num_closed_action_runs",
|
||||
builder.Select("count(*)").From("action_run").
|
||||
Where(builder.Eq{
|
||||
"repo_id": repo.ID,
|
||||
}.And(
|
||||
builder.In("status",
|
||||
StatusSuccess,
|
||||
StatusFailure,
|
||||
StatusCancelled,
|
||||
StatusSkipped,
|
||||
),
|
||||
),
|
||||
),
|
||||
).
|
||||
Update(repo)
|
||||
return err
|
||||
}
|
||||
|
||||
// CancelPreviousJobs cancels all previous jobs of the same repository, reference, workflow, and event.
|
||||
// It's useful when a new run is triggered, and all previous runs needn't be continued anymore.
|
||||
func CancelPreviousJobs(ctx context.Context, repoID int64, ref, workflowID string, event webhook_module.HookEventType) ([]*ActionRunJob, error) {
|
||||
// Find all runs in the specified repository, reference, and workflow with non-final status
|
||||
runs, total, err := db.FindAndCount[ActionRun](ctx, FindRunOptions{
|
||||
RepoID: repoID,
|
||||
Ref: ref,
|
||||
WorkflowID: workflowID,
|
||||
TriggerEvent: event,
|
||||
Status: []Status{StatusRunning, StatusWaiting, StatusBlocked},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there are no runs found, there's no need to proceed with cancellation, so return nil.
|
||||
if total == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cancelledJobs := make([]*ActionRunJob, 0, total)
|
||||
|
||||
// Iterate over each found run and cancel its associated jobs.
|
||||
for _, run := range runs {
|
||||
// Find all jobs associated with the current run.
|
||||
jobs, err := db.Find[ActionRunJob](ctx, FindRunJobOptions{
|
||||
RunID: run.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return cancelledJobs, err
|
||||
}
|
||||
|
||||
// Iterate over each job and attempt to cancel it.
|
||||
for _, job := range jobs {
|
||||
// Skip jobs that are already in a terminal state (completed, cancelled, etc.).
|
||||
status := job.Status
|
||||
if status.IsDone() {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the job has no associated task (probably an error), set its status to 'Cancelled' and stop it.
|
||||
if job.TaskID == 0 {
|
||||
job.Status = StatusCancelled
|
||||
job.Stopped = timeutil.TimeStampNow()
|
||||
|
||||
// Update the job's status and stopped time in the database.
|
||||
n, err := UpdateRunJob(ctx, job, builder.Eq{"task_id": 0}, "status", "stopped")
|
||||
if err != nil {
|
||||
return cancelledJobs, err
|
||||
}
|
||||
|
||||
// If the update affected 0 rows, it means the job has changed in the meantime, so we need to try again.
|
||||
if n == 0 {
|
||||
return cancelledJobs, errors.New("job has changed, try again")
|
||||
}
|
||||
|
||||
cancelledJobs = append(cancelledJobs, job)
|
||||
// Continue with the next job.
|
||||
continue
|
||||
}
|
||||
|
||||
// If the job has an associated task, try to stop the task, effectively cancelling the job.
|
||||
if err := StopTask(ctx, job.TaskID, StatusCancelled); err != nil {
|
||||
return cancelledJobs, err
|
||||
}
|
||||
cancelledJobs = append(cancelledJobs, job)
|
||||
}
|
||||
}
|
||||
|
||||
// Return nil to indicate successful cancellation of all running and waiting jobs.
|
||||
return cancelledJobs, nil
|
||||
}
|
||||
|
||||
// InsertRun inserts a run
|
||||
// The title will be cut off at 255 characters if it's longer than 255 characters.
|
||||
func InsertRun(ctx context.Context, run *ActionRun, jobs []*jobparser.SingleWorkflow) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
index, err := db.GetNextResourceIndex(ctx, "action_run_index", run.RepoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
run.Index = index
|
||||
run.Title = util.EllipsisDisplayString(run.Title, 255)
|
||||
|
||||
if err := db.Insert(ctx, run); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if run.Repo == nil {
|
||||
repo, err := repo_model.GetRepositoryByID(ctx, run.RepoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
run.Repo = repo
|
||||
}
|
||||
|
||||
if err := updateRepoRunsNumbers(ctx, run.Repo); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runJobs := make([]*ActionRunJob, 0, len(jobs))
|
||||
var hasWaiting bool
|
||||
for _, v := range jobs {
|
||||
id, job := v.Job()
|
||||
needs := job.Needs()
|
||||
if err := v.SetJob(id, job.EraseNeeds()); err != nil {
|
||||
return err
|
||||
}
|
||||
payload, _ := v.Marshal()
|
||||
status := StatusWaiting
|
||||
if len(needs) > 0 || run.NeedApproval {
|
||||
status = StatusBlocked
|
||||
} else {
|
||||
hasWaiting = true
|
||||
}
|
||||
job.Name = util.EllipsisDisplayString(job.Name, 255)
|
||||
runJobs = append(runJobs, &ActionRunJob{
|
||||
RunID: run.ID,
|
||||
RepoID: run.RepoID,
|
||||
OwnerID: run.OwnerID,
|
||||
CommitSHA: run.CommitSHA,
|
||||
IsForkPullRequest: run.IsForkPullRequest,
|
||||
Name: job.Name,
|
||||
WorkflowPayload: payload,
|
||||
JobID: id,
|
||||
Needs: needs,
|
||||
RunsOn: job.RunsOn(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
if err := db.Insert(ctx, runJobs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if there is a job in the waiting status, increase tasks version.
|
||||
if hasWaiting {
|
||||
if err := IncreaseTaskVersion(ctx, run.OwnerID, run.RepoID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func GetRunByRepoAndID(ctx context.Context, repoID, runID int64) (*ActionRun, error) {
|
||||
var run ActionRun
|
||||
has, err := db.GetEngine(ctx).Where("id=? AND repo_id=?", runID, repoID).Get(&run)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("run with id %d: %w", runID, util.ErrNotExist)
|
||||
}
|
||||
|
||||
return &run, nil
|
||||
}
|
||||
|
||||
func GetRunByIndex(ctx context.Context, repoID, index int64) (*ActionRun, error) {
|
||||
run := &ActionRun{
|
||||
RepoID: repoID,
|
||||
Index: index,
|
||||
}
|
||||
has, err := db.GetEngine(ctx).Get(run)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("run with index %d %d: %w", repoID, index, util.ErrNotExist)
|
||||
}
|
||||
|
||||
return run, nil
|
||||
}
|
||||
|
||||
func GetLatestRun(ctx context.Context, repoID int64) (*ActionRun, error) {
|
||||
run := &ActionRun{
|
||||
RepoID: repoID,
|
||||
}
|
||||
has, err := db.GetEngine(ctx).Where("repo_id=?", repoID).Desc("index").Get(run)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("latest run with repo_id %d: %w", repoID, util.ErrNotExist)
|
||||
}
|
||||
return run, nil
|
||||
}
|
||||
|
||||
func GetWorkflowLatestRun(ctx context.Context, repoID int64, workflowFile, branch, event string) (*ActionRun, error) {
|
||||
var run ActionRun
|
||||
q := db.GetEngine(ctx).Where("repo_id=?", repoID).
|
||||
And("ref = ?", branch).
|
||||
And("workflow_id = ?", workflowFile)
|
||||
if event != "" {
|
||||
q.And("event = ?", event)
|
||||
}
|
||||
has, err := q.Desc("id").Get(&run)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, util.NewNotExistErrorf("run with repo_id %d, ref %s, workflow_id %s", repoID, branch, workflowFile)
|
||||
}
|
||||
return &run, nil
|
||||
}
|
||||
|
||||
// UpdateRun updates a run.
|
||||
// It requires the inputted run has Version set.
|
||||
// It will return error if the version is not matched (it means the run has been changed after loaded).
|
||||
func UpdateRun(ctx context.Context, run *ActionRun, cols ...string) error {
|
||||
sess := db.GetEngine(ctx).ID(run.ID)
|
||||
if len(cols) > 0 {
|
||||
sess.Cols(cols...)
|
||||
}
|
||||
run.Title = util.EllipsisDisplayString(run.Title, 255)
|
||||
affected, err := sess.Update(run)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return errors.New("run has changed")
|
||||
// It's impossible that the run is not found, since Gitea never deletes runs.
|
||||
}
|
||||
|
||||
if run.Status != 0 || slices.Contains(cols, "status") {
|
||||
if run.RepoID == 0 {
|
||||
setting.PanicInDevOrTesting("RepoID should not be 0")
|
||||
}
|
||||
if err = run.LoadRepo(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := updateRepoRunsNumbers(ctx, run.Repo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ActionRunIndex db.ResourceIndex
|
||||
199
models/actions/run_job.go
Normal file
199
models/actions/run_job.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionRunJob represents a job of a run
|
||||
type ActionRunJob struct {
|
||||
ID int64
|
||||
RunID int64 `xorm:"index"`
|
||||
Run *ActionRun `xorm:"-"`
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
CommitSHA string `xorm:"index"`
|
||||
IsForkPullRequest bool
|
||||
Name string `xorm:"VARCHAR(255)"`
|
||||
Attempt int64
|
||||
WorkflowPayload []byte
|
||||
JobID string `xorm:"VARCHAR(255)"` // job id in workflow, not job's id
|
||||
Needs []string `xorm:"JSON TEXT"`
|
||||
RunsOn []string `xorm:"JSON TEXT"`
|
||||
TaskID int64 // the latest task of the job
|
||||
Status Status `xorm:"index"`
|
||||
Started timeutil.TimeStamp
|
||||
Stopped timeutil.TimeStamp
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated index"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionRunJob))
|
||||
}
|
||||
|
||||
func (job *ActionRunJob) Duration() time.Duration {
|
||||
return calculateDuration(job.Started, job.Stopped, job.Status)
|
||||
}
|
||||
|
||||
func (job *ActionRunJob) LoadRun(ctx context.Context) error {
|
||||
if job.Run == nil {
|
||||
run, err := GetRunByRepoAndID(ctx, job.RepoID, job.RunID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
job.Run = run
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (job *ActionRunJob) LoadRepo(ctx context.Context) error {
|
||||
if job.Repo == nil {
|
||||
repo, err := repo_model.GetRepositoryByID(ctx, job.RepoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
job.Repo = repo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAttributes load Run if not loaded
|
||||
func (job *ActionRunJob) LoadAttributes(ctx context.Context) error {
|
||||
if job == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := job.LoadRun(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return job.Run.LoadAttributes(ctx)
|
||||
}
|
||||
|
||||
func GetRunJobByID(ctx context.Context, id int64) (*ActionRunJob, error) {
|
||||
var job ActionRunJob
|
||||
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&job)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("run job with id %d: %w", id, util.ErrNotExist)
|
||||
}
|
||||
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
func GetRunJobsByRunID(ctx context.Context, runID int64) (ActionJobList, error) {
|
||||
var jobs []*ActionRunJob
|
||||
if err := db.GetEngine(ctx).Where("run_id=?", runID).OrderBy("id").Find(&jobs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
|
||||
func UpdateRunJob(ctx context.Context, job *ActionRunJob, cond builder.Cond, cols ...string) (int64, error) {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
sess := e.ID(job.ID)
|
||||
if len(cols) > 0 {
|
||||
sess.Cols(cols...)
|
||||
}
|
||||
|
||||
if cond != nil {
|
||||
sess.Where(cond)
|
||||
}
|
||||
|
||||
affected, err := sess.Update(job)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if affected == 0 || (!slices.Contains(cols, "status") && job.Status == 0) {
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
if affected != 0 && slices.Contains(cols, "status") && job.Status.IsWaiting() {
|
||||
// if the status of job changes to waiting again, increase tasks version.
|
||||
if err := IncreaseTaskVersion(ctx, job.OwnerID, job.RepoID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if job.RunID == 0 {
|
||||
var err error
|
||||
if job, err = GetRunJobByID(ctx, job.ID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Other goroutines may aggregate the status of the run and update it too.
|
||||
// So we need load the run and its jobs before updating the run.
|
||||
run, err := GetRunByRepoAndID(ctx, job.RepoID, job.RunID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
jobs, err := GetRunJobsByRunID(ctx, job.RunID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
run.Status = AggregateJobStatus(jobs)
|
||||
if run.Started.IsZero() && run.Status.IsRunning() {
|
||||
run.Started = timeutil.TimeStampNow()
|
||||
}
|
||||
if run.Stopped.IsZero() && run.Status.IsDone() {
|
||||
run.Stopped = timeutil.TimeStampNow()
|
||||
}
|
||||
if err := UpdateRun(ctx, run, "status", "started", "stopped"); err != nil {
|
||||
return 0, fmt.Errorf("update run %d: %w", run.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func AggregateJobStatus(jobs []*ActionRunJob) Status {
|
||||
allSuccessOrSkipped := len(jobs) != 0
|
||||
allSkipped := len(jobs) != 0
|
||||
var hasFailure, hasCancelled, hasWaiting, hasRunning, hasBlocked bool
|
||||
for _, job := range jobs {
|
||||
allSuccessOrSkipped = allSuccessOrSkipped && (job.Status == StatusSuccess || job.Status == StatusSkipped)
|
||||
allSkipped = allSkipped && job.Status == StatusSkipped
|
||||
hasFailure = hasFailure || job.Status == StatusFailure
|
||||
hasCancelled = hasCancelled || job.Status == StatusCancelled
|
||||
hasWaiting = hasWaiting || job.Status == StatusWaiting
|
||||
hasRunning = hasRunning || job.Status == StatusRunning
|
||||
hasBlocked = hasBlocked || job.Status == StatusBlocked
|
||||
}
|
||||
switch {
|
||||
case allSkipped:
|
||||
return StatusSkipped
|
||||
case allSuccessOrSkipped:
|
||||
return StatusSuccess
|
||||
case hasCancelled:
|
||||
return StatusCancelled
|
||||
case hasRunning:
|
||||
return StatusRunning
|
||||
case hasWaiting:
|
||||
return StatusWaiting
|
||||
case hasFailure:
|
||||
return StatusFailure
|
||||
case hasBlocked:
|
||||
return StatusBlocked
|
||||
default:
|
||||
return StatusUnknown // it shouldn't happen
|
||||
}
|
||||
}
|
||||
110
models/actions/run_job_list.go
Normal file
110
models/actions/run_job_list.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type ActionJobList []*ActionRunJob
|
||||
|
||||
func (jobs ActionJobList) GetRunIDs() []int64 {
|
||||
return container.FilterSlice(jobs, func(j *ActionRunJob) (int64, bool) {
|
||||
return j.RunID, j.RunID != 0
|
||||
})
|
||||
}
|
||||
|
||||
func (jobs ActionJobList) LoadRepos(ctx context.Context) error {
|
||||
repoIDs := container.FilterSlice(jobs, func(j *ActionRunJob) (int64, bool) {
|
||||
return j.RepoID, j.RepoID != 0 && j.Repo == nil
|
||||
})
|
||||
if len(repoIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repos := make(map[int64]*repo_model.Repository, len(repoIDs))
|
||||
if err := db.GetEngine(ctx).In("id", repoIDs).Find(&repos); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, j := range jobs {
|
||||
if j.RepoID > 0 && j.Repo == nil {
|
||||
j.Repo = repos[j.RepoID]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
|
||||
if withRepo {
|
||||
if err := jobs.LoadRepos(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
runIDs := jobs.GetRunIDs()
|
||||
runs := make(map[int64]*ActionRun, len(runIDs))
|
||||
if err := db.GetEngine(ctx).In("id", runIDs).Find(&runs); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, j := range jobs {
|
||||
if j.RunID > 0 && j.Run == nil {
|
||||
j.Run = runs[j.RunID]
|
||||
j.Run.Repo = j.Repo
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (jobs ActionJobList) LoadAttributes(ctx context.Context, withRepo bool) error {
|
||||
return jobs.LoadRuns(ctx, withRepo)
|
||||
}
|
||||
|
||||
type FindRunJobOptions struct {
|
||||
db.ListOptions
|
||||
RunID int64
|
||||
RepoID int64
|
||||
OwnerID int64
|
||||
CommitSHA string
|
||||
Statuses []Status
|
||||
UpdatedBefore timeutil.TimeStamp
|
||||
}
|
||||
|
||||
func (opts FindRunJobOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RunID > 0 {
|
||||
cond = cond.And(builder.Eq{"`action_run_job`.run_id": opts.RunID})
|
||||
}
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"`action_run_job`.repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.CommitSHA != "" {
|
||||
cond = cond.And(builder.Eq{"`action_run_job`.commit_sha": opts.CommitSHA})
|
||||
}
|
||||
if len(opts.Statuses) > 0 {
|
||||
cond = cond.And(builder.In("`action_run_job`.status", opts.Statuses))
|
||||
}
|
||||
if opts.UpdatedBefore > 0 {
|
||||
cond = cond.And(builder.Lt{"`action_run_job`.updated": opts.UpdatedBefore})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindRunJobOptions) ToJoins() []db.JoinFunc {
|
||||
if opts.OwnerID > 0 {
|
||||
return []db.JoinFunc{
|
||||
func(sess db.Engine) error {
|
||||
sess.Join("INNER", "repository", "repository.id = repo_id AND repository.owner_id = ?", opts.OwnerID)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
85
models/actions/run_job_status_test.go
Normal file
85
models/actions/run_job_status_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAggregateJobStatus(t *testing.T) {
|
||||
testStatuses := func(expected Status, statuses []Status) {
|
||||
t.Helper()
|
||||
var jobs []*ActionRunJob
|
||||
for _, v := range statuses {
|
||||
jobs = append(jobs, &ActionRunJob{Status: v})
|
||||
}
|
||||
actual := AggregateJobStatus(jobs)
|
||||
if !assert.Equal(t, expected, actual) {
|
||||
var statusStrings []string
|
||||
for _, s := range statuses {
|
||||
statusStrings = append(statusStrings, s.String())
|
||||
}
|
||||
t.Errorf("AggregateJobStatus(%v) = %v, want %v", statusStrings, statusNames[actual], statusNames[expected])
|
||||
}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
statuses []Status
|
||||
expected Status
|
||||
}{
|
||||
// unknown cases, maybe it shouldn't happen in real world
|
||||
{[]Status{}, StatusUnknown},
|
||||
{[]Status{StatusUnknown, StatusSuccess}, StatusUnknown},
|
||||
{[]Status{StatusUnknown, StatusSkipped}, StatusUnknown},
|
||||
{[]Status{StatusUnknown, StatusFailure}, StatusFailure},
|
||||
{[]Status{StatusUnknown, StatusCancelled}, StatusCancelled},
|
||||
{[]Status{StatusUnknown, StatusWaiting}, StatusWaiting},
|
||||
{[]Status{StatusUnknown, StatusRunning}, StatusRunning},
|
||||
{[]Status{StatusUnknown, StatusBlocked}, StatusBlocked},
|
||||
|
||||
// success with other status
|
||||
{[]Status{StatusSuccess}, StatusSuccess},
|
||||
{[]Status{StatusSuccess, StatusSkipped}, StatusSuccess}, // skipped doesn't affect success
|
||||
{[]Status{StatusSuccess, StatusFailure}, StatusFailure},
|
||||
{[]Status{StatusSuccess, StatusCancelled}, StatusCancelled},
|
||||
{[]Status{StatusSuccess, StatusWaiting}, StatusWaiting},
|
||||
{[]Status{StatusSuccess, StatusRunning}, StatusRunning},
|
||||
{[]Status{StatusSuccess, StatusBlocked}, StatusBlocked},
|
||||
|
||||
// any cancelled, then cancelled
|
||||
{[]Status{StatusCancelled}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusSuccess}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusSkipped}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusFailure}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusWaiting}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusRunning}, StatusCancelled},
|
||||
{[]Status{StatusCancelled, StatusBlocked}, StatusCancelled},
|
||||
|
||||
// failure with other status, usually fail fast, but "running" wins to match GitHub's behavior
|
||||
// another reason that we can't make "failure" wins over "running": it would cause a weird behavior that user cannot cancel a workflow or get current running workflows correctly by filter after a job fail.
|
||||
{[]Status{StatusFailure}, StatusFailure},
|
||||
{[]Status{StatusFailure, StatusSuccess}, StatusFailure},
|
||||
{[]Status{StatusFailure, StatusSkipped}, StatusFailure},
|
||||
{[]Status{StatusFailure, StatusCancelled}, StatusCancelled},
|
||||
{[]Status{StatusFailure, StatusWaiting}, StatusWaiting},
|
||||
{[]Status{StatusFailure, StatusRunning}, StatusRunning},
|
||||
{[]Status{StatusFailure, StatusBlocked}, StatusFailure},
|
||||
|
||||
// skipped with other status
|
||||
// "all skipped" is also considered as "mergeable" by "services/actions.toCommitStatus", the same as GitHub
|
||||
{[]Status{StatusSkipped}, StatusSkipped},
|
||||
{[]Status{StatusSkipped, StatusSuccess}, StatusSuccess},
|
||||
{[]Status{StatusSkipped, StatusFailure}, StatusFailure},
|
||||
{[]Status{StatusSkipped, StatusCancelled}, StatusCancelled},
|
||||
{[]Status{StatusSkipped, StatusWaiting}, StatusWaiting},
|
||||
{[]Status{StatusSkipped, StatusRunning}, StatusRunning},
|
||||
{[]Status{StatusSkipped, StatusBlocked}, StatusBlocked},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
testStatuses(c.expected, c.statuses)
|
||||
}
|
||||
}
|
||||
150
models/actions/run_list.go
Normal file
150
models/actions/run_list.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/translation"
|
||||
webhook_module "code.gitea.io/gitea/modules/webhook"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type RunList []*ActionRun
|
||||
|
||||
// GetUserIDs returns a slice of user's id
|
||||
func (runs RunList) GetUserIDs() []int64 {
|
||||
return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
|
||||
return run.TriggerUserID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (runs RunList) GetRepoIDs() []int64 {
|
||||
return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
|
||||
return run.RepoID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (runs RunList) LoadTriggerUser(ctx context.Context) error {
|
||||
userIDs := runs.GetUserIDs()
|
||||
users := make(map[int64]*user_model.User, len(userIDs))
|
||||
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, run := range runs {
|
||||
if run.TriggerUserID == user_model.ActionsUserID {
|
||||
run.TriggerUser = user_model.NewActionsUser()
|
||||
} else {
|
||||
run.TriggerUser = users[run.TriggerUserID]
|
||||
if run.TriggerUser == nil {
|
||||
run.TriggerUser = user_model.NewGhostUser()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runs RunList) LoadRepos(ctx context.Context) error {
|
||||
repoIDs := runs.GetRepoIDs()
|
||||
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, run := range runs {
|
||||
run.Repo = repos[run.RepoID]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FindRunOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
OwnerID int64
|
||||
WorkflowID string
|
||||
Ref string // the commit/tag/… that caused this workflow
|
||||
TriggerUserID int64
|
||||
TriggerEvent webhook_module.HookEventType
|
||||
Approved bool // not util.OptionalBool, it works only when it's true
|
||||
Status []Status
|
||||
CommitSHA string
|
||||
}
|
||||
|
||||
func (opts FindRunOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"`action_run`.repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.WorkflowID != "" {
|
||||
cond = cond.And(builder.Eq{"`action_run`.workflow_id": opts.WorkflowID})
|
||||
}
|
||||
if opts.TriggerUserID > 0 {
|
||||
cond = cond.And(builder.Eq{"`action_run`.trigger_user_id": opts.TriggerUserID})
|
||||
}
|
||||
if opts.Approved {
|
||||
cond = cond.And(builder.Gt{"`action_run`.approved_by": 0})
|
||||
}
|
||||
if len(opts.Status) > 0 {
|
||||
cond = cond.And(builder.In("`action_run`.status", opts.Status))
|
||||
}
|
||||
if opts.Ref != "" {
|
||||
cond = cond.And(builder.Eq{"`action_run`.ref": opts.Ref})
|
||||
}
|
||||
if opts.TriggerEvent != "" {
|
||||
cond = cond.And(builder.Eq{"`action_run`.trigger_event": opts.TriggerEvent})
|
||||
}
|
||||
if opts.CommitSHA != "" {
|
||||
cond = cond.And(builder.Eq{"`action_run`.commit_sha": opts.CommitSHA})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindRunOptions) ToJoins() []db.JoinFunc {
|
||||
if opts.OwnerID > 0 {
|
||||
return []db.JoinFunc{func(sess db.Engine) error {
|
||||
sess.Join("INNER", "repository", "repository.id = repo_id AND repository.owner_id = ?", opts.OwnerID)
|
||||
return nil
|
||||
}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opts FindRunOptions) ToOrders() string {
|
||||
return "`action_run`.`id` DESC"
|
||||
}
|
||||
|
||||
type StatusInfo struct {
|
||||
Status int
|
||||
DisplayedStatus string
|
||||
}
|
||||
|
||||
// GetStatusInfoList returns a slice of StatusInfo
|
||||
func GetStatusInfoList(ctx context.Context, lang translation.Locale) []StatusInfo {
|
||||
// same as those in aggregateJobStatus
|
||||
allStatus := []Status{StatusSuccess, StatusFailure, StatusWaiting, StatusRunning}
|
||||
statusInfoList := make([]StatusInfo, 0, 4)
|
||||
for _, s := range allStatus {
|
||||
statusInfoList = append(statusInfoList, StatusInfo{
|
||||
Status: int(s),
|
||||
DisplayedStatus: s.LocaleString(lang),
|
||||
})
|
||||
}
|
||||
return statusInfoList
|
||||
}
|
||||
|
||||
// GetActors returns a slice of Actors
|
||||
func GetActors(ctx context.Context, repoID int64) ([]*user_model.User, error) {
|
||||
actors := make([]*user_model.User, 0, 10)
|
||||
|
||||
return actors, db.GetEngine(ctx).Where(builder.In("id", builder.Select("`action_run`.trigger_user_id").From("`action_run`").
|
||||
GroupBy("`action_run`.trigger_user_id").
|
||||
Where(builder.Eq{"`action_run`.repo_id": repoID}))).
|
||||
Cols("id", "name", "full_name", "avatar", "avatar_email", "use_custom_avatar").
|
||||
OrderBy(user_model.GetOrderByName()).
|
||||
Find(&actors)
|
||||
}
|
||||
35
models/actions/run_test.go
Normal file
35
models/actions/run_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2025 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpdateRepoRunsNumbers(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
// update the number to a wrong one, the original is 3
|
||||
_, err := db.GetEngine(t.Context()).ID(4).Cols("num_closed_action_runs").Update(&repo_model.Repository{
|
||||
NumClosedActionRuns: 2,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4})
|
||||
assert.Equal(t, 4, repo.NumActionRuns)
|
||||
assert.Equal(t, 2, repo.NumClosedActionRuns)
|
||||
|
||||
// now update will correct them, only num_actionr_runs and num_closed_action_runs should be updated
|
||||
err = updateRepoRunsNumbers(t.Context(), repo)
|
||||
assert.NoError(t, err)
|
||||
repo = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4})
|
||||
assert.Equal(t, 5, repo.NumActionRuns)
|
||||
assert.Equal(t, 3, repo.NumClosedActionRuns)
|
||||
}
|
||||
388
models/actions/runner.go
Normal file
388
models/actions/runner.go
Normal file
@@ -0,0 +1,388 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/shared/types"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/optional"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/translation"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionRunner represents runner machines
|
||||
//
|
||||
// It can be:
|
||||
// 1. global runner, OwnerID is 0 and RepoID is 0
|
||||
// 2. org/user level runner, OwnerID is org/user ID and RepoID is 0
|
||||
// 3. repo level runner, OwnerID is 0 and RepoID is repo ID
|
||||
//
|
||||
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
|
||||
// or it will be complicated to find runners belonging to a specific owner.
|
||||
// For example, conditions like `OwnerID = 1` will also return runner {OwnerID: 1, RepoID: 1},
|
||||
// but it's a repo level runner, not an org/user level runner.
|
||||
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level runners.
|
||||
type ActionRunner struct {
|
||||
ID int64
|
||||
UUID string `xorm:"CHAR(36) UNIQUE"`
|
||||
Name string `xorm:"VARCHAR(255)"`
|
||||
Version string `xorm:"VARCHAR(64)"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
Owner *user_model.User `xorm:"-"`
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
Description string `xorm:"TEXT"`
|
||||
Base int // 0 native 1 docker 2 virtual machine
|
||||
RepoRange string // glob match which repositories could use this runner
|
||||
|
||||
Token string `xorm:"-"`
|
||||
TokenHash string `xorm:"UNIQUE"` // sha256 of token
|
||||
TokenSalt string
|
||||
// TokenLastEight string `xorm:"token_last_eight"` // it's unnecessary because we don't find runners by token
|
||||
|
||||
LastOnline timeutil.TimeStamp `xorm:"index"`
|
||||
LastActive timeutil.TimeStamp `xorm:"index"`
|
||||
|
||||
// Store labels defined in state file (default: .runner file) of `act_runner`
|
||||
AgentLabels []string `xorm:"TEXT"`
|
||||
// Store if this is a runner that only ever get one single job assigned
|
||||
Ephemeral bool `xorm:"ephemeral NOT NULL DEFAULT false"`
|
||||
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
Deleted timeutil.TimeStamp `xorm:"deleted"`
|
||||
}
|
||||
|
||||
const (
|
||||
RunnerOfflineTime = time.Minute
|
||||
RunnerIdleTime = 10 * time.Second
|
||||
)
|
||||
|
||||
// BelongsToOwnerName before calling, should guarantee that all attributes are loaded
|
||||
func (r *ActionRunner) BelongsToOwnerName() string {
|
||||
if r.RepoID != 0 {
|
||||
return r.Repo.FullName()
|
||||
}
|
||||
if r.OwnerID != 0 {
|
||||
return r.Owner.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *ActionRunner) BelongsToOwnerType() types.OwnerType {
|
||||
if r.RepoID != 0 {
|
||||
return types.OwnerTypeRepository
|
||||
}
|
||||
if r.OwnerID != 0 {
|
||||
switch r.Owner.Type {
|
||||
case user_model.UserTypeOrganization:
|
||||
return types.OwnerTypeOrganization
|
||||
case user_model.UserTypeIndividual:
|
||||
return types.OwnerTypeIndividual
|
||||
}
|
||||
}
|
||||
return types.OwnerTypeSystemGlobal
|
||||
}
|
||||
|
||||
// if the logic here changed, you should also modify FindRunnerOptions.ToCond
|
||||
func (r *ActionRunner) Status() runnerv1.RunnerStatus {
|
||||
if time.Since(r.LastOnline.AsTime()) > RunnerOfflineTime {
|
||||
return runnerv1.RunnerStatus_RUNNER_STATUS_OFFLINE
|
||||
}
|
||||
if time.Since(r.LastActive.AsTime()) > RunnerIdleTime {
|
||||
return runnerv1.RunnerStatus_RUNNER_STATUS_IDLE
|
||||
}
|
||||
return runnerv1.RunnerStatus_RUNNER_STATUS_ACTIVE
|
||||
}
|
||||
|
||||
func (r *ActionRunner) StatusName() string {
|
||||
return strings.ToLower(strings.TrimPrefix(r.Status().String(), "RUNNER_STATUS_"))
|
||||
}
|
||||
|
||||
func (r *ActionRunner) StatusLocaleName(lang translation.Locale) string {
|
||||
return lang.TrString("actions.runners.status." + r.StatusName())
|
||||
}
|
||||
|
||||
func (r *ActionRunner) IsOnline() bool {
|
||||
status := r.Status()
|
||||
if status == runnerv1.RunnerStatus_RUNNER_STATUS_IDLE || status == runnerv1.RunnerStatus_RUNNER_STATUS_ACTIVE {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EditableInContext checks if the runner is editable by the "context" owner/repo
|
||||
// ownerID == 0 and repoID == 0 means "admin" context, any runner including global runners could be edited
|
||||
// ownerID == 0 and repoID != 0 means "repo" context, any runner belonging to the given repo could be edited
|
||||
// ownerID != 0 and repoID == 0 means "owner(org/user)" context, any runner belonging to the given user/org could be edited
|
||||
// ownerID != 0 and repoID != 0 means "owner" OR "repo" context, legacy behavior, but we should forbid using it
|
||||
func (r *ActionRunner) EditableInContext(ownerID, repoID int64) bool {
|
||||
if ownerID != 0 && repoID != 0 {
|
||||
setting.PanicInDevOrTesting("ownerID and repoID should not be both set")
|
||||
}
|
||||
if ownerID == 0 && repoID == 0 {
|
||||
return true
|
||||
}
|
||||
if ownerID > 0 && r.OwnerID == ownerID {
|
||||
return true
|
||||
}
|
||||
return repoID > 0 && r.RepoID == repoID
|
||||
}
|
||||
|
||||
// LoadAttributes loads the attributes of the runner
|
||||
func (r *ActionRunner) LoadAttributes(ctx context.Context) error {
|
||||
if r.OwnerID > 0 {
|
||||
var user user_model.User
|
||||
has, err := db.GetEngine(ctx).ID(r.OwnerID).Get(&user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if has {
|
||||
r.Owner = &user
|
||||
}
|
||||
}
|
||||
if r.RepoID > 0 {
|
||||
var repo repo_model.Repository
|
||||
has, err := db.GetEngine(ctx).ID(r.RepoID).Get(&repo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if has {
|
||||
r.Repo = &repo
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ActionRunner) GenerateToken() (err error) {
|
||||
r.Token, r.TokenSalt, r.TokenHash, _, err = generateSaltedToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(&ActionRunner{})
|
||||
}
|
||||
|
||||
// FindRunnerOptions
|
||||
// ownerID == 0 and repoID == 0 means any runner including global runners
|
||||
// repoID != 0 and WithAvailable == false means any runner for the given repo
|
||||
// repoID != 0 and WithAvailable == true means any runner for the given repo, parent user/org, and global runners
|
||||
// ownerID != 0 and repoID == 0 and WithAvailable == false means any runner for the given user/org
|
||||
// ownerID != 0 and repoID == 0 and WithAvailable == true means any runner for the given user/org and global runners
|
||||
type FindRunnerOptions struct {
|
||||
db.ListOptions
|
||||
IDs []int64
|
||||
RepoID int64
|
||||
OwnerID int64 // it will be ignored if RepoID is set
|
||||
Sort string
|
||||
Filter string
|
||||
IsOnline optional.Option[bool]
|
||||
WithAvailable bool // not only runners belong to, but also runners can be used
|
||||
}
|
||||
|
||||
func (opts FindRunnerOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
|
||||
if len(opts.IDs) > 0 {
|
||||
if len(opts.IDs) == 1 {
|
||||
cond = cond.And(builder.Eq{"id": opts.IDs[0]})
|
||||
} else {
|
||||
cond = cond.And(builder.In("id", opts.IDs))
|
||||
}
|
||||
}
|
||||
|
||||
if opts.RepoID > 0 {
|
||||
c := builder.NewCond().And(builder.Eq{"repo_id": opts.RepoID})
|
||||
if opts.WithAvailable {
|
||||
c = c.Or(builder.Eq{"owner_id": builder.Select("owner_id").From("repository").Where(builder.Eq{"id": opts.RepoID})})
|
||||
c = c.Or(builder.Eq{"repo_id": 0, "owner_id": 0})
|
||||
}
|
||||
cond = cond.And(c)
|
||||
} else if opts.OwnerID > 0 { // OwnerID is ignored if RepoID is set
|
||||
c := builder.NewCond().And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
if opts.WithAvailable {
|
||||
c = c.Or(builder.Eq{"repo_id": 0, "owner_id": 0})
|
||||
}
|
||||
cond = cond.And(c)
|
||||
}
|
||||
|
||||
if opts.Filter != "" {
|
||||
cond = cond.And(builder.Like{"name", opts.Filter})
|
||||
}
|
||||
|
||||
if opts.IsOnline.Has() {
|
||||
if opts.IsOnline.Value() {
|
||||
cond = cond.And(builder.Gt{"last_online": time.Now().Add(-RunnerOfflineTime).Unix()})
|
||||
} else {
|
||||
cond = cond.And(builder.Lte{"last_online": time.Now().Add(-RunnerOfflineTime).Unix()})
|
||||
}
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindRunnerOptions) ToOrders() string {
|
||||
switch opts.Sort {
|
||||
case "online":
|
||||
return "last_online DESC"
|
||||
case "offline":
|
||||
return "last_online ASC"
|
||||
case "alphabetically":
|
||||
return "name ASC"
|
||||
case "reversealphabetically":
|
||||
return "name DESC"
|
||||
case "newest":
|
||||
return "id DESC"
|
||||
case "oldest":
|
||||
return "id ASC"
|
||||
}
|
||||
return "last_online DESC"
|
||||
}
|
||||
|
||||
// GetRunnerByUUID returns a runner via uuid
|
||||
func GetRunnerByUUID(ctx context.Context, uuid string) (*ActionRunner, error) {
|
||||
var runner ActionRunner
|
||||
has, err := db.GetEngine(ctx).Where("uuid=?", uuid).Get(&runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("runner with uuid %s: %w", uuid, util.ErrNotExist)
|
||||
}
|
||||
return &runner, nil
|
||||
}
|
||||
|
||||
// GetRunnerByID returns a runner via id
|
||||
func GetRunnerByID(ctx context.Context, id int64) (*ActionRunner, error) {
|
||||
var runner ActionRunner
|
||||
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&runner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("runner with id %d: %w", id, util.ErrNotExist)
|
||||
}
|
||||
return &runner, nil
|
||||
}
|
||||
|
||||
// UpdateRunner updates runner's information.
|
||||
func UpdateRunner(ctx context.Context, r *ActionRunner, cols ...string) error {
|
||||
e := db.GetEngine(ctx)
|
||||
r.Name = util.EllipsisDisplayString(r.Name, 255)
|
||||
var err error
|
||||
if len(cols) == 0 {
|
||||
_, err = e.ID(r.ID).AllCols().Update(r)
|
||||
} else {
|
||||
_, err = e.ID(r.ID).Cols(cols...).Update(r)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteRunner deletes a runner by given ID.
|
||||
func DeleteRunner(ctx context.Context, id int64) error {
|
||||
if _, err := GetRunnerByID(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := db.DeleteByID[ActionRunner](ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteEphemeralRunner deletes a ephemeral runner by given ID.
|
||||
func DeleteEphemeralRunner(ctx context.Context, id int64) error {
|
||||
runner, err := GetRunnerByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, util.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !runner.Ephemeral {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = db.DeleteByID[ActionRunner](ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateRunner creates new runner.
|
||||
func CreateRunner(ctx context.Context, t *ActionRunner) error {
|
||||
if t.OwnerID != 0 && t.RepoID != 0 {
|
||||
// It's trying to create a runner that belongs to a repository, but OwnerID has been set accidentally.
|
||||
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
|
||||
t.OwnerID = 0
|
||||
}
|
||||
t.Name = util.EllipsisDisplayString(t.Name, 255)
|
||||
return db.Insert(ctx, t)
|
||||
}
|
||||
|
||||
func CountRunnersWithoutBelongingOwner(ctx context.Context) (int64, error) {
|
||||
// Only affect action runners were a owner ID is set, as actions runners
|
||||
// could also be created on a repository.
|
||||
return db.GetEngine(ctx).Table("action_runner").
|
||||
Join("LEFT", "`user`", "`action_runner`.owner_id = `user`.id").
|
||||
Where("`action_runner`.owner_id != ?", 0).
|
||||
And(builder.IsNull{"`user`.id"}).
|
||||
Count(new(ActionRunner))
|
||||
}
|
||||
|
||||
func FixRunnersWithoutBelongingOwner(ctx context.Context) (int64, error) {
|
||||
subQuery := builder.Select("`action_runner`.id").
|
||||
From("`action_runner`").
|
||||
Join("LEFT", "`user`", "`action_runner`.owner_id = `user`.id").
|
||||
Where(builder.Neq{"`action_runner`.owner_id": 0}).
|
||||
And(builder.IsNull{"`user`.id"})
|
||||
b := builder.Delete(builder.In("id", subQuery)).From("`action_runner`")
|
||||
res, err := db.GetEngine(ctx).Exec(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func CountRunnersWithoutBelongingRepo(ctx context.Context) (int64, error) {
|
||||
return db.GetEngine(ctx).Table("action_runner").
|
||||
Join("LEFT", "`repository`", "`action_runner`.repo_id = `repository`.id").
|
||||
Where("`action_runner`.repo_id != ?", 0).
|
||||
And(builder.IsNull{"`repository`.id"}).
|
||||
Count(new(ActionRunner))
|
||||
}
|
||||
|
||||
func FixRunnersWithoutBelongingRepo(ctx context.Context) (int64, error) {
|
||||
subQuery := builder.Select("`action_runner`.id").
|
||||
From("`action_runner`").
|
||||
Join("LEFT", "`repository`", "`action_runner`.repo_id = `repository`.id").
|
||||
Where(builder.Neq{"`action_runner`.repo_id": 0}).
|
||||
And(builder.IsNull{"`repository`.id"})
|
||||
b := builder.Delete(builder.In("id", subQuery)).From("`action_runner`")
|
||||
res, err := db.GetEngine(ctx).Exec(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func CountWrongRepoLevelRunners(ctx context.Context) (int64, error) {
|
||||
var result int64
|
||||
_, err := db.GetEngine(ctx).SQL("SELECT count(`id`) FROM `action_runner` WHERE `repo_id` > 0 AND `owner_id` > 0").Get(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func UpdateWrongRepoLevelRunners(ctx context.Context) (int64, error) {
|
||||
result, err := db.GetEngine(ctx).Exec("UPDATE `action_runner` SET `owner_id` = 0 WHERE `repo_id` > 0 AND `owner_id` > 0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
65
models/actions/runner_list.go
Normal file
65
models/actions/runner_list.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
)
|
||||
|
||||
type RunnerList []*ActionRunner
|
||||
|
||||
// GetUserIDs returns a slice of user's id
|
||||
func (runners RunnerList) GetUserIDs() []int64 {
|
||||
return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) {
|
||||
return runner.OwnerID, runner.OwnerID != 0
|
||||
})
|
||||
}
|
||||
|
||||
func (runners RunnerList) LoadOwners(ctx context.Context) error {
|
||||
userIDs := runners.GetUserIDs()
|
||||
users := make(map[int64]*user_model.User, len(userIDs))
|
||||
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, runner := range runners {
|
||||
if runner.OwnerID > 0 && runner.Owner == nil {
|
||||
runner.Owner = users[runner.OwnerID]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runners RunnerList) getRepoIDs() []int64 {
|
||||
return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) {
|
||||
return runner.RepoID, runner.RepoID > 0
|
||||
})
|
||||
}
|
||||
|
||||
func (runners RunnerList) LoadRepos(ctx context.Context) error {
|
||||
repoIDs := runners.getRepoIDs()
|
||||
repos := make(map[int64]*repo_model.Repository, len(repoIDs))
|
||||
if err := db.GetEngine(ctx).In("id", repoIDs).Find(&repos); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, runner := range runners {
|
||||
if runner.RepoID > 0 && runner.Repo == nil {
|
||||
runner.Repo = repos[runner.RepoID]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runners RunnerList) LoadAttributes(ctx context.Context) error {
|
||||
if err := runners.LoadOwners(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return runners.LoadRepos(ctx)
|
||||
}
|
||||
124
models/actions/runner_token.go
Normal file
124
models/actions/runner_token.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// ActionRunnerToken represents runner tokens
|
||||
//
|
||||
// It can be:
|
||||
// 1. global token, OwnerID is 0 and RepoID is 0
|
||||
// 2. org/user level token, OwnerID is org/user ID and RepoID is 0
|
||||
// 3. repo level token, OwnerID is 0 and RepoID is repo ID
|
||||
//
|
||||
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
|
||||
// or it will be complicated to find tokens belonging to a specific owner.
|
||||
// For example, conditions like `OwnerID = 1` will also return token {OwnerID: 1, RepoID: 1},
|
||||
// but it's a repo level token, not an org/user level token.
|
||||
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level tokens.
|
||||
type ActionRunnerToken struct {
|
||||
ID int64
|
||||
Token string `xorm:"UNIQUE"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
Owner *user_model.User `xorm:"-"`
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
IsActive bool // true means it can be used
|
||||
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
Deleted timeutil.TimeStamp `xorm:"deleted"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionRunnerToken))
|
||||
}
|
||||
|
||||
// GetRunnerToken returns a action runner via token
|
||||
func GetRunnerToken(ctx context.Context, token string) (*ActionRunnerToken, error) {
|
||||
var runnerToken ActionRunnerToken
|
||||
has, err := db.GetEngine(ctx).Where("token=?", token).Get(&runnerToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf(`runner token "%s...": %w`, util.TruncateRunes(token, 3), util.ErrNotExist)
|
||||
}
|
||||
return &runnerToken, nil
|
||||
}
|
||||
|
||||
// UpdateRunnerToken updates runner token information.
|
||||
func UpdateRunnerToken(ctx context.Context, r *ActionRunnerToken, cols ...string) (err error) {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
if len(cols) == 0 {
|
||||
_, err = e.ID(r.ID).AllCols().Update(r)
|
||||
} else {
|
||||
_, err = e.ID(r.ID).Cols(cols...).Update(r)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// NewRunnerTokenWithValue creates a new active runner token and invalidate all old tokens
|
||||
// ownerID will be ignored and treated as 0 if repoID is non-zero.
|
||||
func NewRunnerTokenWithValue(ctx context.Context, ownerID, repoID int64, token string) (*ActionRunnerToken, error) {
|
||||
if ownerID != 0 && repoID != 0 {
|
||||
// It's trying to create a runner token that belongs to a repository, but OwnerID has been set accidentally.
|
||||
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
|
||||
ownerID = 0
|
||||
}
|
||||
|
||||
runnerToken := &ActionRunnerToken{
|
||||
OwnerID: ownerID,
|
||||
RepoID: repoID,
|
||||
IsActive: true,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
return runnerToken, db.WithTx(ctx, func(ctx context.Context) error {
|
||||
if _, err := db.GetEngine(ctx).Where("owner_id =? AND repo_id = ?", ownerID, repoID).Cols("is_active").Update(&ActionRunnerToken{
|
||||
IsActive: false,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := db.GetEngine(ctx).Insert(runnerToken)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func NewRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) {
|
||||
token, err := util.CryptoRandomString(40)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRunnerTokenWithValue(ctx, ownerID, repoID, token)
|
||||
}
|
||||
|
||||
// GetLatestRunnerToken returns the latest runner token
|
||||
func GetLatestRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) {
|
||||
if ownerID != 0 && repoID != 0 {
|
||||
// It's trying to get a runner token that belongs to a repository, but OwnerID has been set accidentally.
|
||||
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
|
||||
ownerID = 0
|
||||
}
|
||||
|
||||
var runnerToken ActionRunnerToken
|
||||
has, err := db.GetEngine(ctx).Where("owner_id=? AND repo_id=?", ownerID, repoID).
|
||||
OrderBy("id DESC").Get(&runnerToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("runner token: %w", util.ErrNotExist)
|
||||
}
|
||||
return &runnerToken, nil
|
||||
}
|
||||
39
models/actions/runner_token_test.go
Normal file
39
models/actions/runner_token_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetLatestRunnerToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := unittest.AssertExistsAndLoadBean(t, &ActionRunnerToken{ID: 3})
|
||||
expectedToken, err := GetLatestRunnerToken(t.Context(), 1, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedToken, token)
|
||||
}
|
||||
|
||||
func TestNewRunnerToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := NewRunnerToken(t.Context(), 1, 0)
|
||||
assert.NoError(t, err)
|
||||
expectedToken, err := GetLatestRunnerToken(t.Context(), 1, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedToken, token)
|
||||
}
|
||||
|
||||
func TestUpdateRunnerToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := unittest.AssertExistsAndLoadBean(t, &ActionRunnerToken{ID: 3})
|
||||
token.IsActive = true
|
||||
assert.NoError(t, UpdateRunnerToken(t.Context(), token))
|
||||
expectedToken, err := GetLatestRunnerToken(t.Context(), 1, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedToken, token)
|
||||
}
|
||||
127
models/actions/schedule.go
Normal file
127
models/actions/schedule.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
webhook_module "code.gitea.io/gitea/modules/webhook"
|
||||
)
|
||||
|
||||
// ActionSchedule represents a schedule of a workflow file
|
||||
type ActionSchedule struct {
|
||||
ID int64
|
||||
Title string
|
||||
Specs []string
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
WorkflowID string
|
||||
TriggerUserID int64
|
||||
TriggerUser *user_model.User `xorm:"-"`
|
||||
Ref string
|
||||
CommitSHA string
|
||||
Event webhook_module.HookEventType
|
||||
EventPayload string `xorm:"LONGTEXT"`
|
||||
Content []byte
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionSchedule))
|
||||
}
|
||||
|
||||
// GetSchedulesMapByIDs returns the schedules by given id slice.
|
||||
func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) {
|
||||
schedules := make(map[int64]*ActionSchedule, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return schedules, nil
|
||||
}
|
||||
return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules)
|
||||
}
|
||||
|
||||
// CreateScheduleTask creates new schedule task.
|
||||
func CreateScheduleTask(ctx context.Context, rows []*ActionSchedule) error {
|
||||
// Return early if there are no rows to insert
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
// Loop through each schedule row
|
||||
for _, row := range rows {
|
||||
row.Title = util.EllipsisDisplayString(row.Title, 255)
|
||||
// Create new schedule row
|
||||
if err := db.Insert(ctx, row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Loop through each schedule spec and create a new spec row
|
||||
now := time.Now()
|
||||
|
||||
for _, spec := range row.Specs {
|
||||
specRow := &ActionScheduleSpec{
|
||||
RepoID: row.RepoID,
|
||||
ScheduleID: row.ID,
|
||||
Spec: spec,
|
||||
}
|
||||
// Parse the spec and check for errors
|
||||
schedule, err := specRow.Parse()
|
||||
if err != nil {
|
||||
continue // skip to the next spec if there's an error
|
||||
}
|
||||
|
||||
specRow.Next = timeutil.TimeStamp(schedule.Next(now).Unix())
|
||||
|
||||
// Insert the new schedule spec row
|
||||
if err = db.Insert(ctx, specRow); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteScheduleTaskByRepo(ctx context.Context, id int64) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
if _, err := db.GetEngine(ctx).Delete(&ActionSchedule{RepoID: id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := db.GetEngine(ctx).Delete(&ActionScheduleSpec{RepoID: id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func CleanRepoScheduleTasks(ctx context.Context, repo *repo_model.Repository) ([]*ActionRunJob, error) {
|
||||
// If actions disabled when there is schedule task, this will remove the outdated schedule tasks
|
||||
// There is no other place we can do this because the app.ini will be changed manually
|
||||
if err := DeleteScheduleTaskByRepo(ctx, repo.ID); err != nil {
|
||||
return nil, fmt.Errorf("DeleteCronTaskByRepo: %v", err)
|
||||
}
|
||||
// cancel running cron jobs of this repository and delete old schedules
|
||||
jobs, err := CancelPreviousJobs(
|
||||
ctx,
|
||||
repo.ID,
|
||||
repo.DefaultBranch,
|
||||
"",
|
||||
webhook_module.HookEventSchedule,
|
||||
)
|
||||
if err != nil {
|
||||
return jobs, fmt.Errorf("CancelPreviousJobs: %v", err)
|
||||
}
|
||||
return jobs, nil
|
||||
}
|
||||
83
models/actions/schedule_list.go
Normal file
83
models/actions/schedule_list.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type ScheduleList []*ActionSchedule
|
||||
|
||||
// GetUserIDs returns a slice of user's id
|
||||
func (schedules ScheduleList) GetUserIDs() []int64 {
|
||||
return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
|
||||
return schedule.TriggerUserID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (schedules ScheduleList) GetRepoIDs() []int64 {
|
||||
return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
|
||||
return schedule.RepoID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
|
||||
userIDs := schedules.GetUserIDs()
|
||||
users := make(map[int64]*user_model.User, len(userIDs))
|
||||
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, schedule := range schedules {
|
||||
if schedule.TriggerUserID == user_model.ActionsUserID {
|
||||
schedule.TriggerUser = user_model.NewActionsUser()
|
||||
} else {
|
||||
schedule.TriggerUser = users[schedule.TriggerUserID]
|
||||
if schedule.TriggerUser == nil {
|
||||
schedule.TriggerUser = user_model.NewGhostUser()
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
|
||||
repoIDs := schedules.GetRepoIDs()
|
||||
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, schedule := range schedules {
|
||||
schedule.Repo = repos[schedule.RepoID]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FindScheduleOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
OwnerID int64
|
||||
}
|
||||
|
||||
func (opts FindScheduleOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.OwnerID > 0 {
|
||||
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
}
|
||||
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindScheduleOptions) ToOrders() string {
|
||||
return "`id` DESC"
|
||||
}
|
||||
73
models/actions/schedule_spec.go
Normal file
73
models/actions/schedule_spec.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
// ActionScheduleSpec represents a schedule spec of a workflow file
|
||||
type ActionScheduleSpec struct {
|
||||
ID int64
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
ScheduleID int64 `xorm:"index"`
|
||||
Schedule *ActionSchedule `xorm:"-"`
|
||||
|
||||
// Next time the job will run, or the zero time if Cron has not been
|
||||
// started or this entry's schedule is unsatisfiable
|
||||
Next timeutil.TimeStamp `xorm:"index"`
|
||||
// Prev is the last time this job was run, or the zero time if never.
|
||||
Prev timeutil.TimeStamp
|
||||
Spec string
|
||||
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
// Parse parses the spec and returns a cron.Schedule
|
||||
// Unlike the default cron parser, Parse uses UTC timezone as the default if none is specified.
|
||||
func (s *ActionScheduleSpec) Parse() (cron.Schedule, error) {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
|
||||
schedule, err := parser.Parse(s.Spec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the spec has specified a timezone, use it
|
||||
if strings.HasPrefix(s.Spec, "TZ=") || strings.HasPrefix(s.Spec, "CRON_TZ=") {
|
||||
return schedule, nil
|
||||
}
|
||||
|
||||
specSchedule, ok := schedule.(*cron.SpecSchedule)
|
||||
// If it's not a spec schedule, like "@every 5m", timezone is not relevant
|
||||
if !ok {
|
||||
return schedule, nil
|
||||
}
|
||||
|
||||
// Set the timezone to UTC
|
||||
specSchedule.Location = time.UTC
|
||||
return specSchedule, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionScheduleSpec))
|
||||
}
|
||||
|
||||
func UpdateScheduleSpec(ctx context.Context, spec *ActionScheduleSpec, cols ...string) error {
|
||||
sess := db.GetEngine(ctx).ID(spec.ID)
|
||||
if len(cols) > 0 {
|
||||
sess.Cols(cols...)
|
||||
}
|
||||
_, err := sess.Update(spec)
|
||||
return err
|
||||
}
|
||||
97
models/actions/schedule_spec_list.go
Normal file
97
models/actions/schedule_spec_list.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type SpecList []*ActionScheduleSpec
|
||||
|
||||
func (specs SpecList) GetScheduleIDs() []int64 {
|
||||
return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
|
||||
return spec.ScheduleID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (specs SpecList) LoadSchedules(ctx context.Context) error {
|
||||
scheduleIDs := specs.GetScheduleIDs()
|
||||
schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, spec := range specs {
|
||||
spec.Schedule = schedules[spec.ScheduleID]
|
||||
}
|
||||
|
||||
repoIDs := specs.GetRepoIDs()
|
||||
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, spec := range specs {
|
||||
spec.Repo = repos[spec.RepoID]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (specs SpecList) GetRepoIDs() []int64 {
|
||||
return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
|
||||
return spec.RepoID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (specs SpecList) LoadRepos(ctx context.Context) error {
|
||||
repoIDs := specs.GetRepoIDs()
|
||||
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, spec := range specs {
|
||||
spec.Repo = repos[spec.RepoID]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FindSpecOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
Next int64
|
||||
}
|
||||
|
||||
func (opts FindSpecOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
|
||||
}
|
||||
|
||||
if opts.Next > 0 {
|
||||
cond = cond.And(builder.Lte{"next": opts.Next})
|
||||
}
|
||||
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindSpecOptions) ToOrders() string {
|
||||
return "`id` DESC"
|
||||
}
|
||||
|
||||
func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, error) {
|
||||
specs, total, err := db.FindAndCount[ActionScheduleSpec](ctx, opts)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := SpecList(specs).LoadSchedules(ctx); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return specs, total, nil
|
||||
}
|
||||
69
models/actions/schedule_spec_test.go
Normal file
69
models/actions/schedule_spec_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/modules/test"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestActionScheduleSpec_Parse(t *testing.T) {
|
||||
// Mock the local timezone is not UTC
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
defer test.MockVariableValue(&time.Local, tz)()
|
||||
|
||||
now, err := time.Parse(time.RFC3339, "2024-07-31T15:47:55+08:00")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
want string
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "regular",
|
||||
spec: "0 10 * * *",
|
||||
want: "2024-07-31T10:00:00Z",
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
spec: "0 10 * *",
|
||||
want: "",
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "with timezone",
|
||||
spec: "TZ=America/New_York 0 10 * * *",
|
||||
want: "2024-07-31T14:00:00Z",
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "timezone irrelevant",
|
||||
spec: "@every 5m",
|
||||
want: "2024-07-31T07:52:55Z",
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &ActionScheduleSpec{
|
||||
Spec: tt.spec,
|
||||
}
|
||||
got, err := s.Parse()
|
||||
tt.wantErr(t, err)
|
||||
|
||||
if err == nil {
|
||||
assert.Equal(t, tt.want, got.Next(now).UTC().Format(time.RFC3339))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
101
models/actions/status.go
Normal file
101
models/actions/status.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"code.gitea.io/gitea/modules/translation"
|
||||
|
||||
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
|
||||
)
|
||||
|
||||
// Status represents the status of ActionRun, ActionRunJob, ActionTask, or ActionTaskStep
|
||||
type Status int
|
||||
|
||||
const (
|
||||
StatusUnknown Status = iota // 0, consistent with runnerv1.Result_RESULT_UNSPECIFIED
|
||||
StatusSuccess // 1, consistent with runnerv1.Result_RESULT_SUCCESS
|
||||
StatusFailure // 2, consistent with runnerv1.Result_RESULT_FAILURE
|
||||
StatusCancelled // 3, consistent with runnerv1.Result_RESULT_CANCELLED
|
||||
StatusSkipped // 4, consistent with runnerv1.Result_RESULT_SKIPPED
|
||||
StatusWaiting // 5, isn't a runnerv1.Result
|
||||
StatusRunning // 6, isn't a runnerv1.Result
|
||||
StatusBlocked // 7, isn't a runnerv1.Result
|
||||
)
|
||||
|
||||
var statusNames = map[Status]string{
|
||||
StatusUnknown: "unknown",
|
||||
StatusWaiting: "waiting",
|
||||
StatusRunning: "running",
|
||||
StatusSuccess: "success",
|
||||
StatusFailure: "failure",
|
||||
StatusCancelled: "cancelled",
|
||||
StatusSkipped: "skipped",
|
||||
StatusBlocked: "blocked",
|
||||
}
|
||||
|
||||
// String returns the string name of the Status
|
||||
func (s Status) String() string {
|
||||
return statusNames[s]
|
||||
}
|
||||
|
||||
// LocaleString returns the locale string name of the Status
|
||||
func (s Status) LocaleString(lang translation.Locale) string {
|
||||
return lang.TrString("actions.status." + s.String())
|
||||
}
|
||||
|
||||
// IsDone returns whether the Status is final
|
||||
func (s Status) IsDone() bool {
|
||||
return s.In(StatusSuccess, StatusFailure, StatusCancelled, StatusSkipped)
|
||||
}
|
||||
|
||||
// HasRun returns whether the Status is a result of running
|
||||
func (s Status) HasRun() bool {
|
||||
return s.In(StatusSuccess, StatusFailure)
|
||||
}
|
||||
|
||||
func (s Status) IsUnknown() bool {
|
||||
return s == StatusUnknown
|
||||
}
|
||||
|
||||
func (s Status) IsSuccess() bool {
|
||||
return s == StatusSuccess
|
||||
}
|
||||
|
||||
func (s Status) IsFailure() bool {
|
||||
return s == StatusFailure
|
||||
}
|
||||
|
||||
func (s Status) IsCancelled() bool {
|
||||
return s == StatusCancelled
|
||||
}
|
||||
|
||||
func (s Status) IsSkipped() bool {
|
||||
return s == StatusSkipped
|
||||
}
|
||||
|
||||
func (s Status) IsWaiting() bool {
|
||||
return s == StatusWaiting
|
||||
}
|
||||
|
||||
func (s Status) IsRunning() bool {
|
||||
return s == StatusRunning
|
||||
}
|
||||
|
||||
func (s Status) IsBlocked() bool {
|
||||
return s == StatusBlocked
|
||||
}
|
||||
|
||||
// In returns whether s is one of the given statuses
|
||||
func (s Status) In(statuses ...Status) bool {
|
||||
return slices.Contains(statuses, s)
|
||||
}
|
||||
|
||||
func (s Status) AsResult() runnerv1.Result {
|
||||
if s.IsDone() {
|
||||
return runnerv1.Result(s)
|
||||
}
|
||||
return runnerv1.Result_RESULT_UNSPECIFIED
|
||||
}
|
||||
526
models/actions/task.go
Normal file
526
models/actions/task.go
Normal file
@@ -0,0 +1,526 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unit"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/nektos/act/pkg/jobparser"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionTask represents a distribution of job
|
||||
type ActionTask struct {
|
||||
ID int64
|
||||
JobID int64
|
||||
Job *ActionRunJob `xorm:"-"`
|
||||
Steps []*ActionTaskStep `xorm:"-"`
|
||||
Attempt int64
|
||||
RunnerID int64 `xorm:"index"`
|
||||
Status Status `xorm:"index"`
|
||||
Started timeutil.TimeStamp `xorm:"index"`
|
||||
Stopped timeutil.TimeStamp `xorm:"index(stopped_log_expired)"`
|
||||
|
||||
RepoID int64 `xorm:"index"`
|
||||
OwnerID int64 `xorm:"index"`
|
||||
CommitSHA string `xorm:"index"`
|
||||
IsForkPullRequest bool
|
||||
|
||||
Token string `xorm:"-"`
|
||||
TokenHash string `xorm:"UNIQUE"` // sha256 of token
|
||||
TokenSalt string
|
||||
TokenLastEight string `xorm:"index token_last_eight"`
|
||||
|
||||
LogFilename string // file name of log
|
||||
LogInStorage bool // read log from database or from storage
|
||||
LogLength int64 // lines count
|
||||
LogSize int64 // blob size
|
||||
LogIndexes LogIndexes `xorm:"LONGBLOB"` // line number to offset
|
||||
LogExpired bool `xorm:"index(stopped_log_expired)"` // files that are too old will be deleted
|
||||
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated index"`
|
||||
}
|
||||
|
||||
var successfulTokenTaskCache *lru.Cache[string, any]
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionTask), func() error {
|
||||
if setting.SuccessfulTokensCacheSize > 0 {
|
||||
var err error
|
||||
successfulTokenTaskCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to allocate Task cache: %v", err)
|
||||
}
|
||||
} else {
|
||||
successfulTokenTaskCache = nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (task *ActionTask) Duration() time.Duration {
|
||||
return calculateDuration(task.Started, task.Stopped, task.Status)
|
||||
}
|
||||
|
||||
func (task *ActionTask) IsStopped() bool {
|
||||
return task.Stopped > 0
|
||||
}
|
||||
|
||||
func (task *ActionTask) GetRunLink() string {
|
||||
if task.Job == nil || task.Job.Run == nil {
|
||||
return ""
|
||||
}
|
||||
return task.Job.Run.Link()
|
||||
}
|
||||
|
||||
func (task *ActionTask) GetCommitLink() string {
|
||||
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return task.Job.Run.Repo.CommitLink(task.CommitSHA)
|
||||
}
|
||||
|
||||
func (task *ActionTask) GetRepoName() string {
|
||||
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return task.Job.Run.Repo.FullName()
|
||||
}
|
||||
|
||||
func (task *ActionTask) GetRepoLink() string {
|
||||
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
|
||||
return ""
|
||||
}
|
||||
return task.Job.Run.Repo.Link()
|
||||
}
|
||||
|
||||
func (task *ActionTask) LoadJob(ctx context.Context) error {
|
||||
if task.Job == nil {
|
||||
job, err := GetRunJobByID(ctx, task.JobID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task.Job = job
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAttributes load Job Steps if not loaded
|
||||
func (task *ActionTask) LoadAttributes(ctx context.Context) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
if err := task.LoadJob(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := task.Job.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if task.Steps == nil { // be careful, an empty slice (not nil) also means loaded
|
||||
steps, err := GetTaskStepsByTaskID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task.Steps = steps
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (task *ActionTask) GenerateToken() (err error) {
|
||||
task.Token, task.TokenSalt, task.TokenHash, task.TokenLastEight, err = generateSaltedToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func GetTaskByID(ctx context.Context, id int64) (*ActionTask, error) {
|
||||
var task ActionTask
|
||||
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("task with id %d: %w", id, util.ErrNotExist)
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func GetRunningTaskByToken(ctx context.Context, token string) (*ActionTask, error) {
|
||||
errNotExist := fmt.Errorf("task with token %q: %w", token, util.ErrNotExist)
|
||||
if token == "" {
|
||||
return nil, errNotExist
|
||||
}
|
||||
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
|
||||
if len(token) != 40 {
|
||||
return nil, errNotExist
|
||||
}
|
||||
for _, x := range []byte(token) {
|
||||
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
|
||||
return nil, errNotExist
|
||||
}
|
||||
}
|
||||
|
||||
lastEight := token[len(token)-8:]
|
||||
|
||||
if id := getTaskIDFromCache(token); id > 0 {
|
||||
task := &ActionTask{
|
||||
TokenLastEight: lastEight,
|
||||
}
|
||||
// Re-get the task from the db in case it has been deleted in the intervening period
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if has {
|
||||
return task, nil
|
||||
}
|
||||
successfulTokenTaskCache.Remove(token)
|
||||
}
|
||||
|
||||
var tasks []*ActionTask
|
||||
err := db.GetEngine(ctx).Where("token_last_eight = ? AND status = ?", lastEight, StatusRunning).Find(&tasks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(tasks) == 0 {
|
||||
return nil, errNotExist
|
||||
}
|
||||
|
||||
for _, t := range tasks {
|
||||
tempHash := auth_model.HashToken(token, t.TokenSalt)
|
||||
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
|
||||
if successfulTokenTaskCache != nil {
|
||||
successfulTokenTaskCache.Add(token, t.ID)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return nil, errNotExist
|
||||
}
|
||||
|
||||
func CreateTaskForRunner(ctx context.Context, runner *ActionRunner) (*ActionTask, bool, error) {
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
jobCond := builder.NewCond()
|
||||
if runner.RepoID != 0 {
|
||||
jobCond = builder.Eq{"repo_id": runner.RepoID}
|
||||
} else if runner.OwnerID != 0 {
|
||||
jobCond = builder.In("repo_id", builder.Select("`repository`.id").From("repository").
|
||||
Join("INNER", "repo_unit", "`repository`.id = `repo_unit`.repo_id").
|
||||
Where(builder.Eq{"`repository`.owner_id": runner.OwnerID, "`repo_unit`.type": unit.TypeActions}))
|
||||
}
|
||||
if jobCond.IsValid() {
|
||||
jobCond = builder.In("run_id", builder.Select("id").From("action_run").Where(jobCond))
|
||||
}
|
||||
|
||||
var jobs []*ActionRunJob
|
||||
if err := e.Where("task_id=? AND status=?", 0, StatusWaiting).And(jobCond).Asc("updated", "id").Find(&jobs); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// TODO: a more efficient way to filter labels
|
||||
var job *ActionRunJob
|
||||
log.Trace("runner labels: %v", runner.AgentLabels)
|
||||
for _, v := range jobs {
|
||||
if isSubset(runner.AgentLabels, v.RunsOn) {
|
||||
job = v
|
||||
break
|
||||
}
|
||||
}
|
||||
if job == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err := job.LoadAttributes(ctx); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
now := timeutil.TimeStampNow()
|
||||
job.Attempt++
|
||||
job.Started = now
|
||||
job.Status = StatusRunning
|
||||
|
||||
task := &ActionTask{
|
||||
JobID: job.ID,
|
||||
Attempt: job.Attempt,
|
||||
RunnerID: runner.ID,
|
||||
Started: now,
|
||||
Status: StatusRunning,
|
||||
RepoID: job.RepoID,
|
||||
OwnerID: job.OwnerID,
|
||||
CommitSHA: job.CommitSHA,
|
||||
IsForkPullRequest: job.IsForkPullRequest,
|
||||
}
|
||||
if err := task.GenerateToken(); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
parsedWorkflows, err := jobparser.Parse(job.WorkflowPayload)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parse workflow of job %d: %w", job.ID, err)
|
||||
} else if len(parsedWorkflows) != 1 {
|
||||
return nil, false, fmt.Errorf("workflow of job %d: not single workflow", job.ID)
|
||||
}
|
||||
_, workflowJob := parsedWorkflows[0].Job()
|
||||
|
||||
if _, err := e.Insert(task); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
task.LogFilename = logFileName(job.Run.Repo.FullName(), task.ID)
|
||||
if err := UpdateTask(ctx, task, "log_filename"); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(workflowJob.Steps) > 0 {
|
||||
steps := make([]*ActionTaskStep, len(workflowJob.Steps))
|
||||
for i, v := range workflowJob.Steps {
|
||||
name := util.EllipsisDisplayString(v.String(), 255)
|
||||
steps[i] = &ActionTaskStep{
|
||||
Name: name,
|
||||
TaskID: task.ID,
|
||||
Index: int64(i),
|
||||
RepoID: task.RepoID,
|
||||
Status: StatusWaiting,
|
||||
}
|
||||
}
|
||||
if _, err := e.Insert(steps); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
task.Steps = steps
|
||||
}
|
||||
|
||||
job.TaskID = task.ID
|
||||
if n, err := UpdateRunJob(ctx, job, builder.Eq{"task_id": 0}); err != nil {
|
||||
return nil, false, err
|
||||
} else if n != 1 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
task.Job = job
|
||||
|
||||
if err := committer.Commit(); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return task, true, nil
|
||||
}
|
||||
|
||||
func UpdateTask(ctx context.Context, task *ActionTask, cols ...string) error {
|
||||
sess := db.GetEngine(ctx).ID(task.ID)
|
||||
if len(cols) > 0 {
|
||||
sess.Cols(cols...)
|
||||
}
|
||||
_, err := sess.Update(task)
|
||||
|
||||
// Automatically delete the ephemeral runner if the task is done
|
||||
if err == nil && task.Status.IsDone() && util.SliceContainsString(cols, "status") {
|
||||
return DeleteEphemeralRunner(ctx, task.RunnerID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTaskByState updates the task by the state.
|
||||
// It will always update the task if the state is not final, even there is no change.
|
||||
// So it will update ActionTask.Updated to avoid the task being judged as a zombie task.
|
||||
func UpdateTaskByState(ctx context.Context, runnerID int64, state *runnerv1.TaskState) (*ActionTask, error) {
|
||||
stepStates := map[int64]*runnerv1.StepState{}
|
||||
for _, v := range state.Steps {
|
||||
stepStates[v.Id] = v
|
||||
}
|
||||
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*ActionTask, error) {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
task := &ActionTask{}
|
||||
if has, err := e.ID(state.Id).Get(task); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, util.ErrNotExist
|
||||
} else if runnerID != task.RunnerID {
|
||||
return nil, errors.New("invalid runner for task")
|
||||
}
|
||||
|
||||
if task.Status.IsDone() {
|
||||
// the state is final, do nothing
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// state.Result is not unspecified means the task is finished
|
||||
if state.Result != runnerv1.Result_RESULT_UNSPECIFIED {
|
||||
task.Status = Status(state.Result)
|
||||
task.Stopped = timeutil.TimeStamp(state.StoppedAt.AsTime().Unix())
|
||||
if err := UpdateTask(ctx, task, "status", "stopped"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := UpdateRunJob(ctx, &ActionRunJob{
|
||||
ID: task.JobID,
|
||||
Status: task.Status,
|
||||
Stopped: task.Stopped,
|
||||
}, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Force update ActionTask.Updated to avoid the task being judged as a zombie task
|
||||
task.Updated = timeutil.TimeStampNow()
|
||||
if err := UpdateTask(ctx, task, "updated"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := task.LoadAttributes(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, step := range task.Steps {
|
||||
var result runnerv1.Result
|
||||
if v, ok := stepStates[step.Index]; ok {
|
||||
result = v.Result
|
||||
step.LogIndex = v.LogIndex
|
||||
step.LogLength = v.LogLength
|
||||
step.Started = convertTimestamp(v.StartedAt)
|
||||
step.Stopped = convertTimestamp(v.StoppedAt)
|
||||
}
|
||||
if result != runnerv1.Result_RESULT_UNSPECIFIED {
|
||||
step.Status = Status(result)
|
||||
} else if step.Started != 0 {
|
||||
step.Status = StatusRunning
|
||||
}
|
||||
if _, err := e.ID(step.ID).Update(step); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return task, nil
|
||||
})
|
||||
}
|
||||
|
||||
func StopTask(ctx context.Context, taskID int64, status Status) error {
|
||||
if !status.IsDone() {
|
||||
return fmt.Errorf("cannot stop task with status %v", status)
|
||||
}
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
task := &ActionTask{}
|
||||
if has, err := e.ID(taskID).Get(task); err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
if task.Status.IsDone() {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := timeutil.TimeStampNow()
|
||||
task.Status = status
|
||||
task.Stopped = now
|
||||
if _, err := UpdateRunJob(ctx, &ActionRunJob{
|
||||
ID: task.JobID,
|
||||
Status: task.Status,
|
||||
Stopped: task.Stopped,
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := UpdateTask(ctx, task, "status", "stopped"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := task.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, step := range task.Steps {
|
||||
if !step.Status.IsDone() {
|
||||
step.Status = status
|
||||
if step.Started == 0 {
|
||||
step.Started = now
|
||||
}
|
||||
step.Stopped = now
|
||||
}
|
||||
if _, err := e.ID(step.ID).Update(step); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindOldTasksToExpire(ctx context.Context, olderThan timeutil.TimeStamp, limit int) ([]*ActionTask, error) {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
tasks := make([]*ActionTask, 0, limit)
|
||||
// Check "stopped > 0" to avoid deleting tasks that are still running
|
||||
return tasks, e.Where("stopped > 0 AND stopped < ? AND log_expired = ?", olderThan, false).
|
||||
Limit(limit).
|
||||
Find(&tasks)
|
||||
}
|
||||
|
||||
func isSubset(set, subset []string) bool {
|
||||
m := make(container.Set[string], len(set))
|
||||
for _, v := range set {
|
||||
m.Add(v)
|
||||
}
|
||||
|
||||
for _, v := range subset {
|
||||
if !m.Contains(v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func convertTimestamp(timestamp *timestamppb.Timestamp) timeutil.TimeStamp {
|
||||
if timestamp.GetSeconds() == 0 && timestamp.GetNanos() == 0 {
|
||||
return timeutil.TimeStamp(0)
|
||||
}
|
||||
return timeutil.TimeStamp(timestamp.AsTime().Unix())
|
||||
}
|
||||
|
||||
func logFileName(repoFullName string, taskID int64) string {
|
||||
ret := fmt.Sprintf("%s/%02x/%d.log", repoFullName, taskID%256, taskID)
|
||||
|
||||
if setting.Actions.LogCompression.IsZstd() {
|
||||
ret += ".zst"
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func getTaskIDFromCache(token string) int64 {
|
||||
if successfulTokenTaskCache == nil {
|
||||
return 0
|
||||
}
|
||||
tInterface, ok := successfulTokenTaskCache.Get(token)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
t, ok := tInterface.(int64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return t
|
||||
}
|
||||
91
models/actions/task_list.go
Normal file
91
models/actions/task_list.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type TaskList []*ActionTask
|
||||
|
||||
func (tasks TaskList) GetJobIDs() []int64 {
|
||||
return container.FilterSlice(tasks, func(t *ActionTask) (int64, bool) {
|
||||
return t.JobID, t.JobID != 0
|
||||
})
|
||||
}
|
||||
|
||||
func (tasks TaskList) LoadJobs(ctx context.Context) error {
|
||||
jobIDs := tasks.GetJobIDs()
|
||||
jobs := make(map[int64]*ActionRunJob, len(jobIDs))
|
||||
if err := db.GetEngine(ctx).In("id", jobIDs).Find(&jobs); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range tasks {
|
||||
if t.JobID > 0 && t.Job == nil {
|
||||
t.Job = jobs[t.JobID]
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Replace with "ActionJobList(maps.Values(jobs))" once available
|
||||
var jobsList ActionJobList = make([]*ActionRunJob, 0, len(jobs))
|
||||
for _, j := range jobs {
|
||||
jobsList = append(jobsList, j)
|
||||
}
|
||||
return jobsList.LoadAttributes(ctx, true)
|
||||
}
|
||||
|
||||
func (tasks TaskList) LoadAttributes(ctx context.Context) error {
|
||||
return tasks.LoadJobs(ctx)
|
||||
}
|
||||
|
||||
type FindTaskOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
JobID int64
|
||||
OwnerID int64
|
||||
CommitSHA string
|
||||
Status Status
|
||||
UpdatedBefore timeutil.TimeStamp
|
||||
StartedBefore timeutil.TimeStamp
|
||||
RunnerID int64
|
||||
}
|
||||
|
||||
func (opts FindTaskOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.RepoID > 0 {
|
||||
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.JobID > 0 {
|
||||
cond = cond.And(builder.Eq{"job_id": opts.JobID})
|
||||
}
|
||||
if opts.OwnerID > 0 {
|
||||
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
}
|
||||
if opts.CommitSHA != "" {
|
||||
cond = cond.And(builder.Eq{"commit_sha": opts.CommitSHA})
|
||||
}
|
||||
if opts.Status > StatusUnknown {
|
||||
cond = cond.And(builder.Eq{"status": opts.Status})
|
||||
}
|
||||
if opts.UpdatedBefore > 0 {
|
||||
cond = cond.And(builder.Lt{"updated": opts.UpdatedBefore})
|
||||
}
|
||||
if opts.StartedBefore > 0 {
|
||||
cond = cond.And(builder.Lt{"started": opts.StartedBefore})
|
||||
}
|
||||
if opts.RunnerID > 0 {
|
||||
cond = cond.And(builder.Eq{"runner_id": opts.RunnerID})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindTaskOptions) ToOrders() string {
|
||||
return "`id` DESC"
|
||||
}
|
||||
55
models/actions/task_output.go
Normal file
55
models/actions/task_output.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
)
|
||||
|
||||
// ActionTaskOutput represents an output of ActionTask.
|
||||
// So the outputs are bound to a task, that means when a completed job has been rerun,
|
||||
// the outputs of the job will be reset because the task is new.
|
||||
// It's by design, to avoid the outputs of the old task to be mixed with the new task.
|
||||
type ActionTaskOutput struct {
|
||||
ID int64
|
||||
TaskID int64 `xorm:"INDEX UNIQUE(task_id_output_key)"`
|
||||
OutputKey string `xorm:"VARCHAR(255) UNIQUE(task_id_output_key)"`
|
||||
OutputValue string `xorm:"MEDIUMTEXT"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionTaskOutput))
|
||||
}
|
||||
|
||||
// FindTaskOutputByTaskID returns the outputs of the task.
|
||||
func FindTaskOutputByTaskID(ctx context.Context, taskID int64) ([]*ActionTaskOutput, error) {
|
||||
var outputs []*ActionTaskOutput
|
||||
return outputs, db.GetEngine(ctx).Where("task_id=?", taskID).Find(&outputs)
|
||||
}
|
||||
|
||||
// FindTaskOutputKeyByTaskID returns the keys of the outputs of the task.
|
||||
func FindTaskOutputKeyByTaskID(ctx context.Context, taskID int64) ([]string, error) {
|
||||
var keys []string
|
||||
return keys, db.GetEngine(ctx).Table(ActionTaskOutput{}).Where("task_id=?", taskID).Cols("output_key").Find(&keys)
|
||||
}
|
||||
|
||||
// InsertTaskOutputIfNotExist inserts a new task output if it does not exist.
|
||||
func InsertTaskOutputIfNotExist(ctx context.Context, taskID int64, key, value string) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
sess := db.GetEngine(ctx)
|
||||
if exist, err := sess.Exist(&ActionTaskOutput{TaskID: taskID, OutputKey: key}); err != nil {
|
||||
return err
|
||||
} else if exist {
|
||||
return nil
|
||||
}
|
||||
_, err := sess.Insert(&ActionTaskOutput{
|
||||
TaskID: taskID,
|
||||
OutputKey: key,
|
||||
OutputValue: value,
|
||||
})
|
||||
return err
|
||||
})
|
||||
}
|
||||
41
models/actions/task_step.go
Normal file
41
models/actions/task_step.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
)
|
||||
|
||||
// ActionTaskStep represents a step of ActionTask
|
||||
type ActionTaskStep struct {
|
||||
ID int64
|
||||
Name string `xorm:"VARCHAR(255)"`
|
||||
TaskID int64 `xorm:"index unique(task_index)"`
|
||||
Index int64 `xorm:"index unique(task_index)"`
|
||||
RepoID int64 `xorm:"index"`
|
||||
Status Status `xorm:"index"`
|
||||
LogIndex int64
|
||||
LogLength int64
|
||||
Started timeutil.TimeStamp
|
||||
Stopped timeutil.TimeStamp
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
Updated timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
func (step *ActionTaskStep) Duration() time.Duration {
|
||||
return calculateDuration(step.Started, step.Stopped, step.Status)
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionTaskStep))
|
||||
}
|
||||
|
||||
func GetTaskStepsByTaskID(ctx context.Context, taskID int64) ([]*ActionTaskStep, error) {
|
||||
var steps []*ActionTaskStep
|
||||
return steps, db.GetEngine(ctx).Where("task_id=?", taskID).OrderBy("`index` ASC").Find(&steps)
|
||||
}
|
||||
101
models/actions/tasks_version.go
Normal file
101
models/actions/tasks_version.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
)
|
||||
|
||||
// ActionTasksVersion
|
||||
// If both ownerID and repoID is zero, its scope is global.
|
||||
// If ownerID is not zero and repoID is zero, its scope is org (there is no user-level runner currently).
|
||||
// If ownerID is zero and repoID is not zero, its scope is repo.
|
||||
type ActionTasksVersion struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
OwnerID int64 `xorm:"UNIQUE(owner_repo)"`
|
||||
RepoID int64 `xorm:"INDEX UNIQUE(owner_repo)"`
|
||||
Version int64
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionTasksVersion))
|
||||
}
|
||||
|
||||
func GetTasksVersionByScope(ctx context.Context, ownerID, repoID int64) (int64, error) {
|
||||
var tasksVersion ActionTasksVersion
|
||||
has, err := db.GetEngine(ctx).Where("owner_id = ? AND repo_id = ?", ownerID, repoID).Get(&tasksVersion)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if !has {
|
||||
return 0, nil
|
||||
}
|
||||
return tasksVersion.Version, err
|
||||
}
|
||||
|
||||
func insertTasksVersion(ctx context.Context, ownerID, repoID int64) (*ActionTasksVersion, error) {
|
||||
tasksVersion := &ActionTasksVersion{
|
||||
OwnerID: ownerID,
|
||||
RepoID: repoID,
|
||||
Version: 1,
|
||||
}
|
||||
if _, err := db.GetEngine(ctx).Insert(tasksVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tasksVersion, nil
|
||||
}
|
||||
|
||||
func increaseTasksVersionByScope(ctx context.Context, ownerID, repoID int64) error {
|
||||
result, err := db.GetEngine(ctx).Exec("UPDATE action_tasks_version SET version = version + 1 WHERE owner_id = ? AND repo_id = ?", ownerID, repoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if affected == 0 {
|
||||
// if update sql does not affect any rows, the database may be broken,
|
||||
// so re-insert the row of version data here.
|
||||
if _, err := insertTasksVersion(ctx, ownerID, repoID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IncreaseTaskVersion(ctx context.Context, ownerID, repoID int64) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
// 1. increase global
|
||||
if err := increaseTasksVersionByScope(ctx, 0, 0); err != nil {
|
||||
log.Error("IncreaseTasksVersionByScope(Global): %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. increase owner
|
||||
if ownerID > 0 {
|
||||
if err := increaseTasksVersionByScope(ctx, ownerID, 0); err != nil {
|
||||
log.Error("IncreaseTasksVersionByScope(Owner): %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. increase repo
|
||||
if repoID > 0 {
|
||||
if err := increaseTasksVersionByScope(ctx, 0, repoID); err != nil {
|
||||
log.Error("IncreaseTasksVersionByScope(Repo): %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
103
models/actions/utils.go
Normal file
103
models/actions/utils.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
func generateSaltedToken() (string, string, string, string, error) {
|
||||
salt, err := util.CryptoRandomString(10)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
buf, err := util.CryptoRandomBytes(20)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
token := hex.EncodeToString(buf)
|
||||
hash := auth_model.HashToken(token, salt)
|
||||
return token, salt, hash, token[len(token)-8:], nil
|
||||
}
|
||||
|
||||
/*
|
||||
LogIndexes is the index for mapping log line number to buffer offset.
|
||||
Because it uses varint encoding, it is impossible to predict its size.
|
||||
But we can make a simple estimate with an assumption that each log line has 200 byte, then:
|
||||
| lines | file size | index size |
|
||||
|-----------|---------------------|--------------------|
|
||||
| 100 | 20 KiB(20000) | 258 B(258) |
|
||||
| 1000 | 195 KiB(200000) | 2.9 KiB(2958) |
|
||||
| 10000 | 1.9 MiB(2000000) | 34 KiB(34715) |
|
||||
| 100000 | 19 MiB(20000000) | 386 KiB(394715) |
|
||||
| 1000000 | 191 MiB(200000000) | 4.1 MiB(4323626) |
|
||||
| 10000000 | 1.9 GiB(2000000000) | 47 MiB(49323626) |
|
||||
| 100000000 | 19 GiB(20000000000) | 490 MiB(513424280) |
|
||||
*/
|
||||
type LogIndexes []int64
|
||||
|
||||
func (indexes *LogIndexes) FromDB(b []byte) error {
|
||||
reader := bytes.NewReader(b)
|
||||
for {
|
||||
v, err := binary.ReadVarint(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("binary ReadVarint: %w", err)
|
||||
}
|
||||
*indexes = append(*indexes, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (indexes *LogIndexes) ToDB() ([]byte, error) {
|
||||
buf, i := make([]byte, binary.MaxVarintLen64*len(*indexes)), 0
|
||||
for _, v := range *indexes {
|
||||
n := binary.PutVarint(buf[i:], v)
|
||||
i += n
|
||||
}
|
||||
return buf[:i], nil
|
||||
}
|
||||
|
||||
var timeSince = time.Since
|
||||
|
||||
func calculateDuration(started, stopped timeutil.TimeStamp, status Status) time.Duration {
|
||||
if started == 0 {
|
||||
return 0
|
||||
}
|
||||
s := started.AsTime()
|
||||
if status.IsDone() {
|
||||
return stopped.AsTime().Sub(s)
|
||||
}
|
||||
return timeSince(s).Truncate(time.Second)
|
||||
}
|
||||
|
||||
// best effort function to convert an action schedule to action run, to be used in GenerateGiteaContext
|
||||
func (s *ActionSchedule) ToActionRun() *ActionRun {
|
||||
return &ActionRun{
|
||||
Title: s.Title,
|
||||
RepoID: s.RepoID,
|
||||
Repo: s.Repo,
|
||||
OwnerID: s.OwnerID,
|
||||
WorkflowID: s.WorkflowID,
|
||||
TriggerUserID: s.TriggerUserID,
|
||||
TriggerUser: s.TriggerUser,
|
||||
Ref: s.Ref,
|
||||
CommitSHA: s.CommitSHA,
|
||||
Event: s.Event,
|
||||
EventPayload: s.EventPayload,
|
||||
Created: s.Created,
|
||||
Updated: s.Updated,
|
||||
}
|
||||
}
|
||||
90
models/actions/utils_test.go
Normal file
90
models/actions/utils_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogIndexes_ToDB(t *testing.T) {
|
||||
tests := []struct {
|
||||
indexes LogIndexes
|
||||
}{
|
||||
{
|
||||
indexes: []int64{1, 2, 0, -1, -2, math.MaxInt64, math.MinInt64},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got, err := tt.indexes.ToDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
indexes := LogIndexes{}
|
||||
require.NoError(t, indexes.FromDB(got))
|
||||
|
||||
assert.Equal(t, tt.indexes, indexes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_calculateDuration(t *testing.T) {
|
||||
oldTimeSince := timeSince
|
||||
defer func() {
|
||||
timeSince = oldTimeSince
|
||||
}()
|
||||
|
||||
timeSince = func(t time.Time) time.Duration {
|
||||
return timeutil.TimeStamp(1000).AsTime().Sub(t)
|
||||
}
|
||||
type args struct {
|
||||
started timeutil.TimeStamp
|
||||
stopped timeutil.TimeStamp
|
||||
status Status
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "unknown",
|
||||
args: args{
|
||||
started: 0,
|
||||
stopped: 0,
|
||||
status: StatusUnknown,
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "running",
|
||||
args: args{
|
||||
started: 500,
|
||||
stopped: 0,
|
||||
status: StatusRunning,
|
||||
},
|
||||
want: 500 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "done",
|
||||
args: args{
|
||||
started: 500,
|
||||
stopped: 600,
|
||||
status: StatusSuccess,
|
||||
},
|
||||
want: 100 * time.Second,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, calculateDuration(tt.args.started, tt.args.stopped, tt.args.status), "calculateDuration(%v, %v, %v)", tt.args.started, tt.args.stopped, tt.args.status)
|
||||
})
|
||||
}
|
||||
}
|
||||
184
models/actions/variable.go
Normal file
184
models/actions/variable.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionVariable represents a variable that can be used in actions
|
||||
//
|
||||
// It can be:
|
||||
// 1. global variable, OwnerID is 0 and RepoID is 0
|
||||
// 2. org/user level variable, OwnerID is org/user ID and RepoID is 0
|
||||
// 3. repo level variable, OwnerID is 0 and RepoID is repo ID
|
||||
//
|
||||
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
|
||||
// or it will be complicated to find variables belonging to a specific owner.
|
||||
// For example, conditions like `OwnerID = 1` will also return variable {OwnerID: 1, RepoID: 1},
|
||||
// but it's a repo level variable, not an org/user level variable.
|
||||
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level variables.
|
||||
type ActionVariable struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
OwnerID int64 `xorm:"UNIQUE(owner_repo_name)"`
|
||||
RepoID int64 `xorm:"INDEX UNIQUE(owner_repo_name)"`
|
||||
Name string `xorm:"UNIQUE(owner_repo_name) NOT NULL"`
|
||||
Data string `xorm:"LONGTEXT NOT NULL"`
|
||||
Description string `xorm:"TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created NOT NULL"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
const (
|
||||
VariableDataMaxLength = 65536
|
||||
VariableDescriptionMaxLength = 4096
|
||||
)
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(ActionVariable))
|
||||
}
|
||||
|
||||
func InsertVariable(ctx context.Context, ownerID, repoID int64, name, data, description string) (*ActionVariable, error) {
|
||||
if ownerID != 0 && repoID != 0 {
|
||||
// It's trying to create a variable that belongs to a repository, but OwnerID has been set accidentally.
|
||||
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
|
||||
ownerID = 0
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(data) > VariableDataMaxLength {
|
||||
return nil, util.NewInvalidArgumentErrorf("data too long")
|
||||
}
|
||||
|
||||
description = util.TruncateRunes(description, VariableDescriptionMaxLength)
|
||||
|
||||
variable := &ActionVariable{
|
||||
OwnerID: ownerID,
|
||||
RepoID: repoID,
|
||||
Name: strings.ToUpper(name),
|
||||
Data: data,
|
||||
Description: description,
|
||||
}
|
||||
return variable, db.Insert(ctx, variable)
|
||||
}
|
||||
|
||||
type FindVariablesOpts struct {
|
||||
db.ListOptions
|
||||
IDs []int64
|
||||
RepoID int64
|
||||
OwnerID int64 // it will be ignored if RepoID is set
|
||||
Name string
|
||||
}
|
||||
|
||||
func (opts FindVariablesOpts) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
|
||||
if len(opts.IDs) > 0 {
|
||||
if len(opts.IDs) == 1 {
|
||||
cond = cond.And(builder.Eq{"id": opts.IDs[0]})
|
||||
} else {
|
||||
cond = cond.And(builder.In("id", opts.IDs))
|
||||
}
|
||||
}
|
||||
|
||||
// Since we now support instance-level variables,
|
||||
// there is no need to check for null values for `owner_id` and `repo_id`
|
||||
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
|
||||
if opts.RepoID != 0 { // if RepoID is set
|
||||
// ignore OwnerID and treat it as 0
|
||||
cond = cond.And(builder.Eq{"owner_id": 0})
|
||||
} else {
|
||||
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
}
|
||||
|
||||
if opts.Name != "" {
|
||||
cond = cond.And(builder.Eq{"name": strings.ToUpper(opts.Name)})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func FindVariables(ctx context.Context, opts FindVariablesOpts) ([]*ActionVariable, error) {
|
||||
return db.Find[ActionVariable](ctx, opts)
|
||||
}
|
||||
|
||||
func UpdateVariableCols(ctx context.Context, variable *ActionVariable, cols ...string) (bool, error) {
|
||||
if utf8.RuneCountInString(variable.Data) > VariableDataMaxLength {
|
||||
return false, util.NewInvalidArgumentErrorf("data too long")
|
||||
}
|
||||
|
||||
variable.Description = util.TruncateRunes(variable.Description, VariableDescriptionMaxLength)
|
||||
|
||||
variable.Name = strings.ToUpper(variable.Name)
|
||||
count, err := db.GetEngine(ctx).
|
||||
ID(variable.ID).
|
||||
Cols(cols...).
|
||||
Update(variable)
|
||||
return count != 0, err
|
||||
}
|
||||
|
||||
func DeleteVariable(ctx context.Context, id int64) error {
|
||||
if _, err := db.DeleteByID[ActionVariable](ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetVariablesOfRun(ctx context.Context, run *ActionRun) (map[string]string, error) {
|
||||
variables := map[string]string{}
|
||||
|
||||
if err := run.LoadRepo(ctx); err != nil {
|
||||
log.Error("LoadRepo: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Global
|
||||
globalVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{})
|
||||
if err != nil {
|
||||
log.Error("find global variables: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Org / User level
|
||||
ownerVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{OwnerID: run.Repo.OwnerID})
|
||||
if err != nil {
|
||||
log.Error("find variables of org: %d, error: %v", run.Repo.OwnerID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Repo level
|
||||
repoVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{RepoID: run.RepoID})
|
||||
if err != nil {
|
||||
log.Error("find variables of repo: %d, error: %v", run.RepoID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Level precedence: Repo > Org / User > Global
|
||||
for _, v := range append(globalVariables, append(ownerVariables, repoVariables...)...) {
|
||||
variables[v.Name] = v.Data
|
||||
}
|
||||
|
||||
return variables, nil
|
||||
}
|
||||
|
||||
func CountWrongRepoLevelVariables(ctx context.Context) (int64, error) {
|
||||
var result int64
|
||||
_, err := db.GetEngine(ctx).SQL("SELECT count(`id`) FROM `action_variable` WHERE `repo_id` > 0 AND `owner_id` > 0").Get(&result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func UpdateWrongRepoLevelVariables(ctx context.Context) (int64, error) {
|
||||
result, err := db.GetEngine(ctx).Exec("UPDATE `action_variable` SET `owner_id` = 0 WHERE `repo_id` > 0 AND `owner_id` > 0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
616
models/activities/action.go
Normal file
616
models/activities/action.go
Normal file
@@ -0,0 +1,616 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/models/organization"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/git"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/structs"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// ActionType represents the type of an action.
|
||||
type ActionType int
|
||||
|
||||
// Possible action types.
|
||||
const (
|
||||
ActionCreateRepo ActionType = iota + 1 // 1
|
||||
ActionRenameRepo // 2
|
||||
ActionStarRepo // 3
|
||||
ActionWatchRepo // 4
|
||||
ActionCommitRepo // 5
|
||||
ActionCreateIssue // 6
|
||||
ActionCreatePullRequest // 7
|
||||
ActionTransferRepo // 8
|
||||
ActionPushTag // 9
|
||||
ActionCommentIssue // 10
|
||||
ActionMergePullRequest // 11
|
||||
ActionCloseIssue // 12
|
||||
ActionReopenIssue // 13
|
||||
ActionClosePullRequest // 14
|
||||
ActionReopenPullRequest // 15
|
||||
ActionDeleteTag // 16
|
||||
ActionDeleteBranch // 17
|
||||
ActionMirrorSyncPush // 18
|
||||
ActionMirrorSyncCreate // 19
|
||||
ActionMirrorSyncDelete // 20
|
||||
ActionApprovePullRequest // 21
|
||||
ActionRejectPullRequest // 22
|
||||
ActionCommentPull // 23
|
||||
ActionPublishRelease // 24
|
||||
ActionPullReviewDismissed // 25
|
||||
ActionPullRequestReadyForReview // 26
|
||||
ActionAutoMergePullRequest // 27
|
||||
)
|
||||
|
||||
func (at ActionType) String() string {
|
||||
switch at {
|
||||
case ActionCreateRepo:
|
||||
return "create_repo"
|
||||
case ActionRenameRepo:
|
||||
return "rename_repo"
|
||||
case ActionStarRepo:
|
||||
return "star_repo" // will not displayed in feeds.tmpl
|
||||
case ActionWatchRepo:
|
||||
return "watch_repo" // will not displayed in feeds.tmpl
|
||||
case ActionCommitRepo:
|
||||
return "commit_repo"
|
||||
case ActionCreateIssue:
|
||||
return "create_issue"
|
||||
case ActionCreatePullRequest:
|
||||
return "create_pull_request"
|
||||
case ActionTransferRepo:
|
||||
return "transfer_repo"
|
||||
case ActionPushTag:
|
||||
return "push_tag"
|
||||
case ActionCommentIssue:
|
||||
return "comment_issue"
|
||||
case ActionMergePullRequest:
|
||||
return "merge_pull_request"
|
||||
case ActionCloseIssue:
|
||||
return "close_issue"
|
||||
case ActionReopenIssue:
|
||||
return "reopen_issue"
|
||||
case ActionClosePullRequest:
|
||||
return "close_pull_request"
|
||||
case ActionReopenPullRequest:
|
||||
return "reopen_pull_request"
|
||||
case ActionDeleteTag:
|
||||
return "delete_tag"
|
||||
case ActionDeleteBranch:
|
||||
return "delete_branch"
|
||||
case ActionMirrorSyncPush:
|
||||
return "mirror_sync_push"
|
||||
case ActionMirrorSyncCreate:
|
||||
return "mirror_sync_create"
|
||||
case ActionMirrorSyncDelete:
|
||||
return "mirror_sync_delete"
|
||||
case ActionApprovePullRequest:
|
||||
return "approve_pull_request"
|
||||
case ActionRejectPullRequest:
|
||||
return "reject_pull_request"
|
||||
case ActionCommentPull:
|
||||
return "comment_pull"
|
||||
case ActionPublishRelease:
|
||||
return "publish_release"
|
||||
case ActionPullReviewDismissed:
|
||||
return "pull_review_dismissed"
|
||||
case ActionPullRequestReadyForReview:
|
||||
return "pull_request_ready_for_review"
|
||||
case ActionAutoMergePullRequest:
|
||||
return "auto_merge_pull_request"
|
||||
default:
|
||||
return "action-" + strconv.Itoa(int(at))
|
||||
}
|
||||
}
|
||||
|
||||
func (at ActionType) InActions(actions ...string) bool {
|
||||
return slices.Contains(actions, at.String())
|
||||
}
|
||||
|
||||
// Action represents user operation type and other information to
|
||||
// repository. It implemented interface base.Actioner so that can be
|
||||
// used in template render.
|
||||
type Action struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UserID int64 `xorm:"INDEX"` // Receiver user id.
|
||||
OpType ActionType
|
||||
ActUserID int64 // Action user id.
|
||||
ActUser *user_model.User `xorm:"-"`
|
||||
RepoID int64
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
CommentID int64 `xorm:"INDEX"`
|
||||
Comment *issues_model.Comment `xorm:"-"`
|
||||
Issue *issues_model.Issue `xorm:"-"` // get the issue id from content
|
||||
IsDeleted bool `xorm:"NOT NULL DEFAULT false"`
|
||||
RefName string
|
||||
IsPrivate bool `xorm:"NOT NULL DEFAULT false"`
|
||||
Content string `xorm:"TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Action))
|
||||
}
|
||||
|
||||
// TableIndices implements xorm's TableIndices interface
|
||||
func (a *Action) TableIndices() []*schemas.Index {
|
||||
repoIndex := schemas.NewIndex("r_u_d", schemas.IndexType)
|
||||
repoIndex.AddColumn("repo_id", "user_id", "is_deleted")
|
||||
|
||||
actUserIndex := schemas.NewIndex("au_r_c_u_d", schemas.IndexType)
|
||||
actUserIndex.AddColumn("act_user_id", "repo_id", "created_unix", "user_id", "is_deleted")
|
||||
|
||||
cudIndex := schemas.NewIndex("c_u_d", schemas.IndexType)
|
||||
cudIndex.AddColumn("created_unix", "user_id", "is_deleted")
|
||||
|
||||
cuIndex := schemas.NewIndex("c_u", schemas.IndexType)
|
||||
cuIndex.AddColumn("user_id", "is_deleted")
|
||||
|
||||
actUserUserIndex := schemas.NewIndex("au_c_u", schemas.IndexType)
|
||||
actUserUserIndex.AddColumn("act_user_id", "created_unix", "user_id")
|
||||
|
||||
indices := []*schemas.Index{actUserIndex, repoIndex, cudIndex, cuIndex, actUserUserIndex}
|
||||
|
||||
return indices
|
||||
}
|
||||
|
||||
// GetOpType gets the ActionType of this action.
|
||||
func (a *Action) GetOpType() ActionType {
|
||||
return a.OpType
|
||||
}
|
||||
|
||||
// LoadActUser loads a.ActUser
|
||||
func (a *Action) LoadActUser(ctx context.Context) {
|
||||
if a.ActUser != nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
a.ActUser, err = user_model.GetPossibleUserByID(ctx, a.ActUserID)
|
||||
if err == nil {
|
||||
return
|
||||
} else if user_model.IsErrUserNotExist(err) {
|
||||
a.ActUser = user_model.NewGhostUser()
|
||||
} else {
|
||||
log.Error("GetUserByID(%d): %v", a.ActUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Action) LoadRepo(ctx context.Context) error {
|
||||
if a.Repo != nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
a.Repo, err = repo_model.GetRepositoryByID(ctx, a.RepoID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetActFullName gets the action's user full name.
|
||||
func (a *Action) GetActFullName(ctx context.Context) string {
|
||||
a.LoadActUser(ctx)
|
||||
return a.ActUser.FullName
|
||||
}
|
||||
|
||||
// GetActUserName gets the action's user name.
|
||||
func (a *Action) GetActUserName(ctx context.Context) string {
|
||||
a.LoadActUser(ctx)
|
||||
return a.ActUser.Name
|
||||
}
|
||||
|
||||
// ShortActUserName gets the action's user name trimmed to max 20
|
||||
// chars.
|
||||
func (a *Action) ShortActUserName(ctx context.Context) string {
|
||||
return util.EllipsisDisplayString(a.GetActUserName(ctx), 20)
|
||||
}
|
||||
|
||||
// GetActDisplayName gets the action's display name based on DEFAULT_SHOW_FULL_NAME, or falls back to the username if it is blank.
|
||||
func (a *Action) GetActDisplayName(ctx context.Context) string {
|
||||
if setting.UI.DefaultShowFullName {
|
||||
trimmedFullName := strings.TrimSpace(a.GetActFullName(ctx))
|
||||
if len(trimmedFullName) > 0 {
|
||||
return trimmedFullName
|
||||
}
|
||||
}
|
||||
return a.ShortActUserName(ctx)
|
||||
}
|
||||
|
||||
// GetActDisplayNameTitle gets the action's display name used for the title (tooltip) based on DEFAULT_SHOW_FULL_NAME
|
||||
func (a *Action) GetActDisplayNameTitle(ctx context.Context) string {
|
||||
if setting.UI.DefaultShowFullName {
|
||||
return a.ShortActUserName(ctx)
|
||||
}
|
||||
return a.GetActFullName(ctx)
|
||||
}
|
||||
|
||||
// GetRepoUserName returns the name of the action repository owner.
|
||||
func (a *Action) GetRepoUserName(ctx context.Context) string {
|
||||
_ = a.LoadRepo(ctx)
|
||||
if a.Repo == nil {
|
||||
return "(non-existing-repo)"
|
||||
}
|
||||
return a.Repo.OwnerName
|
||||
}
|
||||
|
||||
// ShortRepoUserName returns the name of the action repository owner
|
||||
// trimmed to max 20 chars.
|
||||
func (a *Action) ShortRepoUserName(ctx context.Context) string {
|
||||
return util.EllipsisDisplayString(a.GetRepoUserName(ctx), 20)
|
||||
}
|
||||
|
||||
// GetRepoName returns the name of the action repository.
|
||||
func (a *Action) GetRepoName(ctx context.Context) string {
|
||||
_ = a.LoadRepo(ctx)
|
||||
if a.Repo == nil {
|
||||
return "(non-existing-repo)"
|
||||
}
|
||||
return a.Repo.Name
|
||||
}
|
||||
|
||||
// ShortRepoName returns the name of the action repository
|
||||
// trimmed to max 33 chars.
|
||||
func (a *Action) ShortRepoName(ctx context.Context) string {
|
||||
return util.EllipsisDisplayString(a.GetRepoName(ctx), 33)
|
||||
}
|
||||
|
||||
// GetRepoPath returns the virtual path to the action repository.
|
||||
func (a *Action) GetRepoPath(ctx context.Context) string {
|
||||
return path.Join(a.GetRepoUserName(ctx), a.GetRepoName(ctx))
|
||||
}
|
||||
|
||||
// ShortRepoPath returns the virtual path to the action repository
|
||||
// trimmed to max 20 + 1 + 33 chars.
|
||||
func (a *Action) ShortRepoPath(ctx context.Context) string {
|
||||
return path.Join(a.ShortRepoUserName(ctx), a.ShortRepoName(ctx))
|
||||
}
|
||||
|
||||
// GetRepoLink returns relative link to action repository.
|
||||
func (a *Action) GetRepoLink(ctx context.Context) string {
|
||||
// path.Join will skip empty strings
|
||||
return path.Join(setting.AppSubURL, "/", url.PathEscape(a.GetRepoUserName(ctx)), url.PathEscape(a.GetRepoName(ctx)))
|
||||
}
|
||||
|
||||
// GetRepoAbsoluteLink returns the absolute link to action repository.
|
||||
func (a *Action) GetRepoAbsoluteLink(ctx context.Context) string {
|
||||
return setting.AppURL + url.PathEscape(a.GetRepoUserName(ctx)) + "/" + url.PathEscape(a.GetRepoName(ctx))
|
||||
}
|
||||
|
||||
func (a *Action) loadComment(ctx context.Context) (err error) {
|
||||
if a.CommentID == 0 || a.Comment != nil {
|
||||
return nil
|
||||
}
|
||||
a.Comment, err = issues_model.GetCommentByID(ctx, a.CommentID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCommentHTMLURL returns link to action comment.
|
||||
func (a *Action) GetCommentHTMLURL(ctx context.Context) string {
|
||||
if a == nil {
|
||||
return "#"
|
||||
}
|
||||
_ = a.loadComment(ctx)
|
||||
if a.Comment != nil {
|
||||
return a.Comment.HTMLURL(ctx)
|
||||
}
|
||||
|
||||
if err := a.LoadIssue(ctx); err != nil || a.Issue == nil {
|
||||
return "#"
|
||||
}
|
||||
if err := a.Issue.LoadRepo(ctx); err != nil {
|
||||
return "#"
|
||||
}
|
||||
|
||||
return a.Issue.HTMLURL(ctx)
|
||||
}
|
||||
|
||||
// GetCommentLink returns link to action comment.
|
||||
func (a *Action) GetCommentLink(ctx context.Context) string {
|
||||
if a == nil {
|
||||
return "#"
|
||||
}
|
||||
_ = a.loadComment(ctx)
|
||||
if a.Comment != nil {
|
||||
return a.Comment.Link(ctx)
|
||||
}
|
||||
|
||||
if err := a.LoadIssue(ctx); err != nil || a.Issue == nil {
|
||||
return "#"
|
||||
}
|
||||
if err := a.Issue.LoadRepo(ctx); err != nil {
|
||||
return "#"
|
||||
}
|
||||
|
||||
return a.Issue.Link()
|
||||
}
|
||||
|
||||
// GetBranch returns the action's repository branch.
|
||||
func (a *Action) GetBranch() string {
|
||||
return strings.TrimPrefix(a.RefName, git.BranchPrefix)
|
||||
}
|
||||
|
||||
// GetRefLink returns the action's ref link.
|
||||
func (a *Action) GetRefLink(ctx context.Context) string {
|
||||
return a.GetRepoLink(ctx) + "/src/" + git.RefName(a.RefName).RefWebLinkPath()
|
||||
}
|
||||
|
||||
// GetTag returns the action's repository tag.
|
||||
func (a *Action) GetTag() string {
|
||||
return strings.TrimPrefix(a.RefName, git.TagPrefix)
|
||||
}
|
||||
|
||||
// GetContent returns the action's content.
|
||||
func (a *Action) GetContent() string {
|
||||
return a.Content
|
||||
}
|
||||
|
||||
// GetCreate returns the action creation time.
|
||||
func (a *Action) GetCreate() time.Time {
|
||||
return a.CreatedUnix.AsTime()
|
||||
}
|
||||
|
||||
func (a *Action) IsIssueEvent() bool {
|
||||
return a.OpType.InActions("comment_issue", "approve_pull_request", "reject_pull_request", "comment_pull", "merge_pull_request")
|
||||
}
|
||||
|
||||
// GetIssueInfos returns a list of associated information with the action.
|
||||
func (a *Action) GetIssueInfos() []string {
|
||||
// make sure it always returns 3 elements, because there are some access to the a[1] and a[2] without checking the length
|
||||
ret := strings.SplitN(a.Content, "|", 3)
|
||||
for len(ret) < 3 {
|
||||
ret = append(ret, "")
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (a *Action) getIssueIndex() int64 {
|
||||
infos := a.GetIssueInfos()
|
||||
if len(infos) == 0 {
|
||||
return 0
|
||||
}
|
||||
index, _ := strconv.ParseInt(infos[0], 10, 64)
|
||||
return index
|
||||
}
|
||||
|
||||
func (a *Action) LoadIssue(ctx context.Context) error {
|
||||
if a.Issue != nil {
|
||||
return nil
|
||||
}
|
||||
if index := a.getIssueIndex(); index > 0 {
|
||||
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.Issue = issue
|
||||
a.Issue.Repo = a.Repo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIssueTitle returns the title of first issue associated with the action.
|
||||
func (a *Action) GetIssueTitle(ctx context.Context) string {
|
||||
if err := a.LoadIssue(ctx); err != nil {
|
||||
log.Error("LoadIssue: %v", err)
|
||||
return "<500 when get issue>"
|
||||
}
|
||||
if a.Issue == nil {
|
||||
return "<Issue not found>"
|
||||
}
|
||||
return a.Issue.Title
|
||||
}
|
||||
|
||||
// GetIssueContent returns the content of first issue associated with this action.
|
||||
func (a *Action) GetIssueContent(ctx context.Context) string {
|
||||
if err := a.LoadIssue(ctx); err != nil {
|
||||
log.Error("LoadIssue: %v", err)
|
||||
return "<500 when get issue>"
|
||||
}
|
||||
if a.Issue == nil {
|
||||
return "<Content not found>"
|
||||
}
|
||||
return a.Issue.Content
|
||||
}
|
||||
|
||||
// GetFeedsOptions options for retrieving feeds
|
||||
type GetFeedsOptions struct {
|
||||
db.ListOptions
|
||||
RequestedUser *user_model.User // the user we want activity for
|
||||
RequestedTeam *organization.Team // the team we want activity for
|
||||
RequestedRepo *repo_model.Repository // the repo we want activity for
|
||||
Actor *user_model.User // the user viewing the activity
|
||||
IncludePrivate bool // include private actions
|
||||
OnlyPerformedBy bool // only actions performed by requested user
|
||||
IncludeDeleted bool // include deleted actions
|
||||
Date string // the day we want activity for: YYYY-MM-DD
|
||||
DontCount bool // do counting in GetFeeds
|
||||
}
|
||||
|
||||
// ActivityReadable return whether doer can read activities of user
|
||||
func ActivityReadable(user, doer *user_model.User) bool {
|
||||
return !user.KeepActivityPrivate ||
|
||||
doer != nil && (doer.IsAdmin || user.ID == doer.ID)
|
||||
}
|
||||
|
||||
func FeedDateCond(opts GetFeedsOptions) builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.Date == "" {
|
||||
return cond
|
||||
}
|
||||
|
||||
dateLow, err := time.ParseInLocation("2006-01-02", opts.Date, setting.DefaultUILocation)
|
||||
if err != nil {
|
||||
log.Warn("Unable to parse %s, filter not applied: %v", opts.Date, err)
|
||||
} else {
|
||||
dateHigh := dateLow.Add(86399000000000) // 23h59m59s
|
||||
|
||||
cond = cond.And(builder.Gte{"`action`.created_unix": dateLow.Unix()})
|
||||
cond = cond.And(builder.Lte{"`action`.created_unix": dateHigh.Unix()})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func ActivityQueryCondition(ctx context.Context, opts GetFeedsOptions) (builder.Cond, error) {
|
||||
cond := builder.NewCond()
|
||||
|
||||
if opts.RequestedTeam != nil && opts.RequestedUser == nil {
|
||||
org, err := user_model.GetUserByID(ctx, opts.RequestedTeam.OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.RequestedUser = org
|
||||
}
|
||||
|
||||
// check activity visibility for actor ( similar to activityReadable() )
|
||||
if opts.Actor == nil {
|
||||
cond = cond.And(builder.In("act_user_id",
|
||||
builder.Select("`user`.id").Where(
|
||||
builder.Eq{"keep_activity_private": false, "visibility": structs.VisibleTypePublic},
|
||||
).From("`user`"),
|
||||
))
|
||||
} else if !opts.Actor.IsAdmin {
|
||||
uidCond := builder.Select("`user`.id").From("`user`").Where(
|
||||
builder.Eq{"keep_activity_private": false}.
|
||||
And(builder.In("visibility", structs.VisibleTypePublic, structs.VisibleTypeLimited))).
|
||||
Or(builder.Eq{"id": opts.Actor.ID})
|
||||
|
||||
if opts.RequestedUser != nil {
|
||||
if opts.RequestedUser.IsOrganization() {
|
||||
// An organization can always see the activities whose `act_user_id` is the same as its id.
|
||||
uidCond = uidCond.Or(builder.Eq{"id": opts.RequestedUser.ID})
|
||||
} else {
|
||||
// A user can always see the activities of the organizations to which the user belongs.
|
||||
uidCond = uidCond.Or(
|
||||
builder.Eq{"type": user_model.UserTypeOrganization}.
|
||||
And(builder.In("`user`.id", builder.Select("org_id").
|
||||
Where(builder.Eq{"uid": opts.RequestedUser.ID}).
|
||||
From("team_user"))),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
cond = cond.And(builder.In("act_user_id", uidCond))
|
||||
}
|
||||
|
||||
// check readable repositories by doer/actor
|
||||
if opts.Actor == nil || !opts.Actor.IsAdmin {
|
||||
cond = cond.And(builder.In("repo_id", repo_model.AccessibleRepoIDsQuery(opts.Actor)))
|
||||
}
|
||||
|
||||
if opts.RequestedRepo != nil {
|
||||
// repo's actions could have duplicate items, see the comment of NotifyWatchers
|
||||
// so here we only filter the "original items", aka: user_id == act_user_id
|
||||
cond = cond.And(
|
||||
builder.Eq{"`action`.repo_id": opts.RequestedRepo.ID},
|
||||
builder.Expr("`action`.user_id = `action`.act_user_id"),
|
||||
)
|
||||
}
|
||||
|
||||
if opts.RequestedTeam != nil {
|
||||
env := repo_model.AccessibleTeamReposEnv(organization.OrgFromUser(opts.RequestedUser), opts.RequestedTeam)
|
||||
teamRepoIDs, err := env.RepoIDs(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetTeamRepositories: %w", err)
|
||||
}
|
||||
cond = cond.And(builder.In("repo_id", teamRepoIDs))
|
||||
}
|
||||
|
||||
if opts.RequestedUser != nil {
|
||||
cond = cond.And(builder.Eq{"user_id": opts.RequestedUser.ID})
|
||||
|
||||
if opts.OnlyPerformedBy {
|
||||
cond = cond.And(builder.Eq{"act_user_id": opts.RequestedUser.ID})
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.IncludePrivate {
|
||||
cond = cond.And(builder.Eq{"`action`.is_private": false})
|
||||
}
|
||||
if !opts.IncludeDeleted {
|
||||
cond = cond.And(builder.Eq{"is_deleted": false})
|
||||
}
|
||||
|
||||
cond = cond.And(FeedDateCond(opts))
|
||||
|
||||
return cond, nil
|
||||
}
|
||||
|
||||
// DeleteOldActions deletes all old actions from database.
|
||||
func DeleteOldActions(ctx context.Context, olderThan time.Duration) (err error) {
|
||||
if olderThan <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteIssueActions delete all actions related with issueID
|
||||
func DeleteIssueActions(ctx context.Context, repoID, issueID, issueIndex int64) error {
|
||||
// delete actions assigned to this issue
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
// MariaDB has a performance bug: https://jira.mariadb.org/browse/MDEV-16289
|
||||
// so here it uses "DELETE ... WHERE IN" with pre-queried IDs.
|
||||
var lastCommentID int64
|
||||
commentIDs := make([]int64, 0, db.DefaultMaxInSize)
|
||||
for {
|
||||
commentIDs = commentIDs[:0]
|
||||
err := e.Select("`id`").Table(&issues_model.Comment{}).
|
||||
Where(builder.Eq{"issue_id": issueID}).And("`id` > ?", lastCommentID).
|
||||
OrderBy("`id`").Limit(db.DefaultMaxInSize).
|
||||
Find(&commentIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(commentIDs) == 0 {
|
||||
break
|
||||
} else if _, err = db.GetEngine(ctx).In("comment_id", commentIDs).Delete(&Action{}); err != nil {
|
||||
return err
|
||||
}
|
||||
lastCommentID = commentIDs[len(commentIDs)-1]
|
||||
}
|
||||
|
||||
_, err := e.Where("repo_id = ?", repoID).
|
||||
In("op_type", ActionCreateIssue, ActionCreatePullRequest).
|
||||
Where("content LIKE ?", strconv.FormatInt(issueIndex, 10)+"|%"). // "IssueIndex|content..."
|
||||
Delete(&Action{})
|
||||
return err
|
||||
}
|
||||
|
||||
// CountActionCreatedUnixString count actions where created_unix is an empty string
|
||||
func CountActionCreatedUnixString(ctx context.Context) (int64, error) {
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
return db.GetEngine(ctx).Where(`created_unix = ''`).Count(new(Action))
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// FixActionCreatedUnixString set created_unix to zero if it is an empty string
|
||||
func FixActionCreatedUnixString(ctx context.Context) (int64, error) {
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
res, err := db.GetEngine(ctx).Exec(`UPDATE action SET created_unix = 0 WHERE created_unix = ''`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
290
models/activities/action_list.go
Normal file
290
models/activities/action_list.go
Normal file
@@ -0,0 +1,290 @@
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ActionList defines a list of actions
|
||||
type ActionList []*Action
|
||||
|
||||
func (actions ActionList) getUserIDs() []int64 {
|
||||
return container.FilterSlice(actions, func(action *Action) (int64, bool) {
|
||||
return action.ActUserID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_model.User, error) {
|
||||
if len(actions) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
userIDs := actions.getUserIDs()
|
||||
userMaps := make(map[int64]*user_model.User, len(userIDs))
|
||||
err := db.GetEngine(ctx).
|
||||
In("id", userIDs).
|
||||
Find(&userMaps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find user: %w", err)
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
action.ActUser = userMaps[action.ActUserID]
|
||||
}
|
||||
return userMaps, nil
|
||||
}
|
||||
|
||||
func (actions ActionList) getRepoIDs() []int64 {
|
||||
return container.FilterSlice(actions, func(action *Action) (int64, bool) {
|
||||
return action.RepoID, true
|
||||
})
|
||||
}
|
||||
|
||||
func (actions ActionList) LoadRepositories(ctx context.Context) error {
|
||||
if len(actions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repoIDs := actions.getRepoIDs()
|
||||
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
|
||||
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("find repository: %w", err)
|
||||
}
|
||||
for _, action := range actions {
|
||||
action.Repo = repoMaps[action.RepoID]
|
||||
}
|
||||
repos := repo_model.RepositoryList(util.ValuesOfMap(repoMaps))
|
||||
return repos.LoadUnits(ctx)
|
||||
}
|
||||
|
||||
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
|
||||
if userMap == nil {
|
||||
userMap = make(map[int64]*user_model.User)
|
||||
}
|
||||
|
||||
missingUserIDs := container.FilterSlice(actions, func(action *Action) (int64, bool) {
|
||||
if action.Repo == nil {
|
||||
return 0, false
|
||||
}
|
||||
_, alreadyLoaded := userMap[action.Repo.OwnerID]
|
||||
return action.Repo.OwnerID, !alreadyLoaded
|
||||
})
|
||||
if len(missingUserIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.GetEngine(ctx).
|
||||
In("id", missingUserIDs).
|
||||
Find(&userMap); err != nil {
|
||||
return fmt.Errorf("find user: %w", err)
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
if action.Repo != nil {
|
||||
action.Repo.Owner = userMap[action.Repo.OwnerID]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAttributes loads all attributes
|
||||
func (actions ActionList) LoadAttributes(ctx context.Context) error {
|
||||
// the load sequence cannot be changed because of the dependencies
|
||||
userMap, err := actions.LoadActUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := actions.LoadRepositories(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := actions.loadRepoOwner(ctx, userMap); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := actions.LoadIssues(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return actions.LoadComments(ctx)
|
||||
}
|
||||
|
||||
func (actions ActionList) LoadComments(ctx context.Context) error {
|
||||
if len(actions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
commentIDs := make([]int64, 0, len(actions))
|
||||
for _, action := range actions {
|
||||
if action.CommentID > 0 {
|
||||
commentIDs = append(commentIDs, action.CommentID)
|
||||
}
|
||||
}
|
||||
if len(commentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
commentsMap := make(map[int64]*issues_model.Comment, len(commentIDs))
|
||||
if err := db.GetEngine(ctx).In("id", commentIDs).Find(&commentsMap); err != nil {
|
||||
return fmt.Errorf("find comment: %w", err)
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
if action.CommentID > 0 {
|
||||
action.Comment = commentsMap[action.CommentID]
|
||||
if action.Comment != nil {
|
||||
action.Comment.Issue = action.Issue
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (actions ActionList) LoadIssues(ctx context.Context) error {
|
||||
if len(actions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conditions := builder.NewCond()
|
||||
issueNum := 0
|
||||
for _, action := range actions {
|
||||
if action.IsIssueEvent() {
|
||||
infos := action.GetIssueInfos()
|
||||
if len(infos) == 0 {
|
||||
continue
|
||||
}
|
||||
index, _ := strconv.ParseInt(infos[0], 10, 64)
|
||||
if index > 0 {
|
||||
conditions = conditions.Or(builder.Eq{
|
||||
"repo_id": action.RepoID,
|
||||
"`index`": index,
|
||||
})
|
||||
issueNum++
|
||||
}
|
||||
}
|
||||
}
|
||||
if !conditions.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
issuesMap := make(map[string]*issues_model.Issue, issueNum)
|
||||
issues := make([]*issues_model.Issue, 0, issueNum)
|
||||
if err := db.GetEngine(ctx).Where(conditions).Find(&issues); err != nil {
|
||||
return fmt.Errorf("find issue: %w", err)
|
||||
}
|
||||
for _, issue := range issues {
|
||||
issuesMap[fmt.Sprintf("%d-%d", issue.RepoID, issue.Index)] = issue
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
if !action.IsIssueEvent() {
|
||||
continue
|
||||
}
|
||||
if index := action.getIssueIndex(); index > 0 {
|
||||
if issue, ok := issuesMap[fmt.Sprintf("%d-%d", action.RepoID, index)]; ok {
|
||||
action.Issue = issue
|
||||
action.Issue.Repo = action.Repo
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFeeds returns actions according to the provided options
|
||||
func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, int64, error) {
|
||||
if opts.RequestedUser == nil && opts.RequestedTeam == nil && opts.RequestedRepo == nil {
|
||||
return nil, 0, errors.New("need at least one of these filters: RequestedUser, RequestedTeam, RequestedRepo")
|
||||
}
|
||||
|
||||
var err error
|
||||
var cond builder.Cond
|
||||
// if the actor is the requested user or is an administrator, we can skip the ActivityQueryCondition
|
||||
if opts.Actor != nil && opts.RequestedUser != nil && (opts.Actor.IsAdmin || opts.Actor.ID == opts.RequestedUser.ID) {
|
||||
cond = builder.Eq{
|
||||
"user_id": opts.RequestedUser.ID,
|
||||
}.And(
|
||||
FeedDateCond(opts),
|
||||
)
|
||||
|
||||
if !opts.IncludeDeleted {
|
||||
cond = cond.And(builder.Eq{"is_deleted": false})
|
||||
}
|
||||
|
||||
if !opts.IncludePrivate {
|
||||
cond = cond.And(builder.Eq{"is_private": false})
|
||||
}
|
||||
if opts.OnlyPerformedBy {
|
||||
cond = cond.And(builder.Eq{"act_user_id": opts.RequestedUser.ID})
|
||||
}
|
||||
} else {
|
||||
cond, err = ActivityQueryCondition(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
actions := make([]*Action, 0, opts.PageSize)
|
||||
var count int64
|
||||
opts.SetDefaultValues()
|
||||
|
||||
if opts.Page < 10 { // TODO: why it's 10 but other values? It's an experience value.
|
||||
sess := db.GetEngine(ctx).Where(cond)
|
||||
sess = db.SetSessionPagination(sess, &opts)
|
||||
|
||||
if opts.DontCount {
|
||||
err = sess.Desc("`action`.created_unix").Find(&actions)
|
||||
} else {
|
||||
count, err = sess.Desc("`action`.created_unix").FindAndCount(&actions)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("FindAndCount: %w", err)
|
||||
}
|
||||
} else {
|
||||
// First, only query which IDs are necessary, and only then query all actions to speed up the overall query
|
||||
sess := db.GetEngine(ctx).Where(cond).Select("`action`.id")
|
||||
sess = db.SetSessionPagination(sess, &opts)
|
||||
|
||||
actionIDs := make([]int64, 0, opts.PageSize)
|
||||
if err := sess.Table("action").Desc("`action`.created_unix").Find(&actionIDs); err != nil {
|
||||
return nil, 0, fmt.Errorf("Find(actionsIDs): %w", err)
|
||||
}
|
||||
|
||||
if !opts.DontCount {
|
||||
count, err = db.GetEngine(ctx).Where(cond).
|
||||
Table("action").
|
||||
Cols("`action`.id").Count()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("Count: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.GetEngine(ctx).In("`action`.id", actionIDs).Desc("`action`.created_unix").Find(&actions); err != nil {
|
||||
return nil, 0, fmt.Errorf("Find: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ActionList(actions).LoadAttributes(ctx); err != nil {
|
||||
return nil, 0, fmt.Errorf("LoadAttributes: %w", err)
|
||||
}
|
||||
|
||||
return actions, count, nil
|
||||
}
|
||||
|
||||
func CountUserFeeds(ctx context.Context, userID int64) (int64, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", userID).
|
||||
And("is_deleted = ?", false).
|
||||
Count(&Action{})
|
||||
}
|
||||
159
models/activities/action_test.go
Normal file
159
models/activities/action_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
activities_model "code.gitea.io/gitea/models/activities"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issue_model "code.gitea.io/gitea/models/issues"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/test"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAction_GetRepoPath(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||
owner := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
|
||||
action := &activities_model.Action{RepoID: repo.ID}
|
||||
assert.Equal(t, path.Join(owner.Name, repo.Name), action.GetRepoPath(t.Context()))
|
||||
}
|
||||
|
||||
func TestAction_GetRepoLink(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
|
||||
owner := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
|
||||
comment := unittest.AssertExistsAndLoadBean(t, &issue_model.Comment{ID: 2})
|
||||
action := &activities_model.Action{RepoID: repo.ID, CommentID: comment.ID}
|
||||
defer test.MockVariableValue(&setting.AppURL, "https://try.gitea.io/suburl/")()
|
||||
defer test.MockVariableValue(&setting.AppSubURL, "/suburl")()
|
||||
expected := path.Join(setting.AppSubURL, owner.Name, repo.Name)
|
||||
assert.Equal(t, expected, action.GetRepoLink(t.Context()))
|
||||
assert.Equal(t, repo.HTMLURL(), action.GetRepoAbsoluteLink(t.Context()))
|
||||
assert.Equal(t, comment.HTMLURL(t.Context()), action.GetCommentHTMLURL(t.Context()))
|
||||
}
|
||||
|
||||
func TestActivityReadable(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
user *user_model.User
|
||||
doer *user_model.User
|
||||
result bool
|
||||
}{{
|
||||
desc: "user should see own activity",
|
||||
user: &user_model.User{ID: 1},
|
||||
doer: &user_model.User{ID: 1},
|
||||
result: true,
|
||||
}, {
|
||||
desc: "anon should see activity if public",
|
||||
user: &user_model.User{ID: 1},
|
||||
result: true,
|
||||
}, {
|
||||
desc: "anon should NOT see activity",
|
||||
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
|
||||
result: false,
|
||||
}, {
|
||||
desc: "user should see own activity if private too",
|
||||
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
|
||||
doer: &user_model.User{ID: 1},
|
||||
result: true,
|
||||
}, {
|
||||
desc: "other user should NOT see activity",
|
||||
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
|
||||
doer: &user_model.User{ID: 2},
|
||||
result: false,
|
||||
}, {
|
||||
desc: "admin should see activity",
|
||||
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
|
||||
doer: &user_model.User{ID: 2, IsAdmin: true},
|
||||
result: true,
|
||||
}}
|
||||
for _, test := range tt {
|
||||
assert.Equal(t, test.result, activities_model.ActivityReadable(test.user, test.doer), test.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsistencyUpdateAction(t *testing.T) {
|
||||
if !setting.Database.Type.IsSQLite3() {
|
||||
t.Skip("Test is only for SQLite database.")
|
||||
}
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
id := 8
|
||||
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
|
||||
ID: int64(id),
|
||||
})
|
||||
_, err := db.GetEngine(t.Context()).Exec(`UPDATE action SET created_unix = '' WHERE id = ?`, id)
|
||||
assert.NoError(t, err)
|
||||
actions := make([]*activities_model.Action, 0, 1)
|
||||
//
|
||||
// XORM returns an error when created_unix is a string
|
||||
//
|
||||
err = db.GetEngine(t.Context()).Where("id = ?", id).Find(&actions)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "type string to a int64: invalid syntax")
|
||||
}
|
||||
//
|
||||
// Get rid of incorrectly set created_unix
|
||||
//
|
||||
count, err := activities_model.CountActionCreatedUnixString(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, count)
|
||||
count, err = activities_model.FixActionCreatedUnixString(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, count)
|
||||
|
||||
count, err = activities_model.CountActionCreatedUnixString(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, count)
|
||||
count, err = activities_model.FixActionCreatedUnixString(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, count)
|
||||
|
||||
//
|
||||
// XORM must be happy now
|
||||
//
|
||||
assert.NoError(t, db.GetEngine(t.Context()).Where("id = ?", id).Find(&actions))
|
||||
unittest.CheckConsistencyFor(t, &activities_model.Action{})
|
||||
}
|
||||
|
||||
func TestDeleteIssueActions(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
// load an issue
|
||||
issue := unittest.AssertExistsAndLoadBean(t, &issue_model.Issue{ID: 4})
|
||||
assert.NotEqual(t, issue.ID, issue.Index) // it needs to use different ID/Index to test the DeleteIssueActions to delete some actions by IssueIndex
|
||||
|
||||
// insert a comment
|
||||
err := db.Insert(t.Context(), &issue_model.Comment{Type: issue_model.CommentTypeComment, IssueID: issue.ID})
|
||||
assert.NoError(t, err)
|
||||
comment := unittest.AssertExistsAndLoadBean(t, &issue_model.Comment{Type: issue_model.CommentTypeComment, IssueID: issue.ID})
|
||||
|
||||
// truncate action table and insert some actions
|
||||
err = db.TruncateBeans(t.Context(), &activities_model.Action{})
|
||||
assert.NoError(t, err)
|
||||
err = db.Insert(t.Context(), &activities_model.Action{
|
||||
OpType: activities_model.ActionCommentIssue,
|
||||
CommentID: comment.ID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
err = db.Insert(t.Context(), &activities_model.Action{
|
||||
OpType: activities_model.ActionCreateIssue,
|
||||
RepoID: issue.RepoID,
|
||||
Content: fmt.Sprintf("%d|content...", issue.Index),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// assert that the actions exist, then delete them
|
||||
unittest.AssertCount(t, &activities_model.Action{}, 2)
|
||||
assert.NoError(t, activities_model.DeleteIssueActions(t.Context(), issue.RepoID, issue.ID, issue.Index))
|
||||
unittest.AssertCount(t, &activities_model.Action{}, 0)
|
||||
}
|
||||
17
models/activities/main_test.go
Normal file
17
models/activities/main_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
_ "code.gitea.io/gitea/models"
|
||||
_ "code.gitea.io/gitea/models/actions"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
418
models/activities/notification.go
Normal file
418
models/activities/notification.go
Normal file
@@ -0,0 +1,418 @@
|
||||
// Copyright 2016 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/models/organization"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type (
|
||||
// NotificationStatus is the status of the notification (read or unread)
|
||||
NotificationStatus uint8
|
||||
// NotificationSource is the source of the notification (issue, PR, commit, etc)
|
||||
NotificationSource uint8
|
||||
)
|
||||
|
||||
const (
|
||||
// NotificationStatusUnread represents an unread notification
|
||||
NotificationStatusUnread NotificationStatus = iota + 1
|
||||
// NotificationStatusRead represents a read notification
|
||||
NotificationStatusRead
|
||||
// NotificationStatusPinned represents a pinned notification
|
||||
NotificationStatusPinned
|
||||
)
|
||||
|
||||
const (
|
||||
// NotificationSourceIssue is a notification of an issue
|
||||
NotificationSourceIssue NotificationSource = iota + 1
|
||||
// NotificationSourcePullRequest is a notification of a pull request
|
||||
NotificationSourcePullRequest
|
||||
// NotificationSourceCommit is a notification of a commit
|
||||
NotificationSourceCommit
|
||||
// NotificationSourceRepository is a notification for a repository
|
||||
NotificationSourceRepository
|
||||
)
|
||||
|
||||
// Notification represents a notification
|
||||
type Notification struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UserID int64 `xorm:"NOT NULL"`
|
||||
RepoID int64 `xorm:"NOT NULL"`
|
||||
|
||||
Status NotificationStatus `xorm:"SMALLINT NOT NULL"`
|
||||
Source NotificationSource `xorm:"SMALLINT NOT NULL"`
|
||||
|
||||
IssueID int64 `xorm:"NOT NULL"`
|
||||
CommitID string
|
||||
CommentID int64
|
||||
|
||||
UpdatedBy int64 `xorm:"NOT NULL"`
|
||||
|
||||
Issue *issues_model.Issue `xorm:"-"`
|
||||
Repository *repo_model.Repository `xorm:"-"`
|
||||
Comment *issues_model.Comment `xorm:"-"`
|
||||
User *user_model.User `xorm:"-"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created NOT NULL"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated NOT NULL"`
|
||||
}
|
||||
|
||||
// TableIndices implements xorm's TableIndices interface
|
||||
func (n *Notification) TableIndices() []*schemas.Index {
|
||||
indices := make([]*schemas.Index, 0, 8)
|
||||
usuuIndex := schemas.NewIndex("u_s_uu", schemas.IndexType)
|
||||
usuuIndex.AddColumn("user_id", "status", "updated_unix")
|
||||
indices = append(indices, usuuIndex)
|
||||
|
||||
// Add the individual indices that were previously defined in struct tags
|
||||
userIDIndex := schemas.NewIndex("idx_notification_user_id", schemas.IndexType)
|
||||
userIDIndex.AddColumn("user_id")
|
||||
indices = append(indices, userIDIndex)
|
||||
|
||||
repoIDIndex := schemas.NewIndex("idx_notification_repo_id", schemas.IndexType)
|
||||
repoIDIndex.AddColumn("repo_id")
|
||||
indices = append(indices, repoIDIndex)
|
||||
|
||||
statusIndex := schemas.NewIndex("idx_notification_status", schemas.IndexType)
|
||||
statusIndex.AddColumn("status")
|
||||
indices = append(indices, statusIndex)
|
||||
|
||||
sourceIndex := schemas.NewIndex("idx_notification_source", schemas.IndexType)
|
||||
sourceIndex.AddColumn("source")
|
||||
indices = append(indices, sourceIndex)
|
||||
|
||||
issueIDIndex := schemas.NewIndex("idx_notification_issue_id", schemas.IndexType)
|
||||
issueIDIndex.AddColumn("issue_id")
|
||||
indices = append(indices, issueIDIndex)
|
||||
|
||||
commitIDIndex := schemas.NewIndex("idx_notification_commit_id", schemas.IndexType)
|
||||
commitIDIndex.AddColumn("commit_id")
|
||||
indices = append(indices, commitIDIndex)
|
||||
|
||||
updatedByIndex := schemas.NewIndex("idx_notification_updated_by", schemas.IndexType)
|
||||
updatedByIndex.AddColumn("updated_by")
|
||||
indices = append(indices, updatedByIndex)
|
||||
|
||||
return indices
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Notification))
|
||||
}
|
||||
|
||||
// CreateRepoTransferNotification creates notification for the user a repository was transferred to
|
||||
func CreateRepoTransferNotification(ctx context.Context, doer, newOwner *user_model.User, repo *repo_model.Repository) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
var notify []*Notification
|
||||
|
||||
if newOwner.IsOrganization() {
|
||||
users, err := organization.GetUsersWhoCanCreateOrgRepo(ctx, newOwner.ID)
|
||||
if err != nil || len(users) == 0 {
|
||||
return err
|
||||
}
|
||||
for i := range users {
|
||||
notify = append(notify, &Notification{
|
||||
UserID: i,
|
||||
RepoID: repo.ID,
|
||||
Status: NotificationStatusUnread,
|
||||
UpdatedBy: doer.ID,
|
||||
Source: NotificationSourceRepository,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
notify = []*Notification{{
|
||||
UserID: newOwner.ID,
|
||||
RepoID: repo.ID,
|
||||
Status: NotificationStatusUnread,
|
||||
UpdatedBy: doer.ID,
|
||||
Source: NotificationSourceRepository,
|
||||
}}
|
||||
}
|
||||
|
||||
return db.Insert(ctx, notify)
|
||||
})
|
||||
}
|
||||
|
||||
func createIssueNotification(ctx context.Context, userID int64, issue *issues_model.Issue, commentID, updatedByID int64) error {
|
||||
notification := &Notification{
|
||||
UserID: userID,
|
||||
RepoID: issue.RepoID,
|
||||
Status: NotificationStatusUnread,
|
||||
IssueID: issue.ID,
|
||||
CommentID: commentID,
|
||||
UpdatedBy: updatedByID,
|
||||
}
|
||||
|
||||
if issue.IsPull {
|
||||
notification.Source = NotificationSourcePullRequest
|
||||
} else {
|
||||
notification.Source = NotificationSourceIssue
|
||||
}
|
||||
|
||||
return db.Insert(ctx, notification)
|
||||
}
|
||||
|
||||
func updateIssueNotification(ctx context.Context, userID, issueID, commentID, updatedByID int64) error {
|
||||
notification, err := GetIssueNotification(ctx, userID, issueID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTICE: Only update comment id when the before notification on this issue is read, otherwise you may miss some old comments.
|
||||
// But we need update update_by so that the notification will be reorder
|
||||
var cols []string
|
||||
if notification.Status == NotificationStatusRead {
|
||||
notification.Status = NotificationStatusUnread
|
||||
notification.CommentID = commentID
|
||||
cols = []string{"status", "update_by", "comment_id"}
|
||||
} else {
|
||||
notification.UpdatedBy = updatedByID
|
||||
cols = []string{"update_by"}
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(notification.ID).Cols(cols...).Update(notification)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetIssueNotification return the notification about an issue
|
||||
func GetIssueNotification(ctx context.Context, userID, issueID int64) (*Notification, error) {
|
||||
notification := new(Notification)
|
||||
_, err := db.GetEngine(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
And("issue_id = ?", issueID).
|
||||
Get(notification)
|
||||
return notification, err
|
||||
}
|
||||
|
||||
// LoadAttributes load Repo Issue User and Comment if not loaded
|
||||
func (n *Notification) LoadAttributes(ctx context.Context) (err error) {
|
||||
if err = n.loadRepo(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = n.loadIssue(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = n.loadUser(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = n.loadComment(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *Notification) loadRepo(ctx context.Context) (err error) {
|
||||
if n.Repository == nil {
|
||||
n.Repository, err = repo_model.GetRepositoryByID(ctx, n.RepoID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getRepositoryByID [%d]: %w", n.RepoID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notification) loadIssue(ctx context.Context) (err error) {
|
||||
if n.Issue == nil && n.IssueID != 0 {
|
||||
n.Issue, err = issues_model.GetIssueByID(ctx, n.IssueID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getIssueByID [%d]: %w", n.IssueID, err)
|
||||
}
|
||||
return n.Issue.LoadAttributes(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notification) loadComment(ctx context.Context) (err error) {
|
||||
if n.Comment == nil && n.CommentID != 0 {
|
||||
n.Comment, err = issues_model.GetCommentByID(ctx, n.CommentID)
|
||||
if err != nil {
|
||||
if issues_model.IsErrCommentNotExist(err) {
|
||||
return issues_model.ErrCommentNotExist{
|
||||
ID: n.CommentID,
|
||||
IssueID: n.IssueID,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notification) loadUser(ctx context.Context) (err error) {
|
||||
if n.User == nil {
|
||||
n.User, err = user_model.GetUserByID(ctx, n.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getUserByID [%d]: %w", n.UserID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRepo returns the repo of the notification
|
||||
func (n *Notification) GetRepo(ctx context.Context) (*repo_model.Repository, error) {
|
||||
return n.Repository, n.loadRepo(ctx)
|
||||
}
|
||||
|
||||
// GetIssue returns the issue of the notification
|
||||
func (n *Notification) GetIssue(ctx context.Context) (*issues_model.Issue, error) {
|
||||
return n.Issue, n.loadIssue(ctx)
|
||||
}
|
||||
|
||||
// HTMLURL formats a URL-string to the notification
|
||||
func (n *Notification) HTMLURL(ctx context.Context) string {
|
||||
switch n.Source {
|
||||
case NotificationSourceIssue, NotificationSourcePullRequest:
|
||||
if n.Comment != nil {
|
||||
return n.Comment.HTMLURL(ctx)
|
||||
}
|
||||
return n.Issue.HTMLURL(ctx)
|
||||
case NotificationSourceCommit:
|
||||
return n.Repository.HTMLURL(ctx) + "/commit/" + url.PathEscape(n.CommitID)
|
||||
case NotificationSourceRepository:
|
||||
return n.Repository.HTMLURL(ctx)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Link formats a relative URL-string to the notification
|
||||
func (n *Notification) Link(ctx context.Context) string {
|
||||
switch n.Source {
|
||||
case NotificationSourceIssue, NotificationSourcePullRequest:
|
||||
if n.Comment != nil {
|
||||
return n.Comment.Link(ctx)
|
||||
}
|
||||
return n.Issue.Link()
|
||||
case NotificationSourceCommit:
|
||||
return n.Repository.Link() + "/commit/" + url.PathEscape(n.CommitID)
|
||||
case NotificationSourceRepository:
|
||||
return n.Repository.Link()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// APIURL formats a URL-string to the notification
|
||||
func (n *Notification) APIURL() string {
|
||||
return setting.AppURL + "api/v1/notifications/threads/" + strconv.FormatInt(n.ID, 10)
|
||||
}
|
||||
|
||||
func notificationExists(notifications []*Notification, issueID, userID int64) bool {
|
||||
for _, notification := range notifications {
|
||||
if notification.IssueID == issueID && notification.UserID == userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// UserIDCount is a simple coalition of UserID and Count
|
||||
type UserIDCount struct {
|
||||
UserID int64
|
||||
Count int64
|
||||
}
|
||||
|
||||
// GetUIDsAndNotificationCounts returns the unread counts for every user between the two provided times.
|
||||
// It must return all user IDs which appear during the period, including count=0 for users who have read all.
|
||||
func GetUIDsAndNotificationCounts(ctx context.Context, since, until timeutil.TimeStamp) ([]UserIDCount, error) {
|
||||
sql := `SELECT user_id, sum(case when status= ? then 1 else 0 end) AS count FROM notification ` +
|
||||
`WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` +
|
||||
`updated_unix < ?) GROUP BY user_id`
|
||||
var res []UserIDCount
|
||||
return res, db.GetEngine(ctx).SQL(sql, NotificationStatusUnread, since, until).Find(&res)
|
||||
}
|
||||
|
||||
// SetIssueReadBy sets issue to be read by given user.
|
||||
func SetIssueReadBy(ctx context.Context, issueID, userID int64) error {
|
||||
if err := issues_model.UpdateIssueUserByRead(ctx, userID, issueID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return setIssueNotificationStatusReadIfUnread(ctx, userID, issueID)
|
||||
}
|
||||
|
||||
func setIssueNotificationStatusReadIfUnread(ctx context.Context, userID, issueID int64) error {
|
||||
notification, err := GetIssueNotification(ctx, userID, issueID)
|
||||
// ignore if not exists
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notification.Status != NotificationStatusUnread {
|
||||
return nil
|
||||
}
|
||||
|
||||
notification.Status = NotificationStatusRead
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(notification.ID).Cols("status").Update(notification)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetRepoReadBy sets repo to be visited by given user.
|
||||
func SetRepoReadBy(ctx context.Context, userID, repoID int64) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Eq{
|
||||
"user_id": userID,
|
||||
"status": NotificationStatusUnread,
|
||||
"source": NotificationSourceRepository,
|
||||
"repo_id": repoID,
|
||||
}).Cols("status").Update(&Notification{Status: NotificationStatusRead})
|
||||
return err
|
||||
}
|
||||
|
||||
// SetNotificationStatus change the notification status
|
||||
func SetNotificationStatus(ctx context.Context, notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) {
|
||||
notification, err := GetNotificationByID(ctx, notificationID)
|
||||
if err != nil {
|
||||
return notification, err
|
||||
}
|
||||
|
||||
if notification.UserID != user.ID {
|
||||
return nil, fmt.Errorf("Can't change notification of another user: %d, %d", notification.UserID, user.ID)
|
||||
}
|
||||
|
||||
notification.Status = status
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(notificationID).Cols("status").Update(notification)
|
||||
return notification, err
|
||||
}
|
||||
|
||||
// GetNotificationByID return notification by ID
|
||||
func GetNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) {
|
||||
notification := new(Notification)
|
||||
ok, err := db.GetEngine(ctx).
|
||||
Where("id = ?", notificationID).
|
||||
Get(notification)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil, db.ErrNotExist{Resource: "notification", ID: notificationID}
|
||||
}
|
||||
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
// UpdateNotificationStatuses updates the statuses of all of a user's notifications that are of the currentStatus type to the desiredStatus
|
||||
func UpdateNotificationStatuses(ctx context.Context, user *user_model.User, currentStatus, desiredStatus NotificationStatus) error {
|
||||
n := &Notification{Status: desiredStatus, UpdatedBy: user.ID}
|
||||
_, err := db.GetEngine(ctx).
|
||||
Where("user_id = ? AND status = ?", user.ID, currentStatus).
|
||||
Cols("status", "updated_by", "updated_unix").
|
||||
Update(n)
|
||||
return err
|
||||
}
|
||||
479
models/activities/notification_list.go
Normal file
479
models/activities/notification_list.go
Normal file
@@ -0,0 +1,479 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
access_model "code.gitea.io/gitea/models/perm/access"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unit"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// FindNotificationOptions represent the filters for notifications. If an ID is 0 it will be ignored.
|
||||
type FindNotificationOptions struct {
|
||||
db.ListOptions
|
||||
UserID int64
|
||||
RepoID int64
|
||||
IssueID int64
|
||||
Status []NotificationStatus
|
||||
Source []NotificationSource
|
||||
UpdatedAfterUnix int64
|
||||
UpdatedBeforeUnix int64
|
||||
}
|
||||
|
||||
// ToCond will convert each condition into a xorm-Cond
|
||||
func (opts FindNotificationOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.UserID != 0 {
|
||||
cond = cond.And(builder.Eq{"notification.user_id": opts.UserID})
|
||||
}
|
||||
if opts.RepoID != 0 {
|
||||
cond = cond.And(builder.Eq{"notification.repo_id": opts.RepoID})
|
||||
}
|
||||
if opts.IssueID != 0 {
|
||||
cond = cond.And(builder.Eq{"notification.issue_id": opts.IssueID})
|
||||
}
|
||||
if len(opts.Status) > 0 {
|
||||
if len(opts.Status) == 1 {
|
||||
cond = cond.And(builder.Eq{"notification.status": opts.Status[0]})
|
||||
} else {
|
||||
cond = cond.And(builder.In("notification.status", opts.Status))
|
||||
}
|
||||
}
|
||||
if len(opts.Source) > 0 {
|
||||
cond = cond.And(builder.In("notification.source", opts.Source))
|
||||
}
|
||||
if opts.UpdatedAfterUnix != 0 {
|
||||
cond = cond.And(builder.Gte{"notification.updated_unix": opts.UpdatedAfterUnix})
|
||||
}
|
||||
if opts.UpdatedBeforeUnix != 0 {
|
||||
cond = cond.And(builder.Lte{"notification.updated_unix": opts.UpdatedBeforeUnix})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts FindNotificationOptions) ToOrders() string {
|
||||
return "notification.updated_unix DESC"
|
||||
}
|
||||
|
||||
// CreateOrUpdateIssueNotifications creates an issue notification
|
||||
// for each watcher, or updates it if already exists
|
||||
// receiverID > 0 just send to receiver, else send to all watcher
|
||||
func CreateOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
return createOrUpdateIssueNotifications(ctx, issueID, commentID, notificationAuthorID, receiverID)
|
||||
})
|
||||
}
|
||||
|
||||
func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error {
|
||||
// init
|
||||
var toNotify container.Set[int64]
|
||||
notifications, err := db.Find[Notification](ctx, FindNotificationOptions{
|
||||
IssueID: issueID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
issue, err := issues_model.GetIssueByID(ctx, issueID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if receiverID > 0 {
|
||||
toNotify = make(container.Set[int64], 1)
|
||||
toNotify.Add(receiverID)
|
||||
} else {
|
||||
toNotify = make(container.Set[int64], 32)
|
||||
issueWatches, err := issues_model.GetIssueWatchersIDs(ctx, issueID, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toNotify.AddMultiple(issueWatches...)
|
||||
if !(issue.IsPull && issues_model.HasWorkInProgressPrefix(issue.Title)) {
|
||||
repoWatches, err := repo_model.GetRepoWatchersIDs(ctx, issue.RepoID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toNotify.AddMultiple(repoWatches...)
|
||||
}
|
||||
issueParticipants, err := issue.GetParticipantIDsByIssue(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
toNotify.AddMultiple(issueParticipants...)
|
||||
|
||||
// dont notify user who cause notification
|
||||
delete(toNotify, notificationAuthorID)
|
||||
// explicit unwatch on issue
|
||||
issueUnWatches, err := issues_model.GetIssueWatchersIDs(ctx, issueID, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, id := range issueUnWatches {
|
||||
toNotify.Remove(id)
|
||||
}
|
||||
}
|
||||
|
||||
err = issue.LoadRepo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// notify
|
||||
for userID := range toNotify {
|
||||
issue.Repo.Units = nil
|
||||
user, err := user_model.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
if user_model.IsErrUserNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if issue.IsPull && !access_model.CheckRepoUnitUser(ctx, issue.Repo, user, unit.TypePullRequests) {
|
||||
continue
|
||||
}
|
||||
if !issue.IsPull && !access_model.CheckRepoUnitUser(ctx, issue.Repo, user, unit.TypeIssues) {
|
||||
continue
|
||||
}
|
||||
|
||||
if notificationExists(notifications, issue.ID, userID) {
|
||||
if err = updateIssueNotification(ctx, userID, issue.ID, commentID, notificationAuthorID); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err = createIssueNotification(ctx, userID, issue, commentID, notificationAuthorID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotificationList contains a list of notifications
|
||||
type NotificationList []*Notification
|
||||
|
||||
// LoadAttributes load Repo Issue User and Comment if not loaded
|
||||
func (nl NotificationList) LoadAttributes(ctx context.Context) error {
|
||||
if _, _, err := nl.LoadRepos(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nl.LoadIssues(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nl.LoadUsers(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nl.LoadComments(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nl NotificationList) getPendingRepoIDs() []int64 {
|
||||
return container.FilterSlice(nl, func(n *Notification) (int64, bool) {
|
||||
if n.Repository != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n.RepoID, true
|
||||
})
|
||||
}
|
||||
|
||||
// LoadRepos loads repositories from database
|
||||
func (nl NotificationList) LoadRepos(ctx context.Context) (repo_model.RepositoryList, []int, error) {
|
||||
if len(nl) == 0 {
|
||||
return repo_model.RepositoryList{}, []int{}, nil
|
||||
}
|
||||
|
||||
repoIDs := nl.getPendingRepoIDs()
|
||||
repos := make(map[int64]*repo_model.Repository, len(repoIDs))
|
||||
left := len(repoIDs)
|
||||
for left > 0 {
|
||||
limit := min(left, db.DefaultMaxInSize)
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", repoIDs[:limit]).
|
||||
Rows(new(repo_model.Repository))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var repo repo_model.Repository
|
||||
err = rows.Scan(&repo)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
repos[repo.ID] = &repo
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
left -= limit
|
||||
repoIDs = repoIDs[limit:]
|
||||
}
|
||||
|
||||
failed := []int{}
|
||||
|
||||
reposList := make(repo_model.RepositoryList, 0, len(repoIDs))
|
||||
for i, notification := range nl {
|
||||
if notification.Repository == nil {
|
||||
notification.Repository = repos[notification.RepoID]
|
||||
}
|
||||
if notification.Repository == nil {
|
||||
log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
|
||||
failed = append(failed, i)
|
||||
continue
|
||||
}
|
||||
var found bool
|
||||
for _, r := range reposList {
|
||||
if r.ID == notification.RepoID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
reposList = append(reposList, notification.Repository)
|
||||
}
|
||||
}
|
||||
return reposList, failed, nil
|
||||
}
|
||||
|
||||
func (nl NotificationList) getPendingIssueIDs() []int64 {
|
||||
ids := make(container.Set[int64], len(nl))
|
||||
for _, notification := range nl {
|
||||
if notification.Issue != nil {
|
||||
continue
|
||||
}
|
||||
ids.Add(notification.IssueID)
|
||||
}
|
||||
return ids.Values()
|
||||
}
|
||||
|
||||
// LoadIssues loads issues from database
|
||||
func (nl NotificationList) LoadIssues(ctx context.Context) ([]int, error) {
|
||||
if len(nl) == 0 {
|
||||
return []int{}, nil
|
||||
}
|
||||
|
||||
issueIDs := nl.getPendingIssueIDs()
|
||||
issues := make(map[int64]*issues_model.Issue, len(issueIDs))
|
||||
left := len(issueIDs)
|
||||
for left > 0 {
|
||||
limit := min(left, db.DefaultMaxInSize)
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", issueIDs[:limit]).
|
||||
Rows(new(issues_model.Issue))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var issue issues_model.Issue
|
||||
err = rows.Scan(&issue)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
issues[issue.ID] = &issue
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
left -= limit
|
||||
issueIDs = issueIDs[limit:]
|
||||
}
|
||||
|
||||
failures := []int{}
|
||||
|
||||
for i, notification := range nl {
|
||||
if notification.Issue == nil {
|
||||
notification.Issue = issues[notification.IssueID]
|
||||
if notification.Issue == nil {
|
||||
if notification.IssueID != 0 {
|
||||
log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
|
||||
failures = append(failures, i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
notification.Issue.Repo = notification.Repository
|
||||
}
|
||||
}
|
||||
return failures, nil
|
||||
}
|
||||
|
||||
// Without returns the notification list without the failures
|
||||
func (nl NotificationList) Without(failures []int) NotificationList {
|
||||
if len(failures) == 0 {
|
||||
return nl
|
||||
}
|
||||
remaining := make([]*Notification, 0, len(nl))
|
||||
last := -1
|
||||
var i int
|
||||
for _, i = range failures {
|
||||
remaining = append(remaining, nl[last+1:i]...)
|
||||
last = i
|
||||
}
|
||||
if len(nl) > i {
|
||||
remaining = append(remaining, nl[i+1:]...)
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (nl NotificationList) getPendingCommentIDs() []int64 {
|
||||
ids := make(container.Set[int64], len(nl))
|
||||
for _, notification := range nl {
|
||||
if notification.CommentID == 0 || notification.Comment != nil {
|
||||
continue
|
||||
}
|
||||
ids.Add(notification.CommentID)
|
||||
}
|
||||
return ids.Values()
|
||||
}
|
||||
|
||||
func (nl NotificationList) getUserIDs() []int64 {
|
||||
ids := make(container.Set[int64], len(nl))
|
||||
for _, notification := range nl {
|
||||
if notification.UserID == 0 || notification.User != nil {
|
||||
continue
|
||||
}
|
||||
ids.Add(notification.UserID)
|
||||
}
|
||||
return ids.Values()
|
||||
}
|
||||
|
||||
// LoadUsers loads users from database
|
||||
func (nl NotificationList) LoadUsers(ctx context.Context) ([]int, error) {
|
||||
if len(nl) == 0 {
|
||||
return []int{}, nil
|
||||
}
|
||||
|
||||
userIDs := nl.getUserIDs()
|
||||
users := make(map[int64]*user_model.User, len(userIDs))
|
||||
left := len(userIDs)
|
||||
for left > 0 {
|
||||
limit := min(left, db.DefaultMaxInSize)
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", userIDs[:limit]).
|
||||
Rows(new(user_model.User))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var user user_model.User
|
||||
err = rows.Scan(&user)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users[user.ID] = &user
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
left -= limit
|
||||
userIDs = userIDs[limit:]
|
||||
}
|
||||
|
||||
failures := []int{}
|
||||
for i, notification := range nl {
|
||||
if notification.UserID > 0 && notification.User == nil && users[notification.UserID] != nil {
|
||||
notification.User = users[notification.UserID]
|
||||
if notification.User == nil {
|
||||
log.Error("Notification[%d]: UserID[%d] failed to load", notification.ID, notification.UserID)
|
||||
failures = append(failures, i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return failures, nil
|
||||
}
|
||||
|
||||
// LoadComments loads comments from database
|
||||
func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) {
|
||||
if len(nl) == 0 {
|
||||
return []int{}, nil
|
||||
}
|
||||
|
||||
commentIDs := nl.getPendingCommentIDs()
|
||||
comments := make(map[int64]*issues_model.Comment, len(commentIDs))
|
||||
left := len(commentIDs)
|
||||
for left > 0 {
|
||||
limit := min(left, db.DefaultMaxInSize)
|
||||
rows, err := db.GetEngine(ctx).
|
||||
In("id", commentIDs[:limit]).
|
||||
Rows(new(issues_model.Comment))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var comment issues_model.Comment
|
||||
err = rows.Scan(&comment)
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
comments[comment.ID] = &comment
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
left -= limit
|
||||
commentIDs = commentIDs[limit:]
|
||||
}
|
||||
|
||||
failures := []int{}
|
||||
for i, notification := range nl {
|
||||
if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
|
||||
notification.Comment = comments[notification.CommentID]
|
||||
if notification.Comment == nil {
|
||||
log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
|
||||
failures = append(failures, i)
|
||||
continue
|
||||
}
|
||||
notification.Comment.Issue = notification.Issue
|
||||
}
|
||||
}
|
||||
return failures, nil
|
||||
}
|
||||
|
||||
// LoadIssuePullRequests loads all issues' pull requests if possible
|
||||
func (nl NotificationList) LoadIssuePullRequests(ctx context.Context) error {
|
||||
issues := make(map[int64]*issues_model.Issue, len(nl))
|
||||
for _, notification := range nl {
|
||||
if notification.Issue != nil && notification.Issue.IsPull && notification.Issue.PullRequest == nil {
|
||||
issues[notification.Issue.ID] = notification.Issue
|
||||
}
|
||||
}
|
||||
|
||||
if len(issues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pulls, err := issues_model.GetPullRequestByIssueIDs(ctx, util.KeysOfMap(issues))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, pull := range pulls {
|
||||
if issue := issues[pull.IssueID]; issue != nil {
|
||||
issue.PullRequest = pull
|
||||
issue.PullRequest.Issue = issue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
140
models/activities/notification_test.go
Normal file
140
models/activities/notification_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
activities_model "code.gitea.io/gitea/models/activities"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateOrUpdateIssueNotifications(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 1})
|
||||
|
||||
assert.NoError(t, activities_model.CreateOrUpdateIssueNotifications(t.Context(), issue.ID, 0, 2, 0))
|
||||
|
||||
// User 9 is inactive, thus notifications for user 1 and 4 are created
|
||||
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: 1, IssueID: issue.ID})
|
||||
assert.Equal(t, activities_model.NotificationStatusUnread, notf.Status)
|
||||
unittest.CheckConsistencyFor(t, &issues_model.Issue{ID: issue.ID})
|
||||
|
||||
notf = unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: 4, IssueID: issue.ID})
|
||||
assert.Equal(t, activities_model.NotificationStatusUnread, notf.Status)
|
||||
}
|
||||
|
||||
func TestNotificationsForUser(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
||||
notfs, err := db.Find[activities_model.Notification](t.Context(), activities_model.FindNotificationOptions{
|
||||
UserID: user.ID,
|
||||
Status: []activities_model.NotificationStatus{
|
||||
activities_model.NotificationStatusRead,
|
||||
activities_model.NotificationStatusUnread,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, notfs, 3) {
|
||||
assert.EqualValues(t, 5, notfs[0].ID)
|
||||
assert.Equal(t, user.ID, notfs[0].UserID)
|
||||
assert.EqualValues(t, 4, notfs[1].ID)
|
||||
assert.Equal(t, user.ID, notfs[1].UserID)
|
||||
assert.EqualValues(t, 2, notfs[2].ID)
|
||||
assert.Equal(t, user.ID, notfs[2].UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotification_GetRepo(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{RepoID: 1})
|
||||
repo, err := notf.GetRepo(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, repo, notf.Repository)
|
||||
assert.Equal(t, notf.RepoID, repo.ID)
|
||||
}
|
||||
|
||||
func TestNotification_GetIssue(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{RepoID: 1})
|
||||
issue, err := notf.GetIssue(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, issue, notf.Issue)
|
||||
assert.Equal(t, notf.IssueID, issue.ID)
|
||||
}
|
||||
|
||||
func TestGetNotificationCount(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
|
||||
cnt, err := db.Count[activities_model.Notification](t.Context(), activities_model.FindNotificationOptions{
|
||||
UserID: user.ID,
|
||||
Status: []activities_model.NotificationStatus{
|
||||
activities_model.NotificationStatusRead,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, cnt)
|
||||
|
||||
cnt, err = db.Count[activities_model.Notification](t.Context(), activities_model.FindNotificationOptions{
|
||||
UserID: user.ID,
|
||||
Status: []activities_model.NotificationStatus{
|
||||
activities_model.NotificationStatusUnread,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
}
|
||||
|
||||
func TestSetNotificationStatus(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
||||
notf := unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead})
|
||||
_, err := activities_model.SetNotificationStatus(t.Context(), notf.ID, user, activities_model.NotificationStatusPinned)
|
||||
assert.NoError(t, err)
|
||||
unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{ID: notf.ID, Status: activities_model.NotificationStatusPinned})
|
||||
|
||||
_, err = activities_model.SetNotificationStatus(t.Context(), 1, user, activities_model.NotificationStatusRead)
|
||||
assert.Error(t, err)
|
||||
_, err = activities_model.SetNotificationStatus(t.Context(), unittest.NonexistentID, user, activities_model.NotificationStatusRead)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUpdateNotificationStatuses(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
|
||||
notfUnread := unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusUnread})
|
||||
notfRead := unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead})
|
||||
notfPinned := unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusPinned})
|
||||
assert.NoError(t, activities_model.UpdateNotificationStatuses(t.Context(), user, activities_model.NotificationStatusUnread, activities_model.NotificationStatusRead))
|
||||
unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{ID: notfUnread.ID, Status: activities_model.NotificationStatusRead})
|
||||
unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{ID: notfRead.ID, Status: activities_model.NotificationStatusRead})
|
||||
unittest.AssertExistsAndLoadBean(t,
|
||||
&activities_model.Notification{ID: notfPinned.ID, Status: activities_model.NotificationStatusPinned})
|
||||
}
|
||||
|
||||
func TestSetIssueReadBy(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
|
||||
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 1})
|
||||
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
return activities_model.SetIssueReadBy(ctx, issue.ID, user.ID)
|
||||
}))
|
||||
|
||||
nt, err := activities_model.GetIssueNotification(t.Context(), user.ID, issue.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, activities_model.NotificationStatusRead, nt.Status)
|
||||
}
|
||||
392
models/activities/repo_activity.go
Normal file
392
models/activities/repo_activity.go
Normal file
@@ -0,0 +1,392 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/git"
|
||||
"code.gitea.io/gitea/modules/gitrepo"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// ActivityAuthorData represents statistical git commit count data
|
||||
type ActivityAuthorData struct {
|
||||
Name string `json:"name"`
|
||||
Login string `json:"login"`
|
||||
AvatarLink string `json:"avatar_link"`
|
||||
HomeLink string `json:"home_link"`
|
||||
Commits int64 `json:"commits"`
|
||||
}
|
||||
|
||||
// ActivityStats represents issue and pull request information.
|
||||
type ActivityStats struct {
|
||||
OpenedPRs issues_model.PullRequestList
|
||||
OpenedPRAuthorCount int64
|
||||
MergedPRs issues_model.PullRequestList
|
||||
MergedPRAuthorCount int64
|
||||
ActiveIssues issues_model.IssueList
|
||||
OpenedIssues issues_model.IssueList
|
||||
OpenedIssueAuthorCount int64
|
||||
ClosedIssues issues_model.IssueList
|
||||
ClosedIssueAuthorCount int64
|
||||
UnresolvedIssues issues_model.IssueList
|
||||
PublishedReleases []*repo_model.Release
|
||||
PublishedReleaseAuthorCount int64
|
||||
Code *git.CodeActivityStats
|
||||
}
|
||||
|
||||
// GetActivityStats return stats for repository at given time range
|
||||
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
|
||||
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
|
||||
if releases {
|
||||
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillReleases: %w", err)
|
||||
}
|
||||
}
|
||||
if prs {
|
||||
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillPullRequests: %w", err)
|
||||
}
|
||||
}
|
||||
if issues {
|
||||
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillIssues: %w", err)
|
||||
}
|
||||
}
|
||||
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
|
||||
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
|
||||
}
|
||||
if code {
|
||||
gitRepo, closer, err := gitrepo.RepositoryFromContextOrOpen(ctx, repo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OpenRepository: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
|
||||
code, err := gitRepo.GetCodeActivityStats(timeFrom, repo.DefaultBranch)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("FillFromGit: %w", err)
|
||||
}
|
||||
stats.Code = code
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetActivityStatsTopAuthors returns top author stats for git commits for all branches
|
||||
func GetActivityStatsTopAuthors(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, count int) ([]*ActivityAuthorData, error) {
|
||||
gitRepo, closer, err := gitrepo.RepositoryFromContextOrOpen(ctx, repo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OpenRepository: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
|
||||
code, err := gitRepo.GetCodeActivityStats(timeFrom, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("FillFromGit: %w", err)
|
||||
}
|
||||
if code.Authors == nil {
|
||||
return nil, nil
|
||||
}
|
||||
users := make(map[int64]*ActivityAuthorData)
|
||||
var unknownUserID int64
|
||||
unknownUserAvatarLink := user_model.NewGhostUser().AvatarLink(ctx)
|
||||
for _, v := range code.Authors {
|
||||
if len(v.Email) == 0 {
|
||||
continue
|
||||
}
|
||||
u, err := user_model.GetUserByEmail(ctx, v.Email)
|
||||
if u == nil || user_model.IsErrUserNotExist(err) {
|
||||
unknownUserID--
|
||||
users[unknownUserID] = &ActivityAuthorData{
|
||||
Name: v.Name,
|
||||
AvatarLink: unknownUserAvatarLink,
|
||||
Commits: v.Commits,
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user, ok := users[u.ID]; !ok {
|
||||
users[u.ID] = &ActivityAuthorData{
|
||||
Name: u.DisplayName(),
|
||||
Login: u.LowerName,
|
||||
AvatarLink: u.AvatarLink(ctx),
|
||||
HomeLink: u.HomeLink(),
|
||||
Commits: v.Commits,
|
||||
}
|
||||
} else {
|
||||
user.Commits += v.Commits
|
||||
}
|
||||
}
|
||||
v := make([]*ActivityAuthorData, 0, len(users))
|
||||
for _, u := range users {
|
||||
v = append(v, u)
|
||||
}
|
||||
|
||||
sort.Slice(v, func(i, j int) bool {
|
||||
return v[i].Commits > v[j].Commits
|
||||
})
|
||||
|
||||
cnt := min(count, len(v))
|
||||
|
||||
return v[:cnt], nil
|
||||
}
|
||||
|
||||
// ActivePRCount returns total active pull request count
|
||||
func (stats *ActivityStats) ActivePRCount() int {
|
||||
return stats.OpenedPRCount() + stats.MergedPRCount()
|
||||
}
|
||||
|
||||
// OpenedPRCount returns opened pull request count
|
||||
func (stats *ActivityStats) OpenedPRCount() int {
|
||||
return len(stats.OpenedPRs)
|
||||
}
|
||||
|
||||
// OpenedPRPerc returns opened pull request percents from total active
|
||||
func (stats *ActivityStats) OpenedPRPerc() int {
|
||||
return int(float32(stats.OpenedPRCount()) / float32(stats.ActivePRCount()) * 100.0)
|
||||
}
|
||||
|
||||
// MergedPRCount returns merged pull request count
|
||||
func (stats *ActivityStats) MergedPRCount() int {
|
||||
return len(stats.MergedPRs)
|
||||
}
|
||||
|
||||
// MergedPRPerc returns merged pull request percent from total active
|
||||
func (stats *ActivityStats) MergedPRPerc() int {
|
||||
return int(float32(stats.MergedPRCount()) / float32(stats.ActivePRCount()) * 100.0)
|
||||
}
|
||||
|
||||
// ActiveIssueCount returns total active issue count
|
||||
func (stats *ActivityStats) ActiveIssueCount() int {
|
||||
return len(stats.ActiveIssues)
|
||||
}
|
||||
|
||||
// OpenedIssueCount returns open issue count
|
||||
func (stats *ActivityStats) OpenedIssueCount() int {
|
||||
return len(stats.OpenedIssues)
|
||||
}
|
||||
|
||||
// OpenedIssuePerc returns open issue count percent from total active
|
||||
func (stats *ActivityStats) OpenedIssuePerc() int {
|
||||
return int(float32(stats.OpenedIssueCount()) / float32(stats.ActiveIssueCount()) * 100.0)
|
||||
}
|
||||
|
||||
// ClosedIssueCount returns closed issue count
|
||||
func (stats *ActivityStats) ClosedIssueCount() int {
|
||||
return len(stats.ClosedIssues)
|
||||
}
|
||||
|
||||
// ClosedIssuePerc returns closed issue count percent from total active
|
||||
func (stats *ActivityStats) ClosedIssuePerc() int {
|
||||
return int(float32(stats.ClosedIssueCount()) / float32(stats.ActiveIssueCount()) * 100.0)
|
||||
}
|
||||
|
||||
// UnresolvedIssueCount returns unresolved issue and pull request count
|
||||
func (stats *ActivityStats) UnresolvedIssueCount() int {
|
||||
return len(stats.UnresolvedIssues)
|
||||
}
|
||||
|
||||
// PublishedReleaseCount returns published release count
|
||||
func (stats *ActivityStats) PublishedReleaseCount() int {
|
||||
return len(stats.PublishedReleases)
|
||||
}
|
||||
|
||||
// FillPullRequests returns pull request information for activity page
|
||||
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Merged pull requests
|
||||
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
|
||||
sess.OrderBy("pull_request.merged_unix DESC")
|
||||
stats.MergedPRs = make(issues_model.PullRequestList, 0)
|
||||
if err = sess.Find(&stats.MergedPRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Merged pull request authors
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.MergedPRAuthorCount = count
|
||||
|
||||
// Opened pull requests
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
|
||||
sess.OrderBy("issue.created_unix ASC")
|
||||
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
|
||||
if err = sess.Find(&stats.OpenedPRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Opened pull request authors
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.OpenedPRAuthorCount = count
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
|
||||
Join("INNER", "issue", "pull_request.issue_id = issue.id")
|
||||
|
||||
if merged {
|
||||
sess.And("pull_request.has_merged = ?", true)
|
||||
sess.And("pull_request.merged_unix >= ?", fromTime.Unix())
|
||||
} else {
|
||||
sess.And("issue.is_closed = ?", false)
|
||||
sess.And("issue.created_unix >= ?", fromTime.Unix())
|
||||
}
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
// FillIssues returns issue information for activity page
|
||||
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Closed issues
|
||||
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
|
||||
sess.OrderBy("issue.closed_unix DESC")
|
||||
stats.ClosedIssues = make(issues_model.IssueList, 0)
|
||||
if err = sess.Find(&stats.ClosedIssues); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Closed issue authors
|
||||
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.ClosedIssueAuthorCount = count
|
||||
|
||||
// New issues
|
||||
sess = newlyCreatedIssues(ctx, repoID, fromTime)
|
||||
sess.OrderBy("issue.created_unix ASC")
|
||||
stats.OpenedIssues = make(issues_model.IssueList, 0)
|
||||
if err = sess.Find(&stats.OpenedIssues); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Active issues
|
||||
sess = activeIssues(ctx, repoID, fromTime)
|
||||
sess.OrderBy("issue.created_unix ASC")
|
||||
stats.ActiveIssues = make(issues_model.IssueList, 0)
|
||||
if err = sess.Find(&stats.ActiveIssues); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Opened issue authors
|
||||
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.OpenedIssueAuthorCount = count
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
|
||||
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
|
||||
// Check if we need to select anything
|
||||
if !issues && !prs {
|
||||
return nil
|
||||
}
|
||||
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
|
||||
if !issues || !prs {
|
||||
sess.And("issue.is_pull = ?", prs)
|
||||
}
|
||||
sess.OrderBy("issue.updated_unix DESC")
|
||||
stats.UnresolvedIssues = make(issues_model.IssueList, 0)
|
||||
return sess.Find(&stats.UnresolvedIssues)
|
||||
}
|
||||
|
||||
func newlyCreatedIssues(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
|
||||
And("issue.is_pull = ?", false). // Retain the is_pull check to exclude pull requests
|
||||
And("issue.created_unix >= ?", fromTime.Unix()) // Include all issues created after fromTime
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
func activeIssues(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
|
||||
And("issue.is_pull = ?", false).
|
||||
And(builder.Or(
|
||||
builder.Gte{"issue.created_unix": fromTime.Unix()},
|
||||
builder.Gte{"issue.closed_unix": fromTime.Unix()},
|
||||
))
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
|
||||
And("issue.is_closed = ?", closed)
|
||||
|
||||
if !unresolved {
|
||||
sess.And("issue.is_pull = ?", false)
|
||||
if closed {
|
||||
sess.And("issue.closed_unix >= ?", fromTime.Unix())
|
||||
} else {
|
||||
sess.And("issue.created_unix >= ?", fromTime.Unix())
|
||||
}
|
||||
} else {
|
||||
sess.And("issue.created_unix < ?", fromTime.Unix())
|
||||
sess.And("issue.updated_unix >= ?", fromTime.Unix())
|
||||
}
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
// FillReleases returns release information for activity page
|
||||
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Published releases list
|
||||
sess := releasesForActivityStatement(ctx, repoID, fromTime)
|
||||
sess.OrderBy("`release`.created_unix DESC")
|
||||
stats.PublishedReleases = make([]*repo_model.Release, 0)
|
||||
if err = sess.Find(&stats.PublishedReleases); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Published releases authors
|
||||
sess = releasesForActivityStatement(ctx, repoID, fromTime)
|
||||
if _, err = sess.Select("count(distinct `release`.publisher_id) as `count`").Table("release").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.PublishedReleaseAuthorCount = count
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
|
||||
return db.GetEngine(ctx).Where("`release`.repo_id = ?", repoID).
|
||||
And("`release`.is_draft = ?", false).
|
||||
And("`release`.created_unix >= ?", fromTime.Unix())
|
||||
}
|
||||
135
models/activities/statistic.go
Normal file
135
models/activities/statistic.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
asymkey_model "code.gitea.io/gitea/models/asymkey"
|
||||
"code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
git_model "code.gitea.io/gitea/models/git"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/models/organization"
|
||||
access_model "code.gitea.io/gitea/models/perm/access"
|
||||
project_model "code.gitea.io/gitea/models/project"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/models/webhook"
|
||||
"code.gitea.io/gitea/modules/optional"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/structs"
|
||||
)
|
||||
|
||||
// Statistic contains the database statistics
|
||||
type Statistic struct {
|
||||
Counter struct {
|
||||
UsersActive, UsersNotActive,
|
||||
Org, PublicKey,
|
||||
Repo, Watch, Star, Access,
|
||||
Issue, IssueClosed, IssueOpen,
|
||||
Comment, Oauth, Follow,
|
||||
Mirror, Release, AuthSource, Webhook,
|
||||
Milestone, Label, HookTask,
|
||||
Team, UpdateTask, Project,
|
||||
ProjectColumn, Attachment,
|
||||
Branches, Tags, CommitStatus int64
|
||||
IssueByLabel []IssueByLabelCount
|
||||
IssueByRepository []IssueByRepositoryCount
|
||||
}
|
||||
}
|
||||
|
||||
// IssueByLabelCount contains the number of issue group by label
|
||||
type IssueByLabelCount struct {
|
||||
Count int64
|
||||
Label string
|
||||
}
|
||||
|
||||
// IssueByRepositoryCount contains the number of issue group by repository
|
||||
type IssueByRepositoryCount struct {
|
||||
Count int64
|
||||
OwnerName string
|
||||
Repository string
|
||||
}
|
||||
|
||||
// GetStatistic returns the database statistics
|
||||
func GetStatistic(ctx context.Context) (stats Statistic) {
|
||||
e := db.GetEngine(ctx)
|
||||
|
||||
// Number of active users
|
||||
usersActiveOpts := user_model.CountUserFilter{
|
||||
IsActive: optional.Some(true),
|
||||
}
|
||||
stats.Counter.UsersActive = user_model.CountUsers(ctx, &usersActiveOpts)
|
||||
|
||||
// Number of inactive users
|
||||
usersNotActiveOpts := user_model.CountUserFilter{
|
||||
IsActive: optional.Some(false),
|
||||
}
|
||||
stats.Counter.UsersNotActive = user_model.CountUsers(ctx, &usersNotActiveOpts)
|
||||
|
||||
stats.Counter.Org, _ = db.Count[organization.Organization](ctx, organization.FindOrgOptions{IncludeVisibility: structs.VisibleTypePrivate})
|
||||
stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey))
|
||||
stats.Counter.Repo, _ = repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{})
|
||||
stats.Counter.Watch, _ = e.Count(new(repo_model.Watch))
|
||||
stats.Counter.Star, _ = e.Count(new(repo_model.Star))
|
||||
stats.Counter.Access, _ = e.Count(new(access_model.Access))
|
||||
stats.Counter.Branches, _ = e.Count(new(git_model.Branch))
|
||||
stats.Counter.Tags, _ = e.Where("is_draft=?", false).Count(new(repo_model.Release))
|
||||
stats.Counter.CommitStatus, _ = e.Count(new(git_model.CommitStatus))
|
||||
|
||||
type IssueCount struct {
|
||||
Count int64
|
||||
IsClosed bool
|
||||
}
|
||||
|
||||
if setting.Metrics.EnabledIssueByLabel {
|
||||
stats.Counter.IssueByLabel = []IssueByLabelCount{}
|
||||
|
||||
_ = e.Select("COUNT(*) AS count, l.name AS label").
|
||||
Join("LEFT", "label l", "l.id=il.label_id").
|
||||
Table("issue_label il").
|
||||
GroupBy("l.name").
|
||||
Find(&stats.Counter.IssueByLabel)
|
||||
}
|
||||
|
||||
if setting.Metrics.EnabledIssueByRepository {
|
||||
stats.Counter.IssueByRepository = []IssueByRepositoryCount{}
|
||||
|
||||
_ = e.Select("COUNT(*) AS count, r.owner_name, r.name AS repository").
|
||||
Join("LEFT", "repository r", "r.id=i.repo_id").
|
||||
Table("issue i").
|
||||
GroupBy("r.owner_name, r.name").
|
||||
Find(&stats.Counter.IssueByRepository)
|
||||
}
|
||||
|
||||
var issueCounts []IssueCount
|
||||
|
||||
_ = e.Select("COUNT(*) AS count, is_closed").Table("issue").GroupBy("is_closed").Find(&issueCounts)
|
||||
for _, c := range issueCounts {
|
||||
if c.IsClosed {
|
||||
stats.Counter.IssueClosed = c.Count
|
||||
} else {
|
||||
stats.Counter.IssueOpen = c.Count
|
||||
}
|
||||
}
|
||||
|
||||
stats.Counter.Issue = stats.Counter.IssueClosed + stats.Counter.IssueOpen
|
||||
|
||||
stats.Counter.Comment, _ = e.Count(new(issues_model.Comment))
|
||||
stats.Counter.Oauth = 0
|
||||
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
|
||||
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
|
||||
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
|
||||
stats.Counter.AuthSource, _ = db.Count[auth.Source](ctx, auth.FindSourcesOptions{})
|
||||
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
|
||||
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
|
||||
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
|
||||
stats.Counter.HookTask, _ = e.Count(new(webhook.HookTask))
|
||||
stats.Counter.Team, _ = e.Count(new(organization.Team))
|
||||
stats.Counter.Attachment, _ = e.Count(new(repo_model.Attachment))
|
||||
stats.Counter.Project, _ = e.Count(new(project_model.Project))
|
||||
stats.Counter.ProjectColumn, _ = e.Count(new(project_model.Column))
|
||||
return stats
|
||||
}
|
||||
82
models/activities/user_heatmap.go
Normal file
82
models/activities/user_heatmap.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/organization"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
)
|
||||
|
||||
// UserHeatmapData represents the data needed to create a heatmap
|
||||
type UserHeatmapData struct {
|
||||
Timestamp timeutil.TimeStamp `json:"timestamp"`
|
||||
Contributions int64 `json:"contributions"`
|
||||
}
|
||||
|
||||
// GetUserHeatmapDataByUser returns an array of UserHeatmapData
|
||||
func GetUserHeatmapDataByUser(ctx context.Context, user, doer *user_model.User) ([]*UserHeatmapData, error) {
|
||||
return getUserHeatmapData(ctx, user, nil, doer)
|
||||
}
|
||||
|
||||
// GetUserHeatmapDataByUserTeam returns an array of UserHeatmapData
|
||||
func GetUserHeatmapDataByUserTeam(ctx context.Context, user *user_model.User, team *organization.Team, doer *user_model.User) ([]*UserHeatmapData, error) {
|
||||
return getUserHeatmapData(ctx, user, team, doer)
|
||||
}
|
||||
|
||||
func getUserHeatmapData(ctx context.Context, user *user_model.User, team *organization.Team, doer *user_model.User) ([]*UserHeatmapData, error) {
|
||||
hdata := make([]*UserHeatmapData, 0)
|
||||
|
||||
if !ActivityReadable(user, doer) {
|
||||
return hdata, nil
|
||||
}
|
||||
|
||||
// Group by 15 minute intervals which will allow the client to accurately shift the timestamp to their timezone.
|
||||
// The interval is based on the fact that there are timezones such as UTC +5:30 and UTC +12:45.
|
||||
groupBy := "created_unix / 900 * 900"
|
||||
groupByName := "timestamp" // We need this extra case because mssql doesn't allow grouping by alias
|
||||
switch {
|
||||
case setting.Database.Type.IsMySQL():
|
||||
groupBy = "created_unix DIV 900 * 900"
|
||||
case setting.Database.Type.IsMSSQL():
|
||||
groupByName = groupBy
|
||||
}
|
||||
|
||||
cond, err := ActivityQueryCondition(ctx, GetFeedsOptions{
|
||||
RequestedUser: user,
|
||||
RequestedTeam: team,
|
||||
Actor: doer,
|
||||
IncludePrivate: true, // don't filter by private, as we already filter by repo access
|
||||
IncludeDeleted: true,
|
||||
// * Heatmaps for individual users only include actions that the user themself did.
|
||||
// * For organizations actions by all users that were made in owned
|
||||
// repositories are counted.
|
||||
OnlyPerformedBy: !user.IsOrganization(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hdata, db.GetEngine(ctx).
|
||||
Select(groupBy+" AS timestamp, count(user_id) as contributions").
|
||||
Table("action").
|
||||
Where(cond).
|
||||
And("created_unix > ?", timeutil.TimeStampNow()-(366+7)*86400). // (366+7) days to include the first week for the heatmap
|
||||
GroupBy(groupByName).
|
||||
OrderBy("timestamp").
|
||||
Find(&hdata)
|
||||
}
|
||||
|
||||
// GetTotalContributionsInHeatmap returns the total number of contributions in a heatmap
|
||||
func GetTotalContributionsInHeatmap(hdata []*UserHeatmapData) int64 {
|
||||
var total int64
|
||||
for _, v := range hdata {
|
||||
total += v.Contributions
|
||||
}
|
||||
return total
|
||||
}
|
||||
97
models/activities/user_heatmap_test.go
Normal file
97
models/activities/user_heatmap_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package activities_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
activities_model "code.gitea.io/gitea/models/activities"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetUserHeatmapDataByUser(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
userID int64
|
||||
doerID int64
|
||||
CountResult int
|
||||
JSONResult string
|
||||
}{
|
||||
{
|
||||
"self looks at action in private repo",
|
||||
2, 2, 1, `[{"timestamp":1603227600,"contributions":1}]`,
|
||||
},
|
||||
{
|
||||
"admin looks at action in private repo",
|
||||
2, 1, 1, `[{"timestamp":1603227600,"contributions":1}]`,
|
||||
},
|
||||
{
|
||||
"other user looks at action in private repo",
|
||||
2, 3, 0, `[]`,
|
||||
},
|
||||
{
|
||||
"nobody looks at action in private repo",
|
||||
2, 0, 0, `[]`,
|
||||
},
|
||||
{
|
||||
"collaborator looks at action in private repo",
|
||||
16, 15, 1, `[{"timestamp":1603267200,"contributions":1}]`,
|
||||
},
|
||||
{
|
||||
"no action action not performed by target user",
|
||||
3, 3, 0, `[]`,
|
||||
},
|
||||
{
|
||||
"multiple actions performed with two grouped together",
|
||||
10, 10, 3, `[{"timestamp":1603009800,"contributions":1},{"timestamp":1603010700,"contributions":2}]`,
|
||||
},
|
||||
}
|
||||
// Prepare
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
// Mock time
|
||||
timeutil.MockSet(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC))
|
||||
defer timeutil.MockUnset()
|
||||
|
||||
for _, tc := range testCases {
|
||||
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: tc.userID})
|
||||
|
||||
var doer *user_model.User
|
||||
if tc.doerID != 0 {
|
||||
doer = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: tc.doerID})
|
||||
}
|
||||
|
||||
// get the action for comparison
|
||||
actions, count, err := activities_model.GetFeeds(t.Context(), activities_model.GetFeedsOptions{
|
||||
RequestedUser: user,
|
||||
Actor: doer,
|
||||
IncludePrivate: true,
|
||||
OnlyPerformedBy: true,
|
||||
IncludeDeleted: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get the heatmap and compare
|
||||
heatmap, err := activities_model.GetUserHeatmapDataByUser(t.Context(), user, doer)
|
||||
var contributions int
|
||||
for _, hm := range heatmap {
|
||||
contributions += int(hm.Contributions)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, actions, contributions, "invalid action count: did the test data became too old?")
|
||||
assert.Equal(t, count, int64(contributions))
|
||||
assert.Equal(t, tc.CountResult, contributions, "testcase '%s'", tc.desc)
|
||||
|
||||
// Test JSON rendering
|
||||
jsonData, err := json.Marshal(heatmap)
|
||||
assert.NoError(t, err)
|
||||
assert.JSONEq(t, tc.JSONResult, string(jsonData))
|
||||
}
|
||||
}
|
||||
212
models/admin/task.go
Normal file
212
models/admin/task.go
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright 2019 Gitea. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/migration"
|
||||
"code.gitea.io/gitea/modules/secret"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/structs"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// Task represents a task
|
||||
type Task struct {
|
||||
ID int64
|
||||
DoerID int64 `xorm:"index"` // operator
|
||||
Doer *user_model.User `xorm:"-"`
|
||||
OwnerID int64 `xorm:"index"` // repo owner id, when creating, the repoID maybe zero
|
||||
Owner *user_model.User `xorm:"-"`
|
||||
RepoID int64 `xorm:"index"`
|
||||
Repo *repo_model.Repository `xorm:"-"`
|
||||
Type structs.TaskType
|
||||
Status structs.TaskStatus `xorm:"index"`
|
||||
StartTime timeutil.TimeStamp
|
||||
EndTime timeutil.TimeStamp
|
||||
PayloadContent string `xorm:"TEXT"`
|
||||
Message string `xorm:"TEXT"` // if task failed, saved the error reason, it could be a JSON string of TranslatableMessage or a plain message
|
||||
Created timeutil.TimeStamp `xorm:"created"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Task))
|
||||
}
|
||||
|
||||
// TranslatableMessage represents JSON struct that can be translated with a Locale
|
||||
type TranslatableMessage struct {
|
||||
Format string
|
||||
Args []any `json:",omitempty"`
|
||||
}
|
||||
|
||||
// LoadRepo loads repository of the task
|
||||
func (task *Task) LoadRepo(ctx context.Context) error {
|
||||
if task.Repo != nil {
|
||||
return nil
|
||||
}
|
||||
var repo repo_model.Repository
|
||||
has, err := db.GetEngine(ctx).ID(task.RepoID).Get(&repo)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return repo_model.ErrRepoNotExist{
|
||||
ID: task.RepoID,
|
||||
}
|
||||
}
|
||||
task.Repo = &repo
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDoer loads do user
|
||||
func (task *Task) LoadDoer(ctx context.Context) error {
|
||||
if task.Doer != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var doer user_model.User
|
||||
has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user_model.ErrUserNotExist{
|
||||
UID: task.DoerID,
|
||||
}
|
||||
}
|
||||
task.Doer = &doer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadOwner loads owner user
|
||||
func (task *Task) LoadOwner(ctx context.Context) error {
|
||||
if task.Owner != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var owner user_model.User
|
||||
has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
return user_model.ErrUserNotExist{
|
||||
UID: task.OwnerID,
|
||||
}
|
||||
}
|
||||
task.Owner = &owner
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCols updates some columns
|
||||
func (task *Task) UpdateCols(ctx context.Context, cols ...string) error {
|
||||
_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task)
|
||||
return err
|
||||
}
|
||||
|
||||
// MigrateConfig returns task config when migrate repository
|
||||
func (task *Task) MigrateConfig() (*migration.MigrateOptions, error) {
|
||||
if task.Type == structs.TaskTypeMigrateRepo {
|
||||
var opts migration.MigrateOptions
|
||||
err := json.Unmarshal([]byte(task.PayloadContent), &opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// decrypt credentials
|
||||
if opts.CloneAddrEncrypted != "" {
|
||||
if opts.CloneAddr, err = secret.DecryptSecret(setting.SecretKey, opts.CloneAddrEncrypted); err != nil {
|
||||
log.Error("Unable to decrypt CloneAddr, maybe SECRET_KEY is wrong: %v", err)
|
||||
}
|
||||
}
|
||||
if opts.AuthPasswordEncrypted != "" {
|
||||
if opts.AuthPassword, err = secret.DecryptSecret(setting.SecretKey, opts.AuthPasswordEncrypted); err != nil {
|
||||
log.Error("Unable to decrypt AuthPassword, maybe SECRET_KEY is wrong: %v", err)
|
||||
}
|
||||
}
|
||||
if opts.AuthTokenEncrypted != "" {
|
||||
if opts.AuthToken, err = secret.DecryptSecret(setting.SecretKey, opts.AuthTokenEncrypted); err != nil {
|
||||
log.Error("Unable to decrypt AuthToken, maybe SECRET_KEY is wrong: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &opts, nil
|
||||
}
|
||||
return nil, fmt.Errorf("Task type is %s, not Migrate Repo", task.Type.Name())
|
||||
}
|
||||
|
||||
// ErrTaskDoesNotExist represents a "TaskDoesNotExist" kind of error.
|
||||
type ErrTaskDoesNotExist struct {
|
||||
ID int64
|
||||
RepoID int64
|
||||
Type structs.TaskType
|
||||
}
|
||||
|
||||
// IsErrTaskDoesNotExist checks if an error is a ErrTaskDoesNotExist.
|
||||
func IsErrTaskDoesNotExist(err error) bool {
|
||||
_, ok := err.(ErrTaskDoesNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrTaskDoesNotExist) Error() string {
|
||||
return fmt.Sprintf("task does not exist [id: %d, repo_id: %d, type: %d]",
|
||||
err.ID, err.RepoID, err.Type)
|
||||
}
|
||||
|
||||
func (err ErrTaskDoesNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// GetMigratingTask returns the migrating task by repo's id
|
||||
func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) {
|
||||
task := Task{
|
||||
RepoID: repoID,
|
||||
Type: structs.TaskTypeMigrateRepo,
|
||||
}
|
||||
has, err := db.GetEngine(ctx).Get(&task)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrTaskDoesNotExist{0, repoID, task.Type}
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// CreateTask creates a task on database
|
||||
func CreateTask(ctx context.Context, task *Task) error {
|
||||
return db.Insert(ctx, task)
|
||||
}
|
||||
|
||||
// FinishMigrateTask updates database when migrate task finished
|
||||
func FinishMigrateTask(ctx context.Context, task *Task) error {
|
||||
task.Status = structs.TaskStatusFinished
|
||||
task.EndTime = timeutil.TimeStampNow()
|
||||
|
||||
// delete credentials when we're done, they're a liability.
|
||||
conf, err := task.MigrateConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conf.AuthPassword = ""
|
||||
conf.AuthToken = ""
|
||||
conf.CloneAddr = util.SanitizeCredentialURLs(conf.CloneAddr)
|
||||
conf.AuthPasswordEncrypted = ""
|
||||
conf.AuthTokenEncrypted = ""
|
||||
conf.CloneAddrEncrypted = ""
|
||||
confBytes, err := json.Marshal(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
task.PayloadContent = string(confBytes)
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
|
||||
return err
|
||||
}
|
||||
319
models/asymkey/error.go
Normal file
319
models/asymkey/error.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// ErrKeyUnableVerify represents a "KeyUnableVerify" kind of error.
|
||||
type ErrKeyUnableVerify struct {
|
||||
Result string
|
||||
}
|
||||
|
||||
// IsErrKeyUnableVerify checks if an error is a ErrKeyUnableVerify.
|
||||
func IsErrKeyUnableVerify(err error) bool {
|
||||
_, ok := err.(ErrKeyUnableVerify)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrKeyUnableVerify) Error() string {
|
||||
return fmt.Sprintf("Unable to verify key content [result: %s]", err.Result)
|
||||
}
|
||||
|
||||
// ErrKeyIsPrivate is returned when the provided key is a private key not a public key
|
||||
var ErrKeyIsPrivate = util.ErrorWrap(util.ErrInvalidArgument, "the provided key is a private key")
|
||||
|
||||
// ErrKeyNotExist represents a "KeyNotExist" kind of error.
|
||||
type ErrKeyNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrKeyNotExist checks if an error is a ErrKeyNotExist.
|
||||
func IsErrKeyNotExist(err error) bool {
|
||||
_, ok := err.(ErrKeyNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrKeyNotExist) Error() string {
|
||||
return fmt.Sprintf("public key does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
func (err ErrKeyNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrKeyAlreadyExist represents a "KeyAlreadyExist" kind of error.
|
||||
type ErrKeyAlreadyExist struct {
|
||||
OwnerID int64
|
||||
Fingerprint string
|
||||
Content string
|
||||
}
|
||||
|
||||
// IsErrKeyAlreadyExist checks if an error is a ErrKeyAlreadyExist.
|
||||
func IsErrKeyAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrKeyAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrKeyAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("public key already exists [owner_id: %d, finger_print: %s, content: %s]",
|
||||
err.OwnerID, err.Fingerprint, err.Content)
|
||||
}
|
||||
|
||||
func (err ErrKeyAlreadyExist) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrKeyNameAlreadyUsed represents a "KeyNameAlreadyUsed" kind of error.
|
||||
type ErrKeyNameAlreadyUsed struct {
|
||||
OwnerID int64
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrKeyNameAlreadyUsed checks if an error is a ErrKeyNameAlreadyUsed.
|
||||
func IsErrKeyNameAlreadyUsed(err error) bool {
|
||||
_, ok := err.(ErrKeyNameAlreadyUsed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrKeyNameAlreadyUsed) Error() string {
|
||||
return fmt.Sprintf("public key already exists [owner_id: %d, name: %s]", err.OwnerID, err.Name)
|
||||
}
|
||||
|
||||
func (err ErrKeyNameAlreadyUsed) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrGPGNoEmailFound represents a "ErrGPGNoEmailFound" kind of error.
|
||||
type ErrGPGNoEmailFound struct {
|
||||
FailedEmails []string
|
||||
ID string
|
||||
}
|
||||
|
||||
// IsErrGPGNoEmailFound checks if an error is a ErrGPGNoEmailFound.
|
||||
func IsErrGPGNoEmailFound(err error) bool {
|
||||
_, ok := err.(ErrGPGNoEmailFound)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGNoEmailFound) Error() string {
|
||||
return fmt.Sprintf("none of the emails attached to the GPG key could be found: %v", err.FailedEmails)
|
||||
}
|
||||
|
||||
// ErrGPGInvalidTokenSignature represents a "ErrGPGInvalidTokenSignature" kind of error.
|
||||
type ErrGPGInvalidTokenSignature struct {
|
||||
Wrapped error
|
||||
ID string
|
||||
}
|
||||
|
||||
// IsErrGPGInvalidTokenSignature checks if an error is a ErrGPGInvalidTokenSignature.
|
||||
func IsErrGPGInvalidTokenSignature(err error) bool {
|
||||
_, ok := err.(ErrGPGInvalidTokenSignature)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGInvalidTokenSignature) Error() string {
|
||||
return "the provided signature does not sign the token with the provided key"
|
||||
}
|
||||
|
||||
// ErrGPGKeyParsing represents a "ErrGPGKeyParsing" kind of error.
|
||||
type ErrGPGKeyParsing struct {
|
||||
ParseError error
|
||||
}
|
||||
|
||||
// IsErrGPGKeyParsing checks if an error is a ErrGPGKeyParsing.
|
||||
func IsErrGPGKeyParsing(err error) bool {
|
||||
_, ok := err.(ErrGPGKeyParsing)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyParsing) Error() string {
|
||||
return "failed to parse gpg key " + err.ParseError.Error()
|
||||
}
|
||||
|
||||
// ErrGPGKeyNotExist represents a "GPGKeyNotExist" kind of error.
|
||||
type ErrGPGKeyNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrGPGKeyNotExist checks if an error is a ErrGPGKeyNotExist.
|
||||
func IsErrGPGKeyNotExist(err error) bool {
|
||||
_, ok := err.(ErrGPGKeyNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyNotExist) Error() string {
|
||||
return fmt.Sprintf("public gpg key does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrGPGKeyImportNotExist represents a "GPGKeyImportNotExist" kind of error.
|
||||
type ErrGPGKeyImportNotExist struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
// IsErrGPGKeyImportNotExist checks if an error is a ErrGPGKeyImportNotExist.
|
||||
func IsErrGPGKeyImportNotExist(err error) bool {
|
||||
_, ok := err.(ErrGPGKeyImportNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyImportNotExist) Error() string {
|
||||
return fmt.Sprintf("public gpg key import does not exist [id: %s]", err.ID)
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyImportNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrGPGKeyIDAlreadyUsed represents a "GPGKeyIDAlreadyUsed" kind of error.
|
||||
type ErrGPGKeyIDAlreadyUsed struct {
|
||||
KeyID string
|
||||
}
|
||||
|
||||
// IsErrGPGKeyIDAlreadyUsed checks if an error is a ErrKeyNameAlreadyUsed.
|
||||
func IsErrGPGKeyIDAlreadyUsed(err error) bool {
|
||||
_, ok := err.(ErrGPGKeyIDAlreadyUsed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyIDAlreadyUsed) Error() string {
|
||||
return fmt.Sprintf("public key already exists [key_id: %s]", err.KeyID)
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyIDAlreadyUsed) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrGPGKeyAccessDenied represents a "GPGKeyAccessDenied" kind of Error.
|
||||
type ErrGPGKeyAccessDenied struct {
|
||||
UserID int64
|
||||
KeyID int64
|
||||
}
|
||||
|
||||
// IsErrGPGKeyAccessDenied checks if an error is a ErrGPGKeyAccessDenied.
|
||||
func IsErrGPGKeyAccessDenied(err error) bool {
|
||||
_, ok := err.(ErrGPGKeyAccessDenied)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error pretty-prints an error of type ErrGPGKeyAccessDenied.
|
||||
func (err ErrGPGKeyAccessDenied) Error() string {
|
||||
return fmt.Sprintf("user does not have access to the key [user_id: %d, key_id: %d]",
|
||||
err.UserID, err.KeyID)
|
||||
}
|
||||
|
||||
func (err ErrGPGKeyAccessDenied) Unwrap() error {
|
||||
return util.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// ErrKeyAccessDenied represents a "KeyAccessDenied" kind of error.
|
||||
type ErrKeyAccessDenied struct {
|
||||
UserID int64
|
||||
RepoID int64
|
||||
KeyID int64
|
||||
Note string
|
||||
}
|
||||
|
||||
// IsErrKeyAccessDenied checks if an error is a ErrKeyAccessDenied.
|
||||
func IsErrKeyAccessDenied(err error) bool {
|
||||
_, ok := err.(ErrKeyAccessDenied)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrKeyAccessDenied) Error() string {
|
||||
return fmt.Sprintf("user does not have access to the key [user_id: %d, repo_id: %d, key_id: %d, note: %s]",
|
||||
err.UserID, err.RepoID, err.KeyID, err.Note)
|
||||
}
|
||||
|
||||
func (err ErrKeyAccessDenied) Unwrap() error {
|
||||
return util.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// ErrDeployKeyNotExist represents a "DeployKeyNotExist" kind of error.
|
||||
type ErrDeployKeyNotExist struct {
|
||||
ID int64
|
||||
KeyID int64
|
||||
RepoID int64
|
||||
}
|
||||
|
||||
// IsErrDeployKeyNotExist checks if an error is a ErrDeployKeyNotExist.
|
||||
func IsErrDeployKeyNotExist(err error) bool {
|
||||
_, ok := err.(ErrDeployKeyNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyNotExist) Error() string {
|
||||
return fmt.Sprintf("Deploy key does not exist [id: %d, key_id: %d, repo_id: %d]", err.ID, err.KeyID, err.RepoID)
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrDeployKeyAlreadyExist represents a "DeployKeyAlreadyExist" kind of error.
|
||||
type ErrDeployKeyAlreadyExist struct {
|
||||
KeyID int64
|
||||
RepoID int64
|
||||
}
|
||||
|
||||
// IsErrDeployKeyAlreadyExist checks if an error is a ErrDeployKeyAlreadyExist.
|
||||
func IsErrDeployKeyAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrDeployKeyAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("public key already exists [key_id: %d, repo_id: %d]", err.KeyID, err.RepoID)
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyAlreadyExist) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrDeployKeyNameAlreadyUsed represents a "DeployKeyNameAlreadyUsed" kind of error.
|
||||
type ErrDeployKeyNameAlreadyUsed struct {
|
||||
RepoID int64
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrDeployKeyNameAlreadyUsed checks if an error is a ErrDeployKeyNameAlreadyUsed.
|
||||
func IsErrDeployKeyNameAlreadyUsed(err error) bool {
|
||||
_, ok := err.(ErrDeployKeyNameAlreadyUsed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyNameAlreadyUsed) Error() string {
|
||||
return fmt.Sprintf("public key with name already exists [repo_id: %d, name: %s]", err.RepoID, err.Name)
|
||||
}
|
||||
|
||||
func (err ErrDeployKeyNameAlreadyUsed) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrSSHInvalidTokenSignature represents a "ErrSSHInvalidTokenSignature" kind of error.
|
||||
type ErrSSHInvalidTokenSignature struct {
|
||||
Wrapped error
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
// IsErrSSHInvalidTokenSignature checks if an error is a ErrSSHInvalidTokenSignature.
|
||||
func IsErrSSHInvalidTokenSignature(err error) bool {
|
||||
_, ok := err.(ErrSSHInvalidTokenSignature)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSSHInvalidTokenSignature) Error() string {
|
||||
return "the provided signature does not sign the token with the provided key"
|
||||
}
|
||||
|
||||
func (err ErrSSHInvalidTokenSignature) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
242
models/asymkey/gpg_key.go
Normal file
242
models/asymkey/gpg_key.go
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// GPGKey represents a GPG key.
|
||||
type GPGKey struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
OwnerID int64 `xorm:"INDEX NOT NULL"`
|
||||
KeyID string `xorm:"INDEX CHAR(16) NOT NULL"`
|
||||
PrimaryKeyID string `xorm:"CHAR(16)"`
|
||||
Content string `xorm:"MEDIUMTEXT NOT NULL"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
ExpiredUnix timeutil.TimeStamp
|
||||
AddedUnix timeutil.TimeStamp
|
||||
SubsKey []*GPGKey `xorm:"-"`
|
||||
Emails []*user_model.EmailAddress
|
||||
Verified bool `xorm:"NOT NULL DEFAULT false"`
|
||||
CanSign bool
|
||||
CanEncryptComms bool
|
||||
CanEncryptStorage bool
|
||||
CanCertify bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(GPGKey))
|
||||
}
|
||||
|
||||
// BeforeInsert will be invoked by XORM before inserting a record
|
||||
func (key *GPGKey) BeforeInsert() {
|
||||
key.AddedUnix = timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
func (key *GPGKey) LoadSubKeys(ctx context.Context) error {
|
||||
if err := db.GetEngine(ctx).Where("primary_key_id=?", key.KeyID).Find(&key.SubsKey); err != nil {
|
||||
return fmt.Errorf("find Sub GPGkeys[%s]: %v", key.KeyID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PaddedKeyID show KeyID padded to 16 characters
|
||||
func (key *GPGKey) PaddedKeyID() string {
|
||||
return PaddedKeyID(key.KeyID)
|
||||
}
|
||||
|
||||
// PaddedKeyID show KeyID padded to 16 characters
|
||||
func PaddedKeyID(keyID string) string {
|
||||
if len(keyID) > 15 {
|
||||
return keyID
|
||||
}
|
||||
zeros := "0000000000000000"
|
||||
return zeros[0:16-len(keyID)] + keyID
|
||||
}
|
||||
|
||||
type FindGPGKeyOptions struct {
|
||||
db.ListOptions
|
||||
OwnerID int64
|
||||
KeyID string
|
||||
IncludeSubKeys bool
|
||||
}
|
||||
|
||||
func (opts FindGPGKeyOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if !opts.IncludeSubKeys {
|
||||
cond = cond.And(builder.Eq{"primary_key_id": ""})
|
||||
}
|
||||
|
||||
if opts.OwnerID > 0 {
|
||||
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
}
|
||||
if opts.KeyID != "" {
|
||||
cond = cond.And(builder.Eq{"key_id": opts.KeyID})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func GetGPGKeyForUserByID(ctx context.Context, ownerID, keyID int64) (*GPGKey, error) {
|
||||
key := new(GPGKey)
|
||||
has, err := db.GetEngine(ctx).Where("id=? AND owner_id=?", keyID, ownerID).Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrGPGKeyNotExist{keyID}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GPGKeyToEntity retrieve the imported key and the traducted entity
|
||||
func GPGKeyToEntity(ctx context.Context, k *GPGKey) (*openpgp.Entity, error) {
|
||||
impKey, err := GetGPGImportByKeyID(ctx, k.KeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys, err := CheckArmoredGPGKeyString(impKey.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys[0], err
|
||||
}
|
||||
|
||||
// parseSubGPGKey parse a sub Key
|
||||
func parseSubGPGKey(ownerID int64, primaryID string, pubkey *packet.PublicKey, expiry time.Time) (*GPGKey, error) {
|
||||
content, err := Base64EncPubKey(pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GPGKey{
|
||||
OwnerID: ownerID,
|
||||
KeyID: pubkey.KeyIdString(),
|
||||
PrimaryKeyID: primaryID,
|
||||
Content: content,
|
||||
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
|
||||
ExpiredUnix: timeutil.TimeStamp(expiry.Unix()),
|
||||
CanSign: pubkey.CanSign(),
|
||||
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanCertify: pubkey.PubKeyAlgo.CanSign(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseGPGKey parse a PrimaryKey entity (primary key + subs keys + self-signature)
|
||||
func parseGPGKey(ctx context.Context, ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, error) {
|
||||
pubkey := e.PrimaryKey
|
||||
expiry := getExpiryTime(e)
|
||||
|
||||
// Parse Subkeys
|
||||
subkeys := make([]*GPGKey, len(e.Subkeys))
|
||||
for i, k := range e.Subkeys {
|
||||
subkeyExpiry := expiry
|
||||
if k.Sig.KeyLifetimeSecs != nil {
|
||||
subkeyExpiry = k.PublicKey.CreationTime.Add(time.Duration(*k.Sig.KeyLifetimeSecs) * time.Second)
|
||||
}
|
||||
subs, err := parseSubGPGKey(ownerID, pubkey.KeyIdString(), k.PublicKey, subkeyExpiry)
|
||||
if err != nil {
|
||||
return nil, ErrGPGKeyParsing{ParseError: err}
|
||||
}
|
||||
subkeys[i] = subs
|
||||
}
|
||||
|
||||
// Check emails
|
||||
userEmails, err := user_model.GetEmailAddresses(ctx, ownerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emails := make([]*user_model.EmailAddress, 0, len(e.Identities))
|
||||
for _, ident := range e.Identities {
|
||||
if ident.Revoked(time.Now()) {
|
||||
continue
|
||||
}
|
||||
email := strings.ToLower(strings.TrimSpace(ident.UserId.Email))
|
||||
for _, e := range userEmails {
|
||||
if e.IsActivated && e.LowerEmail == email {
|
||||
emails = append(emails, e)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !verified {
|
||||
// In the case no email as been found
|
||||
if len(emails) == 0 {
|
||||
failedEmails := make([]string, 0, len(e.Identities))
|
||||
for _, ident := range e.Identities {
|
||||
failedEmails = append(failedEmails, ident.UserId.Email)
|
||||
}
|
||||
return nil, ErrGPGNoEmailFound{failedEmails, e.PrimaryKey.KeyIdString()}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := Base64EncPubKey(pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GPGKey{
|
||||
OwnerID: ownerID,
|
||||
KeyID: pubkey.KeyIdString(),
|
||||
PrimaryKeyID: "",
|
||||
Content: content,
|
||||
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
|
||||
ExpiredUnix: timeutil.TimeStamp(expiry.Unix()),
|
||||
Emails: emails,
|
||||
SubsKey: subkeys,
|
||||
Verified: verified,
|
||||
CanSign: pubkey.CanSign(),
|
||||
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanCertify: pubkey.PubKeyAlgo.CanSign(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// deleteGPGKey does the actual key deletion
|
||||
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
|
||||
if keyID == "" {
|
||||
return 0, errors.New("empty KeyId forbidden") // Should never happen but just to be sure
|
||||
}
|
||||
// Delete imported key
|
||||
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
|
||||
}
|
||||
|
||||
// DeleteGPGKey deletes GPG key information in database.
|
||||
func DeleteGPGKey(ctx context.Context, doer *user_model.User, id int64) (err error) {
|
||||
key, err := GetGPGKeyForUserByID(ctx, doer.ID, id)
|
||||
if err != nil {
|
||||
if IsErrGPGKeyNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("GetPublicKeyByID: %w", err)
|
||||
}
|
||||
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
_, err = deleteGPGKey(ctx, key.KeyID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func FindGPGKeyWithSubKeys(ctx context.Context, keyID string) ([]*GPGKey, error) {
|
||||
return db.Find[GPGKey](ctx, FindGPGKeyOptions{
|
||||
KeyID: keyID,
|
||||
IncludeSubKeys: true,
|
||||
})
|
||||
}
|
||||
161
models/asymkey/gpg_key_add.go
Normal file
161
models/asymkey/gpg_key_add.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
)
|
||||
|
||||
// __________________ ________ ____ __.
|
||||
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
|
||||
// / \ ___ | ___/ \ ___ | <_/ __ < | |
|
||||
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
|
||||
// \______ /|____| \______ / |____|__ \___ > ____|
|
||||
// \/ \/ \/ \/\/
|
||||
// _____ .___ .___
|
||||
// / _ \ __| _/__| _/
|
||||
// / /_\ \ / __ |/ __ |
|
||||
// / | \/ /_/ / /_/ |
|
||||
// \____|__ /\____ \____ |
|
||||
// \/ \/ \/
|
||||
|
||||
// This file contains functions relating to adding GPG Keys
|
||||
|
||||
// addGPGKey add key, import and subkeys to database
|
||||
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
|
||||
// Add GPGKeyImport
|
||||
if err = db.Insert(ctx, &GPGKeyImport{
|
||||
KeyID: key.KeyID,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG primary key.
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG subs key.
|
||||
for _, subkey := range key.SubsKey {
|
||||
if err := addGPGSubKey(ctx, subkey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addGPGSubKey add subkeys to database
|
||||
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
|
||||
// Save GPG primary key.
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
// Save GPG subs key.
|
||||
for _, subkey := range key.SubsKey {
|
||||
if err := addGPGSubKey(ctx, subkey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddGPGKey adds new public key to database.
|
||||
func AddGPGKey(ctx context.Context, ownerID int64, content, token, signature string) ([]*GPGKey, error) {
|
||||
ekeys, err := CheckArmoredGPGKeyString(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.WithTx2(ctx, func(ctx context.Context) ([]*GPGKey, error) {
|
||||
keys := make([]*GPGKey, 0, len(ekeys))
|
||||
|
||||
verified := false
|
||||
// Handle provided signature
|
||||
if signature != "" {
|
||||
signer, err := openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token), strings.NewReader(signature), nil)
|
||||
if err != nil {
|
||||
signer, err = openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token+"\n"), strings.NewReader(signature), nil)
|
||||
}
|
||||
if err != nil {
|
||||
signer, err = openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token+"\r\n"), strings.NewReader(signature), nil)
|
||||
}
|
||||
if err != nil {
|
||||
log.Debug("AddGPGKey CheckArmoredDetachedSignature failed: %v", err)
|
||||
return nil, ErrGPGInvalidTokenSignature{
|
||||
ID: ekeys[0].PrimaryKey.KeyIdString(),
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
ekeys = []*openpgp.Entity{signer}
|
||||
verified = true
|
||||
}
|
||||
|
||||
if len(ekeys) > 1 {
|
||||
id2key := map[string]*openpgp.Entity{}
|
||||
newEKeys := make([]*openpgp.Entity, 0, len(ekeys))
|
||||
for _, ekey := range ekeys {
|
||||
id := ekey.PrimaryKey.KeyIdString()
|
||||
if original, has := id2key[id]; has {
|
||||
// Coalesce this with the other one
|
||||
for _, subkey := range ekey.Subkeys {
|
||||
if subkey.PublicKey == nil {
|
||||
continue
|
||||
}
|
||||
found := false
|
||||
|
||||
for _, originalSubkey := range original.Subkeys {
|
||||
if originalSubkey.PublicKey == nil {
|
||||
continue
|
||||
}
|
||||
if originalSubkey.PublicKey.KeyId == subkey.PublicKey.KeyId {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
original.Subkeys = append(original.Subkeys, subkey)
|
||||
}
|
||||
}
|
||||
for name, identity := range ekey.Identities {
|
||||
if _, has := original.Identities[name]; has {
|
||||
continue
|
||||
}
|
||||
original.Identities[name] = identity
|
||||
}
|
||||
continue
|
||||
}
|
||||
id2key[id] = ekey
|
||||
newEKeys = append(newEKeys, ekey)
|
||||
}
|
||||
ekeys = newEKeys
|
||||
}
|
||||
|
||||
for _, ekey := range ekeys {
|
||||
// Key ID cannot be duplicated.
|
||||
has, err := db.GetEngine(ctx).Where("key_id=?", ekey.PrimaryKey.KeyIdString()).
|
||||
Get(new(GPGKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if has {
|
||||
return nil, ErrGPGKeyIDAlreadyUsed{ekey.PrimaryKey.KeyIdString()}
|
||||
}
|
||||
|
||||
key, err := parseGPGKey(ctx, ownerID, ekey, verified)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = addGPGKey(ctx, key, content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys, nil
|
||||
})
|
||||
}
|
||||
184
models/asymkey/gpg_key_commit_verification.go
Normal file
184
models/asymkey/gpg_key_commit_verification.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
)
|
||||
|
||||
// This file provides functions relating commit verification
|
||||
|
||||
// CommitVerification represents a commit validation of signature
|
||||
type CommitVerification struct {
|
||||
Verified bool
|
||||
Warning bool
|
||||
Reason string
|
||||
SigningUser *user_model.User // if Verified, then SigningUser is non-nil
|
||||
CommittingUser *user_model.User // if Verified, then CommittingUser is non-nil
|
||||
SigningEmail string
|
||||
SigningKey *GPGKey // FIXME: need to refactor it to a new name like "SigningGPGKey", it is also used in some templates
|
||||
SigningSSHKey *PublicKey
|
||||
TrustStatus string
|
||||
}
|
||||
|
||||
// SignCommit represents a commit with validation of signature.
|
||||
type SignCommit struct {
|
||||
Verification *CommitVerification
|
||||
*user_model.UserCommit
|
||||
}
|
||||
|
||||
const (
|
||||
// BadSignature is used as the reason when the signature has a KeyID that is in the db
|
||||
// but no key that has that ID verifies the signature. This is a suspicious failure.
|
||||
BadSignature = "gpg.error.probable_bad_signature"
|
||||
// BadDefaultSignature is used as the reason when the signature has a KeyID that matches the
|
||||
// default Key but is not verified by the default key. This is a suspicious failure.
|
||||
BadDefaultSignature = "gpg.error.probable_bad_default_signature"
|
||||
// NoKeyFound is used as the reason when no key can be found to verify the signature.
|
||||
NoKeyFound = "gpg.error.no_gpg_keys_found"
|
||||
)
|
||||
|
||||
func verifySign(s *packet.Signature, h hash.Hash, k *GPGKey) error {
|
||||
// Check if key can sign
|
||||
if !k.CanSign {
|
||||
return errors.New("key can not sign")
|
||||
}
|
||||
// Decode key
|
||||
pkey, err := base64DecPubKey(k.Content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pkey.VerifySignature(h, s)
|
||||
}
|
||||
|
||||
func hashAndVerify(sig *packet.Signature, payload string, k *GPGKey) (*GPGKey, error) {
|
||||
// Generating hash of commit
|
||||
hash, err := populateHash(sig.Hash, []byte(payload))
|
||||
if err != nil { // Skipping as failed to generate hash
|
||||
log.Error("PopulateHash: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
// We will ignore errors in verification as they don't need to be propagated up
|
||||
err = verifySign(sig, hash, k)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
func hashAndVerifyWithSubKeys(sig *packet.Signature, payload string, k *GPGKey) (*GPGKey, error) {
|
||||
verified, err := hashAndVerify(sig, payload, k)
|
||||
if err != nil || verified != nil {
|
||||
return verified, err
|
||||
}
|
||||
for _, sk := range k.SubsKey {
|
||||
verified, err := hashAndVerify(sig, payload, sk)
|
||||
if err != nil || verified != nil {
|
||||
return verified, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func HashAndVerifyWithSubKeysCommitVerification(sig *packet.Signature, payload string, k *GPGKey, committer, signer *user_model.User, email string) *CommitVerification {
|
||||
key, err := hashAndVerifyWithSubKeys(sig, payload, k)
|
||||
if err != nil { // Skipping failed to generate hash
|
||||
return &CommitVerification{
|
||||
CommittingUser: committer,
|
||||
Verified: false,
|
||||
Reason: "gpg.error.generate_hash",
|
||||
}
|
||||
}
|
||||
|
||||
if key != nil {
|
||||
return &CommitVerification{ // Everything is ok
|
||||
CommittingUser: committer,
|
||||
Verified: true,
|
||||
Reason: fmt.Sprintf("%s / %s", signer.Name, key.KeyID),
|
||||
SigningUser: signer,
|
||||
SigningKey: key,
|
||||
SigningEmail: email,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateTrustStatus will calculate the TrustStatus for a commit verification within a repository
|
||||
// There are several trust models in Gitea
|
||||
func CalculateTrustStatus(verification *CommitVerification, repoTrustModel repo_model.TrustModelType, isOwnerMemberCollaborator func(*user_model.User) (bool, error), keyMap *map[string]bool) error {
|
||||
if !verification.Verified {
|
||||
return nil
|
||||
}
|
||||
|
||||
// In the Committer trust model a signature is trusted if it matches the committer
|
||||
// - it doesn't matter if they're a collaborator, the owner, Gitea or Github
|
||||
// NB: This model is commit verification only
|
||||
if repoTrustModel == repo_model.CommitterTrustModel {
|
||||
// default to "unmatched"
|
||||
verification.TrustStatus = "unmatched"
|
||||
|
||||
// We can only verify against users in our database but the default key will match
|
||||
// against by email if it is not in the db.
|
||||
if (verification.SigningUser.ID != 0 &&
|
||||
verification.CommittingUser.ID == verification.SigningUser.ID) ||
|
||||
(verification.SigningUser.ID == 0 && verification.CommittingUser.ID == 0 &&
|
||||
verification.SigningUser.Email == verification.CommittingUser.Email) {
|
||||
verification.TrustStatus = "trusted"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now we drop to the more nuanced trust models...
|
||||
verification.TrustStatus = "trusted"
|
||||
|
||||
if verification.SigningUser.ID == 0 {
|
||||
// This commit is signed by the default key - but this key is not assigned to a user in the DB.
|
||||
|
||||
// However in the repo_model.CollaboratorCommitterTrustModel we cannot mark this as trusted
|
||||
// unless the default key matches the email of a non-user.
|
||||
if repoTrustModel == repo_model.CollaboratorCommitterTrustModel && (verification.CommittingUser.ID != 0 ||
|
||||
verification.SigningUser.Email != verification.CommittingUser.Email) {
|
||||
verification.TrustStatus = "untrusted"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check we actually have a GPG SigningKey
|
||||
var err error
|
||||
if verification.SigningKey != nil {
|
||||
var isMember bool
|
||||
if keyMap != nil {
|
||||
var has bool
|
||||
isMember, has = (*keyMap)[verification.SigningKey.KeyID]
|
||||
if !has {
|
||||
isMember, err = isOwnerMemberCollaborator(verification.SigningUser)
|
||||
(*keyMap)[verification.SigningKey.KeyID] = isMember
|
||||
}
|
||||
} else {
|
||||
isMember, err = isOwnerMemberCollaborator(verification.SigningUser)
|
||||
}
|
||||
|
||||
if !isMember {
|
||||
verification.TrustStatus = "untrusted"
|
||||
if verification.CommittingUser.ID != verification.SigningUser.ID {
|
||||
// The committing user and the signing user are not the same
|
||||
// This should be marked as questionable unless the signing user is a collaborator/team member etc.
|
||||
verification.TrustStatus = "unmatched"
|
||||
}
|
||||
} else if repoTrustModel == repo_model.CollaboratorCommitterTrustModel && verification.CommittingUser.ID != verification.SigningUser.ID {
|
||||
// The committing user and the signing user are not the same and our trustmodel states that they must match
|
||||
verification.TrustStatus = "unmatched"
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
147
models/asymkey/gpg_key_common.go
Normal file
147
models/asymkey/gpg_key_common.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/armor"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
)
|
||||
|
||||
// __________________ ________ ____ __.
|
||||
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
|
||||
// / \ ___ | ___/ \ ___ | <_/ __ < | |
|
||||
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
|
||||
// \______ /|____| \______ / |____|__ \___ > ____|
|
||||
// \/ \/ \/ \/\/
|
||||
// _________
|
||||
// \_ ___ \ ____ _____ _____ ____ ____
|
||||
// / \ \/ / _ \ / \ / \ / _ \ / \
|
||||
// \ \___( <_> ) Y Y \ Y Y ( <_> ) | \
|
||||
// \______ /\____/|__|_| /__|_| /\____/|___| /
|
||||
// \/ \/ \/ \/
|
||||
|
||||
// This file provides common functions relating to GPG Keys
|
||||
|
||||
// CheckArmoredGPGKeyString checks if the given key string is a valid GPG armored key.
|
||||
// The function returns the actual public key on success
|
||||
func CheckArmoredGPGKeyString(content string) (openpgp.EntityList, error) {
|
||||
list, err := openpgp.ReadArmoredKeyRing(strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, ErrGPGKeyParsing{err}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// Base64EncPubKey encode public key content to base 64
|
||||
func Base64EncPubKey(pubkey *packet.PublicKey) (string, error) {
|
||||
var w bytes.Buffer
|
||||
err := pubkey.Serialize(&w)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(w.Bytes()), nil
|
||||
}
|
||||
|
||||
func readerFromBase64(s string) (io.Reader, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewBuffer(bs), nil
|
||||
}
|
||||
|
||||
// base64DecPubKey decode public key content from base 64
|
||||
func base64DecPubKey(content string) (*packet.PublicKey, error) {
|
||||
b, err := readerFromBase64(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read key
|
||||
p, err := packet.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Check type
|
||||
pkey, ok := p.(*packet.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("key is not a public key")
|
||||
}
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// getExpiryTime extract the expiry time of primary key based on sig
|
||||
func getExpiryTime(e *openpgp.Entity) time.Time {
|
||||
expiry := time.Time{}
|
||||
// Extract self-sign for expire date based on : https://github.com/golang/crypto/blob/master/openpgp/keys.go#L165
|
||||
var selfSig *packet.Signature
|
||||
for _, ident := range e.Identities {
|
||||
if selfSig == nil {
|
||||
selfSig = ident.SelfSignature
|
||||
} else if ident.SelfSignature != nil && ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
|
||||
selfSig = ident.SelfSignature
|
||||
break
|
||||
}
|
||||
}
|
||||
if selfSig != nil && selfSig.KeyLifetimeSecs != nil {
|
||||
expiry = e.PrimaryKey.CreationTime.Add(time.Duration(*selfSig.KeyLifetimeSecs) * time.Second)
|
||||
}
|
||||
return expiry
|
||||
}
|
||||
|
||||
func populateHash(hashFunc crypto.Hash, msg []byte) (hash.Hash, error) {
|
||||
h := hashFunc.New()
|
||||
if _, err := h.Write(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// readArmoredSign read an armored signature block with the given type. https://sourcegraph.com/github.com/golang/crypto/-/blob/openpgp/read.go#L24:6-24:17
|
||||
func readArmoredSign(r io.Reader) (body io.Reader, err error) {
|
||||
block, err := armor.Decode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block.Type != openpgp.SignatureType {
|
||||
return nil, fmt.Errorf("expected '%s', got: %s", openpgp.SignatureType, block.Type)
|
||||
}
|
||||
return block.Body, nil
|
||||
}
|
||||
|
||||
func ExtractSignature(s string) (*packet.Signature, error) {
|
||||
r, err := readArmoredSign(strings.NewReader(s))
|
||||
if err != nil {
|
||||
return nil, errors.New("Failed to read signature armor")
|
||||
}
|
||||
p, err := packet.Read(r)
|
||||
if err != nil {
|
||||
return nil, errors.New("Failed to read signature packet")
|
||||
}
|
||||
sig, ok := p.(*packet.Signature)
|
||||
if !ok {
|
||||
return nil, errors.New("Packet is not a signature")
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
func TryGetKeyIDFromSignature(sig *packet.Signature) string {
|
||||
if sig.IssuerKeyId != nil && (*sig.IssuerKeyId) != 0 {
|
||||
return fmt.Sprintf("%016X", *sig.IssuerKeyId)
|
||||
}
|
||||
if len(sig.IssuerFingerprint) > 0 {
|
||||
return fmt.Sprintf("%016X", sig.IssuerFingerprint[12:20])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
47
models/asymkey/gpg_key_import.go
Normal file
47
models/asymkey/gpg_key_import.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
)
|
||||
|
||||
// __________________ ________ ____ __.
|
||||
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
|
||||
// / \ ___ | ___/ \ ___ | <_/ __ < | |
|
||||
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
|
||||
// \______ /|____| \______ / |____|__ \___ > ____|
|
||||
// \/ \/ \/ \/\/
|
||||
// .___ __
|
||||
// | | _____ ______ ____________/ |_
|
||||
// | |/ \\____ \ / _ \_ __ \ __\
|
||||
// | | Y Y \ |_> > <_> ) | \/| |
|
||||
// |___|__|_| / __/ \____/|__| |__|
|
||||
// \/|__|
|
||||
|
||||
// This file contains functions related to the original import of a key
|
||||
|
||||
// GPGKeyImport the original import of key
|
||||
type GPGKeyImport struct {
|
||||
KeyID string `xorm:"pk CHAR(16) NOT NULL"`
|
||||
Content string `xorm:"MEDIUMTEXT NOT NULL"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(GPGKeyImport))
|
||||
}
|
||||
|
||||
// GetGPGImportByKeyID returns the import public armored key by given KeyID.
|
||||
func GetGPGImportByKeyID(ctx context.Context, keyID string) (*GPGKeyImport, error) {
|
||||
key := new(GPGKeyImport)
|
||||
has, err := db.GetEngine(ctx).ID(keyID).Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrGPGKeyImportNotExist{keyID}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
38
models/asymkey/gpg_key_list.go
Normal file
38
models/asymkey/gpg_key_list.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
)
|
||||
|
||||
type GPGKeyList []*GPGKey
|
||||
|
||||
func (keys GPGKeyList) keyIDs() []string {
|
||||
ids := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
ids[i] = key.KeyID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (keys GPGKeyList) LoadSubKeys(ctx context.Context) error {
|
||||
subKeys := make([]*GPGKey, 0, len(keys))
|
||||
if err := db.GetEngine(ctx).In("primary_key_id", keys.keyIDs()).Find(&subKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
subKeysMap := make(map[string][]*GPGKey, len(subKeys))
|
||||
for _, key := range subKeys {
|
||||
subKeysMap[key.PrimaryKeyID] = append(subKeysMap[key.PrimaryKeyID], key)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if subKeys, ok := subKeysMap[key.KeyID]; ok {
|
||||
key.SubsKey = subKeys
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
428
models/asymkey/gpg_key_test.go
Normal file
428
models/asymkey/gpg_key_test.go
Normal file
@@ -0,0 +1,428 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"github.com/ProtonMail/go-crypto/openpgp"
|
||||
"github.com/ProtonMail/go-crypto/openpgp/packet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCheckArmoredGPGKeyString(t *testing.T) {
|
||||
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
mQENBFh91QoBCADciaDd7aqegYkn4ZIG7J0p1CRwpqMGjxFroJEMg6M1ZiuEVTRv
|
||||
z49P4kcr1+98NvFmcNc+x5uJgvPCwr/N8ZW5nqBUs2yrklbFF4MeQomyZJJegP8m
|
||||
/dsRT3BwIT8YMUtJuCj0iqD9vuKYfjrztcMgC1sYwcE9E9OlA0pWBvUdU2i0TIB1
|
||||
vOq6slWGvHHa5l5gPfm09idlVxfH5+I+L1uIMx5ovbiVVU5x2f1AR1T18f0t2TVN
|
||||
0agFTyuoYE1ATmvJHmMcsfgM1Gpd9hIlr9vlupT2kKTPoNzVzsJsOU6Ku/Lf/bac
|
||||
mF+TfSbRCtmG7dkYZ4metLj7zG/WkW8IvJARABEBAAG0HUFudG9pbmUgR0lSQVJE
|
||||
IDxzYXBrQHNhcGsuZnI+iQFUBBMBCAA+FiEEEIOwJg/1vpF1itJ4roJVuKDYKOQF
|
||||
Alh91QoCGwMFCQPCZwAFCwkIBwIGFQgJCgsCBBYCAwECHgECF4AACgkQroJVuKDY
|
||||
KORreggAlIkC2QjHP5tb7b0+LksB2JMXdY+UzZBcJxtNmvA7gNQaGvWRrhrbePpa
|
||||
MKDP+3A4BPDBsWFbbB7N56vQ5tROpmWbNKuFOVER4S1bj0JZV0E+xkDLqt9QwQtQ
|
||||
ojd7oIZJwDUwdud1PvCza2mjgBqqiFE+twbc3i9xjciCGspMniUul1eQYLxRJ0w+
|
||||
sbvSOUnujnq5ByMSz9ij00O6aiPfNQS5oB5AALfpjYZDvWAAljLVrtmlQJWZ6dZo
|
||||
T/YNwsW2dECPuti8+Nmu5FxPGDTXxdbnRaeJTQ3T6q1oUVAv7yTXBx5NXfXkMa5i
|
||||
iEayQIH8Joq5Ev5ja/lRGQQhArMQ2bkBDQRYfdUKAQgAv7B3coLSrOQbuTZSlgWE
|
||||
QeT+7DWbmqE1LAQA1pQPcUPXLBUVd60amZJxF9nzUYcY83ylDi0gUNJS+DJGOXpT
|
||||
pzX2IOuOMGbtUSeKwg5s9O4SUO7f2yCc3RGaegER5zgESxelmOXG+b/hoNt7JbdU
|
||||
JtxcnLr91Jw2PBO/Xf0ZKJ01CQG2Yzdrrj6jnrHyx94seHy0i6xH1o0OuvfVMLfN
|
||||
/Vbb/ZHh6ym2wHNqRX62b0VAbchcJXX/MEehXGknKTkO6dDUd+mhRgWMf9ZGRFWx
|
||||
ag4qALimkf1FXtAyD0vxFYeyoWUQzrOvUsm2BxIN/986R08fhkBQnp5nz07mrU02
|
||||
cQARAQABiQE8BBgBCAAmFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAlh91QoCGwwF
|
||||
CQPCZwAACgkQroJVuKDYKOT32wf/UZqMdPn5OhyhffFzjQx7wolrf92WkF2JkxtH
|
||||
6c3Htjlt/p5RhtKEeErSrNAxB4pqB7dznHaJXiOdWEZtRVXXjlNHjrokGTesqtKk
|
||||
lHWtK62/MuyLdr+FdCl68F3ewuT2iu/MDv+D4HPqA47zma9xVgZ9ZNwJOpv3fCOo
|
||||
RfY66UjGEnfgYifgtI5S84/mp2jaSc9UNvlZB6RSf8cfbJUL74kS2lq+xzSlf0yP
|
||||
Av844q/BfRuVsJsK1NDNG09LC30B0l3LKBqlrRmRTUMHtgchdX2dY+p7GPOoSzlR
|
||||
MkM/fdpyc2hY7Dl/+qFmN5MG5yGmMpQcX+RNNR222ibNC1D3wg==
|
||||
=i9b7
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
|
||||
key, err := CheckArmoredGPGKeyString(testGPGArmor)
|
||||
assert.NoError(t, err, "Could not parse a valid GPG public armored rsa key", key)
|
||||
// TODO verify value of key
|
||||
}
|
||||
|
||||
func TestCheckArmoredbrainpoolP256r1GPGKeyString(t *testing.T) {
|
||||
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
Version: GnuPG v2
|
||||
|
||||
mFMEV6HwkhMJKyQDAwIIAQEHAgMEUsvJO/j5dFMRRj67qeZC9fSKBsGZdOHRj2+6
|
||||
8wssmbUuLTfT/ZjIbExETyY8hFnURRGpD2Ifyz0cKjXcbXfJtrQTRm9vYmFyIDxm
|
||||
b29AYmFyLmRlPoh/BBMTCAAnBQJZOsDIAhsDBQkJZgGABQsJCAcCBhUICQoLAgQW
|
||||
AgMBAh4BAheAAAoJEGuJTd/DBMzmNVQA/2beUrv1yU4gyvCiPDEm3pK42cSfaL5D
|
||||
muCtPCUg9hlWAP4yq6M78NW8STfsXgn6oeziMYiHSTmV14nOamLuwwDWM7hXBFeh
|
||||
8JISCSskAwMCCAEBBwIDBG3A+XfINAZp1CTse2mRNgeUE5DbUtEpO8ALXKA1UQsQ
|
||||
DLKq27b7zTgawgXIGUGP6mWsJ5oH7MNAJ/uKTsYmX40DAQgHiGcEGBMIAA8FAleh
|
||||
8JICGwwFCQlmAYAACgkQa4lN38MEzOZwKAD/QKyerAgcvzzLaqvtap3XvpYcw9tc
|
||||
OyjLLnFQiVmq7kEA/0z0CQe3ZQiQIq5zrs7Nh1XRkFAo8GlU/SGC9XFFi722
|
||||
=ZiSe
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
|
||||
key, err := CheckArmoredGPGKeyString(testGPGArmor)
|
||||
assert.NoError(t, err, "Could not parse a valid GPG public armored brainpoolP256r1 key", key)
|
||||
// TODO verify value of key
|
||||
}
|
||||
|
||||
func TestExtractSignature(t *testing.T) {
|
||||
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
mQENBFh91QoBCADciaDd7aqegYkn4ZIG7J0p1CRwpqMGjxFroJEMg6M1ZiuEVTRv
|
||||
z49P4kcr1+98NvFmcNc+x5uJgvPCwr/N8ZW5nqBUs2yrklbFF4MeQomyZJJegP8m
|
||||
/dsRT3BwIT8YMUtJuCj0iqD9vuKYfjrztcMgC1sYwcE9E9OlA0pWBvUdU2i0TIB1
|
||||
vOq6slWGvHHa5l5gPfm09idlVxfH5+I+L1uIMx5ovbiVVU5x2f1AR1T18f0t2TVN
|
||||
0agFTyuoYE1ATmvJHmMcsfgM1Gpd9hIlr9vlupT2kKTPoNzVzsJsOU6Ku/Lf/bac
|
||||
mF+TfSbRCtmG7dkYZ4metLj7zG/WkW8IvJARABEBAAG0HUFudG9pbmUgR0lSQVJE
|
||||
IDxzYXBrQHNhcGsuZnI+iQFUBBMBCAA+FiEEEIOwJg/1vpF1itJ4roJVuKDYKOQF
|
||||
Alh91QoCGwMFCQPCZwAFCwkIBwIGFQgJCgsCBBYCAwECHgECF4AACgkQroJVuKDY
|
||||
KORreggAlIkC2QjHP5tb7b0+LksB2JMXdY+UzZBcJxtNmvA7gNQaGvWRrhrbePpa
|
||||
MKDP+3A4BPDBsWFbbB7N56vQ5tROpmWbNKuFOVER4S1bj0JZV0E+xkDLqt9QwQtQ
|
||||
ojd7oIZJwDUwdud1PvCza2mjgBqqiFE+twbc3i9xjciCGspMniUul1eQYLxRJ0w+
|
||||
sbvSOUnujnq5ByMSz9ij00O6aiPfNQS5oB5AALfpjYZDvWAAljLVrtmlQJWZ6dZo
|
||||
T/YNwsW2dECPuti8+Nmu5FxPGDTXxdbnRaeJTQ3T6q1oUVAv7yTXBx5NXfXkMa5i
|
||||
iEayQIH8Joq5Ev5ja/lRGQQhArMQ2bkBDQRYfdUKAQgAv7B3coLSrOQbuTZSlgWE
|
||||
QeT+7DWbmqE1LAQA1pQPcUPXLBUVd60amZJxF9nzUYcY83ylDi0gUNJS+DJGOXpT
|
||||
pzX2IOuOMGbtUSeKwg5s9O4SUO7f2yCc3RGaegER5zgESxelmOXG+b/hoNt7JbdU
|
||||
JtxcnLr91Jw2PBO/Xf0ZKJ01CQG2Yzdrrj6jnrHyx94seHy0i6xH1o0OuvfVMLfN
|
||||
/Vbb/ZHh6ym2wHNqRX62b0VAbchcJXX/MEehXGknKTkO6dDUd+mhRgWMf9ZGRFWx
|
||||
ag4qALimkf1FXtAyD0vxFYeyoWUQzrOvUsm2BxIN/986R08fhkBQnp5nz07mrU02
|
||||
cQARAQABiQE8BBgBCAAmFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAlh91QoCGwwF
|
||||
CQPCZwAACgkQroJVuKDYKOT32wf/UZqMdPn5OhyhffFzjQx7wolrf92WkF2JkxtH
|
||||
6c3Htjlt/p5RhtKEeErSrNAxB4pqB7dznHaJXiOdWEZtRVXXjlNHjrokGTesqtKk
|
||||
lHWtK62/MuyLdr+FdCl68F3ewuT2iu/MDv+D4HPqA47zma9xVgZ9ZNwJOpv3fCOo
|
||||
RfY66UjGEnfgYifgtI5S84/mp2jaSc9UNvlZB6RSf8cfbJUL74kS2lq+xzSlf0yP
|
||||
Av844q/BfRuVsJsK1NDNG09LC30B0l3LKBqlrRmRTUMHtgchdX2dY+p7GPOoSzlR
|
||||
MkM/fdpyc2hY7Dl/+qFmN5MG5yGmMpQcX+RNNR222ibNC1D3wg==
|
||||
=i9b7
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
keys, err := CheckArmoredGPGKeyString(testGPGArmor)
|
||||
require.NotEmpty(t, keys)
|
||||
|
||||
ekey := keys[0]
|
||||
assert.NoError(t, err, "Could not parse a valid GPG armored key", ekey)
|
||||
|
||||
pubkey := ekey.PrimaryKey
|
||||
content, err := Base64EncPubKey(pubkey)
|
||||
assert.NoError(t, err, "Could not base64 encode a valid PublicKey content", ekey)
|
||||
|
||||
key := &GPGKey{
|
||||
KeyID: pubkey.KeyIdString(),
|
||||
Content: content,
|
||||
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
|
||||
CanSign: pubkey.CanSign(),
|
||||
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
|
||||
CanCertify: pubkey.PubKeyAlgo.CanSign(),
|
||||
}
|
||||
|
||||
cannotsignkey := &GPGKey{
|
||||
KeyID: pubkey.KeyIdString(),
|
||||
Content: content,
|
||||
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
|
||||
CanSign: false,
|
||||
CanEncryptComms: false,
|
||||
CanEncryptStorage: false,
|
||||
CanCertify: false,
|
||||
}
|
||||
|
||||
testGoodSigArmor := `-----BEGIN PGP SIGNATURE-----
|
||||
|
||||
iQEzBAABCAAdFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAljAiQIACgkQroJVuKDY
|
||||
KORvCgf6A/Ehh0r7QbO2tFEghT+/Ab+bN7jRN3zP9ed6/q/ophYmkrU0NibtbJH9
|
||||
AwFVdHxCmj78SdiRjaTKyevklXw34nvMftmvnOI4lBNUdw6KWl25/n/7wN0l2oZW
|
||||
rW3UawYpZgodXiLTYarfEimkDQmT67ArScjRA6lLbkEYKO0VdwDu+Z6yBUH3GWtm
|
||||
45RkXpnsF6AXUfuD7YxnfyyDE1A7g7zj4vVYUAfWukJjqow/LsCUgETETJOqj9q3
|
||||
52/oQDs04fVkIEtCDulcY+K/fKlukBPJf9WceNDEqiENUzN/Z1y0E+tJ07cSy4bk
|
||||
yIJb+d0OAaG8bxloO7nJq4Res1Qa8Q==
|
||||
=puvG
|
||||
-----END PGP SIGNATURE-----`
|
||||
testGoodPayload := `tree 56ae8d2799882b20381fc11659db06c16c68c61a
|
||||
parent c7870c39e4e6b247235ca005797703ec4254613f
|
||||
author Antoine GIRARD <sapk@sapk.fr> 1489012989 +0100
|
||||
committer Antoine GIRARD <sapk@sapk.fr> 1489012989 +0100
|
||||
|
||||
Goog GPG
|
||||
`
|
||||
|
||||
testBadSigArmor := `-----BEGIN PGP SIGNATURE-----
|
||||
|
||||
iQEzBAABCAAdFiEE5yr4rn9ulbdMxJFiPYI/ySNrtNkFAljAiYkACgkQPYI/ySNr
|
||||
tNmDdQf+NXhVRiOGt0GucpjJCGrOnK/qqVUmQyRUfrqzVUdb/1/Ws84V5/wE547I
|
||||
6z3oxeBKFsJa1CtIlxYaUyVhYnDzQtphJzub+Aw3UG0E2ywiE+N7RCa1Ufl7pPxJ
|
||||
U0SD6gvNaeTDQV/Wctu8v8DkCtEd3N8cMCDWhvy/FQEDztVtzm8hMe0Vdm0ozEH6
|
||||
P0W93sDNkLC5/qpWDN44sFlYDstW5VhMrnF0r/ohfaK2kpYHhkPk7WtOoHSUwQSg
|
||||
c4gfhjvXIQrWFnII1Kr5jFGlmgNSR02qpb31VGkMzSnBhWVf2OaHS/kI49QHJakq
|
||||
AhVDEnoYLCgoDGg9c3p1Ll2452/c6Q==
|
||||
=uoGV
|
||||
-----END PGP SIGNATURE-----`
|
||||
testBadPayload := `tree 3074ff04951956a974e8b02d57733b0766f7cf6c
|
||||
parent fd3577542f7ad1554c7c7c0eb86bb57a1324ad91
|
||||
author Antoine GIRARD <sapk@sapk.fr> 1489013107 +0100
|
||||
committer Antoine GIRARD <sapk@sapk.fr> 1489013107 +0100
|
||||
|
||||
Unknown GPG key with good email
|
||||
`
|
||||
// Reading Sign
|
||||
goodSig, err := ExtractSignature(testGoodSigArmor)
|
||||
assert.NoError(t, err, "Could not parse a valid GPG armored signature", testGoodSigArmor)
|
||||
badSig, err := ExtractSignature(testBadSigArmor)
|
||||
assert.NoError(t, err, "Could not parse a valid GPG armored signature", testBadSigArmor)
|
||||
|
||||
// Generating hash of commit
|
||||
goodHash, err := populateHash(goodSig.Hash, []byte(testGoodPayload))
|
||||
assert.NoError(t, err, "Could not generate a valid hash of payload", testGoodPayload)
|
||||
badHash, err := populateHash(badSig.Hash, []byte(testBadPayload))
|
||||
assert.NoError(t, err, "Could not generate a valid hash of payload", testBadPayload)
|
||||
|
||||
// Verify
|
||||
err = verifySign(goodSig, goodHash, key)
|
||||
assert.NoError(t, err, "Could not validate a good signature")
|
||||
err = verifySign(badSig, badHash, key)
|
||||
assert.Error(t, err, "Validate a bad signature")
|
||||
err = verifySign(goodSig, goodHash, cannotsignkey)
|
||||
assert.Error(t, err, "Validate a bad signature with a kay that can not sign")
|
||||
}
|
||||
|
||||
func TestCheckGPGUserEmail(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
_ = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
|
||||
|
||||
testEmailWithUpperCaseLetters := `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
Version: GnuPG v1
|
||||
|
||||
mQENBFlEBvMBCADe+EQcfv/aKbMFy7YB8e/DE+hY39sfjvdvSgeXtNhfmYvIOUjT
|
||||
ORMCvce2Oxzb3HTI0rjYsJpzo9jEQ53dB3vdr0ne5Juby6N7QPjof3NR+ko50Ki2
|
||||
0ilOjYuA0v6VHLIn70UBa9NEf+XDuE7P+Lbtl2L9B9OMXtcTAZoA3cJySgtNFNIG
|
||||
AVefPi8LeOcekL39wxJEA8OzdCyO5oENEwAG1tzjy9DDNJf74/dBBh2NiXeSeMxZ
|
||||
RYeYzqEa2UTDP1fkUl7d2/hV36cKZWZr+l4SQ5bM7HeLj2SsfabLfqKoVWgkfAzQ
|
||||
VwtkbRpzMiDLMte2ZAyTJUc+77YbFoyAmOcjABEBAAG0HFVzZXIgT25lIDxVc2Vy
|
||||
MUBFeGFtcGxlLmNvbT6JATgEEwECACIFAllEBvMCGwMGCwkIBwMCBhUIAgkKCwQW
|
||||
AgMBAh4BAheAAAoJEFMOzOY274DFw5EIAKc4jiYaMb1HDKrSv0tphgNxPFEY83/J
|
||||
9CZggO7BINxlb7z/lH1i0U2h2Ha9E3VJTJQF80zBCaIvtU2UNrgVmSKoc0BdE/2S
|
||||
rS9MAl29sXxf1BfvXHu12Suvo8O/ZFP45Vm/3kkHuasHyOV1GwUWnynt1qo0zUEn
|
||||
WMIcB8USlmMT1TnSb10YKBd/BpGF3crFDJLfAHRumZUk4knDDWUOWy5RCOG8cedc
|
||||
VTAhfdoKRRO3PchOfz6Rls/hew12mRNayqxuLQl2+BX+BWu+25dR3qyiS+twLbk6
|
||||
Rjpb0S+RQTkYIUoI0SEZpxcTZso11xF5KNpKZ9aAoiLJqkNF5h4oPSe5AQ0EWUQG
|
||||
8wEIALiMMqh3NF3ON/z7hQfeU24bCl/WdfJwCR9CWU/jx4X4gZq2C2aGtytGN5g/
|
||||
qoYQ3poTOPzh/4Dvs+r6CtHqi0CvPiEOfSxzmaK+F+vA0GMn2i3Sx5gq/VB0mr+j
|
||||
RIYMCjf68Tifo2RAT0VDzn6t304l5+VPr4OgbobMRH+wDe7Hhd2pZXl7ty8DooBn
|
||||
vqaqoKgdiccUXGBKe4Oihl/oZ4qrYH6K4ACP1Sco1rs4mNeKDAW8k/Y7zLjg6d59
|
||||
g0YQ1YI+CX/bKB7/cpMHLupyMLqvCcqIpjBXRJNMdjuMHgKckjr89DwnqXqgXz7W
|
||||
u0B39MZQn9nn6vq8BdkoDFgrTQ8AEQEAAYkBHwQYAQIACQUCWUQG8wIbDAAKCRBT
|
||||
DszmNu+Axf4IB/0S9NTc6kpwW+ZPZQNTWR5oKDEaXVCRLccOlkt33txMvk/z2jNM
|
||||
trEke99ss5L1bRyWB5fRA+XVsPmW9kIk8pmGFmxqp2nSxr9m9rlL5oTYH8u6dfSm
|
||||
zwGhqkfITjPI7hyNN52PLANwoS0o4dLzIE65ewigx6cnRlrT2IENObxG/tlxaYg1
|
||||
NHahJX0uFlVk0W0bLBrs3fTDw1lS/N8HpyQb+5ryQmiIb2a48aygCS/h2qeRlX1d
|
||||
Q0KHb+QcycSgbDx0ZAvdIacuKvBBcbxrsmFUI4LR+oIup0G9gUc0roPvr014jYQL
|
||||
7f8r/8fpcN8t+I/41QHCs6L/BEIdTHW3rTQ6
|
||||
=zHo9
|
||||
-----END PGP PUBLIC KEY BLOCK-----`
|
||||
|
||||
keys, err := AddGPGKey(t.Context(), 1, testEmailWithUpperCaseLetters, "", "")
|
||||
assert.NoError(t, err)
|
||||
if assert.NotEmpty(t, keys) {
|
||||
key := keys[0]
|
||||
if assert.Len(t, key.Emails, 1) {
|
||||
assert.Equal(t, "user1@example.com", key.Emails[0].Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckGParseGPGExpire(t *testing.T) {
|
||||
testIssue6599 := `-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
mQINBFlFJRsBEAClNcRT5El+EaTtQEYs/eNAhr/bqiyt6fPMtabDq2x6a8wFWMX0
|
||||
yhRh4vZuLzhi95DU/pmhZARt0W15eiN0AhWdOKxry1KtZNiZBzMm1f0qZJMuBG8g
|
||||
YJ7aRkCqdWRxy1Q+U/yhr6z7ucD8/yn7u5wke/jsPdF/L8I/HKNHoawI1FcMC9v+
|
||||
QoG3pIX8NVGdzaUYygFG1Gxofc3pb3i4pcpOUxpOP12t6PfwTCoAWZtRLgxTdwWn
|
||||
DGvY6SCIIIxn4AC6u3+tHz9HDXx+4eiB7VxMsiIsEuHW9DVBzen9jFNNjRnNaFkL
|
||||
pTAFOyGsSzGRGhuJpb7j7hByoWkaItqaw+clnzVrDqhfbxS1B8dmgMANh9pzNsv7
|
||||
J/OnNdGsbgDX5RytSKMaXclK2ZGH6Txatgezo167z6EdthNR1daj1QfqWADiqKbR
|
||||
UXp7Xz9b+/CBedUNEXPbIExva9mPsFJo2IEntRGtdhhjuO4a6HLG7k1i0o0dHxqb
|
||||
a9HrOW7fO902L7JHIgnjpDWDGLGGnVGcGWdEEZggfpnvjxADeTgyMb2XkALTQ0GG
|
||||
yRywByxG8/zjXeEkqUng/mxNbBCcHcuIRVsqYwGQLiLubYxnRudqtNst8Tdu+0+q
|
||||
AL0bb8ueQC1M3WHsMUxvTjknFJdJzRicNyLf6AdfRv6yy6Ra+t4SFoSbsQARAQAB
|
||||
tB90YXN0eXRlYSA8dGFzdHl0ZWFAdGFzdHl0ZWEuZGU+iQJXBBMBCABBAhsDBQsJ
|
||||
CAcCBhUICQoLAgQWAgMBAh4BAheAAhkBFiEE1bTEO0ioefY1KTbmWTRuDqNcZ+UF
|
||||
Alyo2K0FCQVE5xIACgkQWTRuDqNcZ+UTFA/+IygU02oz19tRVNgVmKyXv1GhnkaY
|
||||
O/oGxp7cRGJ0gf0bjhbJpFf4+6OHaS0ei47Qp8XTuStfWry6V6rXLSV/ZOOhFaCq
|
||||
VpFvoG2JcPZbSTB+CR/lL5fWwx3w5PAOUwipGRFs7mYLgy8U/E3U7u+ioP4ZqCXS
|
||||
heclyXAGNlrjUwvwOWRLxvcEQr4ztQR0Lk2tv1QYYDzbaXUSdnsM1YK9YpYP7BE2
|
||||
luKtwwXaubdwcXPs96FEmGLGfsWC/dWnAxkYXPo9q7O6c5GKbGiP3xFhBaBCzzm0
|
||||
PAqAJ+NyIWL63yI1aNNz4xC1marU7UPLzBnv5fG1WdscYqAbj8XbZ96mPPM80y0A
|
||||
j5/7YecRXce4yedxRHhi3bD8MEzDMHWfkQPpWCZj/KwjDFiZwSMgpQUqeAllDKQx
|
||||
Ld0CLkLuUe20b+/5h6dGtGpoACkoOPxMl6zi9uihztvR5iYdkwnmcxKmnEtz+WV4
|
||||
1efhS3QRZro3QAHjhCqU1Xjl0hnwSCgP5nUhTq6dJqgeZ7c5D4Uhg55MXwQ68Oe4
|
||||
NrQfhdO8IOSVPDPDEeQ2kuP7/HEZsjKZBMKhKoUcdXM6y9T2tYw3wv5JDuDxT2Q1
|
||||
3IuFVr1uFm/spVyFCpPpPSQM1wfdtoPLRjiJ/KVh777AWUlywP2b7cWyKShYJb4P
|
||||
QzTQ/udx94916cSJAlQEEwEIAD4WIQTVtMQ7SKh59jUpNuZZNG4Oo1xn5QUCWUUl
|
||||
GwIbAwUJA8ORBQULCQgHAgYVCAkKCwIEFgIDAQIeAQIXgAAKCRBZNG4Oo1xn5Uoa
|
||||
D/9tdmXECDZS1th0xmdNIsecxhI9dBGJyaJwfhH7UVkL+e86EsmTSzyJhBAepDDe
|
||||
4wTEaW/NnjVX+ulO7rKFN4/qvSCOaeIdP0MEn7zfZVVKG8gMW4mb/piLvUnsZvsM
|
||||
eWfv9AL/b3H1MRkl9S6XsE0ove72pmbBSZEhh2rNHqf+tIGr/RTtn80efTv3w+75
|
||||
0UJtaFPsAKoAzNRy+ouhf9IHy9pEMJRA/hZ0Ho04QCDAC65mWz7iwI7v9VRDVfng
|
||||
UjJPJahoM4vTpB30vJiFYT2oFTgdxGckfEUezsk8Rx/o6x4u6igKypPbeqM/7SMw
|
||||
H61sCWR7nHJhCK55WeEIbzHEhwCZTf1pgvHj5oGUOjzksp2DmFV3ma3WCh8JyqyA
|
||||
zw2OvOXBlayIaGIoyD5tSHS40rTi9JmOUfhg6WPN3MIrvsSVEV7JNdiZs/Tb07eQ
|
||||
l71O7wv/LXZZCYP5NLV0PJbN2pHMf8cysWulfHN/mNgpEiLJpPBYVVyVbzWLg54X
|
||||
FcNQMrT70kRF4M2GBRahXchkWi6+1pd3jPtvCFfcNiYBnHcrKu2R/UdSpFYdclDi
|
||||
y6u7xMxXt0AVeLLtlXq7+ChOANMH5aPdUjCXeQDNJawLx41KL9fETsjScodmmpKi
|
||||
SNhkC03FNfbkPJzZthoTxCfUBQeHYWgDpN3Gjb/OdSWC34kCVwQTAQgAQQIbAwUJ
|
||||
A8ORBQULCQgHAgYVCAkKCwIEFgIDAQIeAQIXgBYhBNW0xDtIqHn2NSk25lk0bg6j
|
||||
XGflBQJcqNWQAhkBAAoJEFk0bg6jXGfldcEP/iz4UbJPd/kr8D008ky7vI7hnYs8
|
||||
VQIxL6ljQJ75XmVx/Lz1MVo4Vdsu6+qEta5gvqbGwjuEugaHcFVbHCZEBKI0QHSQ
|
||||
UNHfXT8eZP/BwwFWawUokLTbF//Dg5xd5ejo/TeltNleyq1r0AoxcoMv1srrY4yK
|
||||
GvWE5V8SVSi/E71y4VarS58ZH3NZ6sW5slnYvgAHTVgOjkVvMYk5JmrWsFsycYf8
|
||||
Rs5BvCuXQpUV9N8UFfW8pAxYhLvUTqhf34m24syyFn9j1udEO1c+IeX7h7hX2CFL
|
||||
+P6wS9Ok2Z++IKvhIXLy/OoBULxKXjM04aLxDDlRW3qEyeLKvbFiEHGSnlaDz27L
|
||||
LBAGGRxzLLr0g1evV33AHUU2N8pATnzXHJaRiMjExjRi5IkHjbiEaxiqIwr8CSnS
|
||||
4RlZ+owxhJ/4MjnsqBL3ELhkSnN+HGkPBQkbFDhCm0ICm78EK2x4+bWo/YUUfoky
|
||||
Hq92XB6RNbO0RcdGyltFsJ02Ev20Hc4MClF7jT7xm7VJfbeYNmxZ6GNXZ7kEsl87
|
||||
7qzFtr2BcEfw/ieyyoOrwAC9FBJc/9CALex3p3TGWpM43C+IdqZIsr9QHAzvJfY7
|
||||
/n5/wJyCPhIZSSE3b8PZRIAdh6NA2IF877OCzIl2UFUNJE1zaEcTvjxZzCZ1SHGU
|
||||
YzQeSbODHUuPDbhytBJnZW50b29AdGFzdHl0ZWEuZGWJAlQEEwEIAD4CGwMFCwkI
|
||||
BwIGFQoJCAsCBBYCAwECHgECF4AWIQTVtMQ7SKh59jUpNuZZNG4Oo1xn5QUCXKjY
|
||||
rQUJBUTnEgAKCRBZNG4Oo1xn5VhkD/42pGYstRMvrO37wJDnnLDm+ZPb0RGy80Ru
|
||||
Nt3S6OmU3TFuU9mj/FBc8VNs6xr0CCMVVM/CXX1gXCHhADss1YDaOcRsl5wVJ6EF
|
||||
tbpEXT/USMw3dV4Y8OYUSNxyEitzKt25CnOdWGPYaJG3YOtAR0qwopMiAgLrgLy9
|
||||
mugXqnrykF7yN27i6iRi2Jk9K7tSb4owpw1kuToJrNGThAkz+3nvXG5oRiYFTlH3
|
||||
pATx34r+QOg1o3giomP49cP4ohxvQFP90w2/cURhLqEKdR6N1X0bTXRQvy8G+4Wl
|
||||
QMl8WYPzQUrKGMgj/f7Uhb3pFFLCcnCaYFdUj+fvshg5NMLGVztENz9x7Vr5n51o
|
||||
Hj9WuM3s65orKrGhMUk4NJCsQWJUHnSNsEXsuir9ocwCv4unIJuoOukNJigL4d5o
|
||||
i0fKPKuLpdIah1dmcrWLIoid0wPeA8unKQg3h6VL5KXpUudo8CiPw/kk1KTLtYQR
|
||||
7lezb1oldqfWgGHmqnOK+u6sOhxGj2fcrTi4139ULMph+LCIB3JEtgaaw4lTTt0t
|
||||
S8h6db6LalzsQyL2sIHgl/rmLmZ5sqZhmi/DsAjZWfpz+inUP6rgap+OgAmtLCit
|
||||
BwsDAy7ux44mUNtW1KExuY2W/bmSLlV28H+fHJ3fhpHDQMNAFYc5n4NgTe6eT/KY
|
||||
WA4KGfp7KYkCVAQTAQgAPhYhBNW0xDtIqHn2NSk25lk0bg6jXGflBQJcqNTKAhsD
|
||||
BQkDw5EFBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheAAAoJEFk0bg6jXGflazAP/iae
|
||||
7/PIaWhIyDw14NvyJG4D8FMSV9bC1cJ+ICo0qkx0dxcZMsxTp7fD8ODaSWzJEI4X
|
||||
mGDvJp5fJ7ZALFhp7IBIsj9CHRWyVBCzwhnAXgSmGF+qzBFE7WjQORdn5ytTiWAN
|
||||
PqyJV0sAw46jLJNvYv/LaFb2bzR/z6U1wQ2qvqXZj8vh2eLvY2XfQa1HnKaPi8h9
|
||||
OqtLM80/6uai2scdYAI6usB8wxTJY2b2B8flDB7c8DruCDRL1QmrK5o70yIIai2c
|
||||
4fXHHglulT9GnwD01a5DA2dgn5nxb81xgofgofXQjIOYARUKvcuZsF/tsR5S+C5k
|
||||
CJnq8V9xdABbWz/FvwXz7ejf2jPtAnD6gcvuPnLX/dsxFHio2n4HHzXboUrVMKid
|
||||
zcvuIrmlNtvKHYGxC9Dk3vNM+9rTlaY2BRt0zkgakDpMhqFu6A/TCEDZK0ukQLtc
|
||||
h0g806AWding6gr4vQDeX6dSCuJMFKTu/2q85R1w2vGuyWYSm6QR6sM+KumOX3vJ
|
||||
c/zvOodhRWXQBWYHTuSw6QGDCI115lWO8DAK4T6u7SVXfthHKm+38dpDH1tSfcHo
|
||||
KaG7XJKExEPgdcNLvJIN/xCx5lX6fy0ohj7oF1dEpeBpIgqTC0l5I8bLAjcLKZl9
|
||||
4YwJSSS8aTedptCmBTAHWd6y3W/hgFJrdKsqbHVGuQINBFlFJRsBEAC1EFjL9rvn
|
||||
O9UIJ2dfaPdfm2GjH/sKfOInfWp4KKEDWtS59Pssld4gnjcmDNgunYYhHYcok61K
|
||||
9J4x33KvkNAhEbw9y5AGW0tb7p2I6NxiOaWZjmZbg7AJMBFenipdUXBEjbu4LzEd
|
||||
yyIm3/lQiV4bW6GR14cKdQLZm/inVmbEaGSpq2g19WA+X7SwBxzZR9O80Iohm3RL
|
||||
X8Z8lXzUj/fUWCCstfXZwdy4vbZv8ms7kmq+3TUOwOiVavgWYhbal+nO0kLdVFbb
|
||||
i7YRvZh6afxfgMyJ3v1goXvsW1W8jno2ikUmkwZiiPY/cKOPmOwEzj3hl73i6qrx
|
||||
vm9SjEwEzI/gFXlJD8cOKMc6/g8kUeCepDfdKjgo1SYynLUk4NW9QeucJo6BSPEP
|
||||
llamHsTaUGzT4tj9qZqAQ0dwSnWYvyi19EMCGssLoy7bAoNueHOYZtHN5TskKShQ
|
||||
XzEG9IRZvXGmaWAT17sFesqXK0g47jQswmwobDsXyvXJfree36jQRj7SAVVK44Im
|
||||
bqBe6BT9QYIBkfThAWjwTibg0P1CPGk5TPpssAQgM3jxXVEyD6iKCS4LKWrtm+Sk
|
||||
MlGaPNyO8OcwHp6p5QaYAE6vlSfT8fsZ0iGd06ua5miZRbkM2i94/jVKvZLRvWv4
|
||||
S8SMZemAYnVMc0YFWEJCbaKdZp35rb5e4QARAQABiQI8BBgBCAAmAhsMFiEE1bTE
|
||||
O0ioefY1KTbmWTRuDqNcZ+UFAlyo2PAFCQVE51UACgkQWTRuDqNcZ+V+Hg/9HhVI
|
||||
No0ID4o8y0jlhyNg8n/Fy08uDALQ6JlbN6buLw+IYU75GTDIysGjx+9bgt+Mjvtp
|
||||
bbWkeT6okKkyB3H/x7w7v9GTYWlnzMA/KwHF7L7Wqy0afcVjg+fchWXPJQ3H5Jxh
|
||||
bcX3FKkIN9kpfdHN87C8//s4LzDOWeYCxFwkxkbx4tc1K4HhezpvYDKiLmFMVbaU
|
||||
qB0pzP8IM3hU1GJeAC2skfjstuaKJPuF895aFSF6++DYodXBFu3UlSJbJGfDEBYC
|
||||
9PgSrxX1qlNUFw+6Hr2uSdPnmcKgCDFGhxB1d/Z2Xa/QFhvuj7U38eyqla3dzXxu
|
||||
4+/9BOoJwdyRlUxd1Jcy3q7l8V4Hk1vMwKICdXBadAcAgSi0ImXt7UttpTYB7WNV
|
||||
nlFmFFi8eVnmMll08LWV6LygG8GBSzW5NUZnUhxHbFVFcEuHo6W1lIEgJooOnGwd
|
||||
H2rqKXpkcv86q7ODxdt9nb0txUPzgukusHes6Q0cnTMWcd0YT75frKjjK6TK8KZA
|
||||
XMH0zobogpnr/n2ji87cn9sSlL3/2NtxfAwqyDWomECKOtKYfx10OPjrPrScDFG0
|
||||
aF6w50Xg5DH/I38zzBVanEgwzWHosIVKNQHgoSYijErnShbRefA8+zCsyn0q/9Rg
|
||||
cToAM7X3ro+tQQHWDIhiayHvJMeGN/R/u1U4Kv25BK4EW0zMChEMALGnffpA/rz6
|
||||
oRXV++syFI6AaByfiatYgKh+d2LkhyeAAnp93VBV8c2YArsSp7XookhxlRA7XAGw
|
||||
x71VKouHjdcMpZM76OcEJgC2fKCbsLrMhkjKOjux6Lru1mY4bFmXBxex0pssvIoc
|
||||
zefV00qVvQ0e2JkvUmuKKIplyH0GAapDRnF3R8/doNNUXfVufHButKHlmK7yaFkK
|
||||
UBXLFUc3c8mCm/UQcMrFYrlyRNd6Axir2LpD8ya8gIwOM49nH+DDSla4d23zP+4M
|
||||
kTaWZ5QlX4FGN8kfPE4rzVxhCP0jtC5m2oqFp8dIKtxzX836YkHG7wlAPsaoPmhl
|
||||
kJMylGSwvjRvjxNLHWodMJfrQgajnW0UEd1XrfO48i/OD3f1Z22/sHRY2VejD4KJ
|
||||
49QBienKCUlNbZRfpaGOQn2HqbOX6/wUfS/83rhBVNrsU2kNb/+6OKsJV2YtokPK
|
||||
saS88q8225YEcsDLPS/3V5VrFW0CQwXJM4AbVweHhE7486VtSfkQswEAjTMJSbTO
|
||||
4IgjWYDaQ57m77bc4N9z0oCWaChlaAjdzSsL/0JQx5GJXUcxW1GvEGhP/Fx1IFd3
|
||||
oCR8OmY6oZHYmB1fNvFLSmJN0dJcQjm3hebrSQiWg/JvVAlF2S7f+j0pjeki09kM
|
||||
0RqAHOkDpLeY6ifU8+QW5DP5yh8d9ZDc4wjPdz53ycwJzaMqESOIr9eHYtOWN6Hi
|
||||
0rItsMN8FB5A70te1IcKG5UWh3cCRg7fEbKVofIYTSU2V98RLkp+iEHLKfa6wObx
|
||||
Mt60OVU/xbrO28w93cLpWUIH1Csow3k3wSbNmw3d9mWc7cVESct+IM5W4ZSYMcjG
|
||||
cvcMELWCwuT1mPSkR0hv2oz5xFOBlUV1KUViIcxpKzTrjj69JAaBbJ3f5OEfEbj/
|
||||
G+aa30EoddPBhwF7XnQUeC/DLRJQh2MH1ohMnkpBttDipHOuFS1CZh8xoxr/8moW
|
||||
nj5FRG+FAZeCmcqj5PE+du7KF2XRPBlxhc1Nu+kPejlr6qa5qdwo4MzfuzmxWmvc
|
||||
WQuNMtaPqQvYL1A09MH0uMH65MtJNsqbSvHa5AwAlletPw6Wr0qrBLBCmOpNf+Q7
|
||||
7nBQBrK5VPMcto9IkGB4/bwhx7gQ0O2dD4dD4DPpGY9p52KpOG2ECoCWMtbsPD2P
|
||||
bs+WNHN8V+3ZCxZukEj25wDhc5941P01BhKVFevGLHyYNWk34mQk7RdHj9OiEL8n
|
||||
GpQ9l/R58+mvVwarzs898/y5onQieWi0Zu3WfMvjTOG3D3NIKMuthzRytfV5C/tJ
|
||||
+W5ZX/jLVR3bzvzx8Pnpvf602xCST9/7LbgFhljfXQq0bq0d9si9hvyaMOh1PQFU
|
||||
2+PzmWtHcsiVoyXfQp6ztJYFkoYaaD+Mc2jWG2Qy9kAyUGTXj/WfkPn7hr5hvuwk
|
||||
0kNDSan8NY2f1mtG253qr6fMOmCgrUfaumpafd9xIJ65x1G2BGAr8bzjLJufEUaG
|
||||
D2wBYWE6tlRqT4j7u6u9vRjShKH+A1UpLV2pEtaIQ3wfbt6GIwFJHWU506m3RCCn
|
||||
pL46fAOVKS1GSuf79koXsZeECJRSbipXz3TJs0TqiQKzBBgBCAAmAhsCFiEE1bTE
|
||||
O0ioefY1KTbmWTRuDqNcZ+UFAlyo2PAFCQM9QGYAgXYgBBkRCAAdFiEENVUmaGTK
|
||||
bX/0Wqbnz8OUl/GybgcFAltMzAoACgkQz8OUl/Gybgf0OwD/c4hwqsfZ79t7pM9d
|
||||
PPWYQ1jyq2g3ELMKyPp79GmL0qsA/2t2qkaOEX3y7egmhL/iKyqASb4y/JTABGMU
|
||||
hy5GjBhxCRBZNG4Oo1xn5WBvEACbCAQRC00FYoktuRzQQy2LCJe13AUS1/lCWv8B
|
||||
Qu7hTmM8TC/iNmYk71qeYInQMp/12b0HSWcv8IBmOlMy2GTjgnTgiwpqY5nhtb9O
|
||||
uB5H2g6fpu7FFG9ARhtH9PiTMwOUzfZFUz0tDdEEG5sayzWUcY3zjmJFmHSg5A9B
|
||||
/Q/yctqZ1eINtyEECINo/OVEfD7bmyZwK/vrxAg285iF6lB11wVl+5E7sNy9Hvu8
|
||||
4kCKPksqyjFWUd0XoEu9AH6+XVeEPF7CQKHpRfhc4uweT9O5nTb7aaPcqq0B4lUL
|
||||
unG6KSCm88zaZczp2SUCFwENegmBT/YKN5ZoHsPh1nwLxh194EP/qRjW9IvFKTlJ
|
||||
EsB4uCpfDeC233oH5nDkvvphcPYdUuOsVH1uPQ7PyWNTf1ufd9bDSDtK8epIcDPe
|
||||
abOuphxQbrMVP4JJsBXnVW5raZO7s5lmSA8Ovce//+xJSAq9u0GTsGu1hWDe60ro
|
||||
uOZwqjo/cU5G4y7WHRaC3oshH+DO8ajdXDogoDVs8DzYkTfWND2DDNEVhVrn7lGf
|
||||
a4739sFIDagtBq6RzJGL0X82eJZzXPFiYvmy0OVbNDUgH+Drva/wRv/tN8RvBiS6
|
||||
bsn8+GBGaU5RASu67UbqxHiytFnN4OnADA5ZHcwQbMgRHHiiMMIf+tJWH/pFMp00
|
||||
epiDVQ==
|
||||
=VSKJ
|
||||
-----END PGP PUBLIC KEY BLOCK-----
|
||||
`
|
||||
keys, err := CheckArmoredGPGKeyString(testIssue6599)
|
||||
assert.NoError(t, err)
|
||||
if assert.NotEmpty(t, keys) {
|
||||
ekey := keys[0]
|
||||
expire := getExpiryTime(ekey)
|
||||
assert.Equal(t, time.Unix(1586105389, 0), expire)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryGetKeyIDFromSignature(t *testing.T) {
|
||||
assert.Empty(t, TryGetKeyIDFromSignature(&packet.Signature{}))
|
||||
assert.Equal(t, "038D1A3EADDBEA9C", TryGetKeyIDFromSignature(&packet.Signature{
|
||||
IssuerKeyId: util.ToPointer(uint64(0x38D1A3EADDBEA9C)),
|
||||
}))
|
||||
assert.Equal(t, "038D1A3EADDBEA9C", TryGetKeyIDFromSignature(&packet.Signature{
|
||||
IssuerFingerprint: []uint8{0xb, 0x23, 0x24, 0xc7, 0xe6, 0xfe, 0x4f, 0x3a, 0x6, 0x26, 0xc1, 0x21, 0x3, 0x8d, 0x1a, 0x3e, 0xad, 0xdb, 0xea, 0x9c},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestParseGPGKey(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.NoError(t, db.Insert(t.Context(), &user_model.EmailAddress{UID: 1, Email: "email1@example.com", IsActivated: true}))
|
||||
|
||||
// create a key for test email
|
||||
e, err := openpgp.NewEntity("name", "comment", "email1@example.com", nil)
|
||||
require.NoError(t, err)
|
||||
k, err := parseGPGKey(t.Context(), 1, e, true)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, k.KeyID)
|
||||
assert.NotEmpty(t, k.Emails) // the key is valid, matches the email
|
||||
|
||||
// then revoke the key
|
||||
for _, id := range e.Identities {
|
||||
id.Revocations = append(id.Revocations, &packet.Signature{RevocationReason: util.ToPointer(packet.KeyCompromised)})
|
||||
}
|
||||
k, err = parseGPGKey(t.Context(), 1, e, true)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, k.KeyID)
|
||||
assert.Empty(t, k.Emails) // the key is revoked, matches no email
|
||||
}
|
||||
98
models/asymkey/gpg_key_verify.go
Normal file
98
models/asymkey/gpg_key_verify.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/base"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
)
|
||||
|
||||
// This file provides functions relating verifying gpg keys
|
||||
|
||||
// VerifyGPGKey marks a GPG key as verified
|
||||
func VerifyGPGKey(ctx context.Context, ownerID int64, keyID, token, signature string) (string, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (string, error) {
|
||||
key := new(GPGKey)
|
||||
|
||||
has, err := db.GetEngine(ctx).Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if !has {
|
||||
return "", ErrGPGKeyNotExist{}
|
||||
}
|
||||
|
||||
if err := key.LoadSubKeys(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sig, err := ExtractSignature(signature)
|
||||
if err != nil {
|
||||
return "", ErrGPGInvalidTokenSignature{
|
||||
ID: key.KeyID,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
signer, err := hashAndVerifyWithSubKeys(sig, token, key)
|
||||
if err != nil {
|
||||
return "", ErrGPGInvalidTokenSignature{
|
||||
ID: key.KeyID,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
if signer == nil {
|
||||
signer, err = hashAndVerifyWithSubKeys(sig, token+"\n", key)
|
||||
if err != nil {
|
||||
return "", ErrGPGInvalidTokenSignature{
|
||||
ID: key.KeyID,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
if signer == nil {
|
||||
signer, err = hashAndVerifyWithSubKeys(sig, token+"\n\n", key)
|
||||
if err != nil {
|
||||
return "", ErrGPGInvalidTokenSignature{
|
||||
ID: key.KeyID,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if signer == nil {
|
||||
log.Debug("VerifyGPGKey failed: no signer")
|
||||
return "", ErrGPGInvalidTokenSignature{
|
||||
ID: key.KeyID,
|
||||
}
|
||||
}
|
||||
|
||||
if signer.PrimaryKeyID != key.KeyID && signer.KeyID != key.KeyID {
|
||||
return "", ErrGPGKeyNotExist{}
|
||||
}
|
||||
|
||||
key.Verified = true
|
||||
if _, err := db.GetEngine(ctx).ID(key.ID).Cols("verified").Update(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return key.KeyID, nil
|
||||
})
|
||||
}
|
||||
|
||||
// VerificationToken returns token for the user that will be valid in minutes (time)
|
||||
func VerificationToken(user *user_model.User, minutes int) string {
|
||||
return base.EncodeSha256(
|
||||
time.Now().Truncate(1*time.Minute).Add(time.Duration(minutes)*time.Minute).Format(
|
||||
time.RFC1123Z) + ":" +
|
||||
user.CreatedUnix.Format(time.RFC1123Z) + ":" +
|
||||
user.Name + ":" +
|
||||
user.Email + ":" +
|
||||
strconv.FormatInt(user.ID, 10))
|
||||
}
|
||||
37
models/asymkey/key_display.go
Normal file
37
models/asymkey/key_display.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2025 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"code.gitea.io/gitea/modules/git"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
)
|
||||
|
||||
func GetDisplaySigningKey(key *git.SigningKey) string {
|
||||
if key == nil || key.Format == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch key.Format {
|
||||
case git.SigningKeyFormatOpenPGP:
|
||||
return key.KeyID
|
||||
case git.SigningKeyFormatSSH:
|
||||
content, err := os.ReadFile(key.KeyID)
|
||||
if err != nil {
|
||||
log.Error("Unable to read SSH key %s: %v", key.KeyID, err)
|
||||
return "(Unable to read SSH key)"
|
||||
}
|
||||
display, err := CalcFingerprint(string(content))
|
||||
if err != nil {
|
||||
log.Error("Unable to calculate fingerprint for SSH key %s: %v", key.KeyID, err)
|
||||
return "(Unable to calculate fingerprint for SSH key)"
|
||||
}
|
||||
return display
|
||||
}
|
||||
setting.PanicInDevOrTesting("Unknown signing key format: %s", key.Format)
|
||||
return "(Unknown key format)"
|
||||
}
|
||||
23
models/asymkey/main_test.go
Normal file
23
models/asymkey/main_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m, &unittest.TestOptions{
|
||||
FixtureFiles: []string{
|
||||
"gpg_key.yml",
|
||||
"public_key.yml",
|
||||
"deploy_key.yml",
|
||||
"gpg_key_import.yml",
|
||||
"user.yml",
|
||||
"email_address.yml",
|
||||
},
|
||||
})
|
||||
}
|
||||
407
models/asymkey/ssh_key.go
Normal file
407
models/asymkey/ssh_key.go
Normal file
@@ -0,0 +1,407 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/perm"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// KeyType specifies the key type
|
||||
type KeyType int
|
||||
|
||||
const (
|
||||
// KeyTypeUser specifies the user key
|
||||
KeyTypeUser = iota + 1
|
||||
// KeyTypeDeploy specifies the deploy key
|
||||
KeyTypeDeploy
|
||||
// KeyTypePrincipal specifies the authorized principal key
|
||||
KeyTypePrincipal
|
||||
)
|
||||
|
||||
// PublicKey represents a user or deploy SSH public key.
|
||||
type PublicKey struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
OwnerID int64 `xorm:"INDEX NOT NULL"`
|
||||
Name string `xorm:"NOT NULL"`
|
||||
Fingerprint string `xorm:"INDEX NOT NULL"`
|
||||
Content string `xorm:"MEDIUMTEXT NOT NULL"`
|
||||
Mode perm.AccessMode `xorm:"NOT NULL DEFAULT 2"`
|
||||
Type KeyType `xorm:"NOT NULL DEFAULT 1"`
|
||||
LoginSourceID int64 `xorm:"NOT NULL DEFAULT 0"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
HasRecentActivity bool `xorm:"-"`
|
||||
HasUsed bool `xorm:"-"`
|
||||
Verified bool `xorm:"NOT NULL DEFAULT false"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(PublicKey))
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (key *PublicKey) AfterLoad() {
|
||||
key.HasUsed = key.UpdatedUnix > key.CreatedUnix
|
||||
key.HasRecentActivity = key.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
// OmitEmail returns content of public key without email address.
|
||||
func (key *PublicKey) OmitEmail() string {
|
||||
return strings.Join(strings.Split(key.Content, " ")[:2], " ")
|
||||
}
|
||||
|
||||
func addKey(ctx context.Context, key *PublicKey) (err error) {
|
||||
if len(key.Fingerprint) == 0 {
|
||||
key.Fingerprint, err = CalcFingerprint(key.Content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Save SSH key.
|
||||
if err = db.Insert(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return appendAuthorizedKeysToFile(key)
|
||||
}
|
||||
|
||||
// AddPublicKey adds new public key to database and authorized_keys file.
|
||||
func AddPublicKey(ctx context.Context, ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
|
||||
log.Trace(content)
|
||||
|
||||
fingerprint, err := CalcFingerprint(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*PublicKey, error) {
|
||||
if err := checkKeyFingerprint(ctx, fingerprint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Key name of same user cannot be duplicated.
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("owner_id = ? AND name = ?", ownerID, name).
|
||||
Get(new(PublicKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if has {
|
||||
return nil, ErrKeyNameAlreadyUsed{ownerID, name}
|
||||
}
|
||||
|
||||
key := &PublicKey{
|
||||
OwnerID: ownerID,
|
||||
Name: name,
|
||||
Fingerprint: fingerprint,
|
||||
Content: content,
|
||||
Mode: perm.AccessModeWrite,
|
||||
Type: KeyTypeUser,
|
||||
LoginSourceID: authSourceID,
|
||||
}
|
||||
if err = addKey(ctx, key); err != nil {
|
||||
return nil, fmt.Errorf("addKey: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetPublicKeyByID returns public key by given ID.
|
||||
func GetPublicKeyByID(ctx context.Context, keyID int64) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := db.GetEngine(ctx).
|
||||
ID(keyID).
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrKeyNotExist{keyID}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SearchPublicKeyByContent searches content as prefix (leak e-mail part)
|
||||
// and returns public key found.
|
||||
func SearchPublicKeyByContent(ctx context.Context, content string) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("content like ?", content+"%").
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrKeyNotExist{}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SearchPublicKeyByContentExact searches content
|
||||
// and returns public key found.
|
||||
func SearchPublicKeyByContentExact(ctx context.Context, content string) (*PublicKey, error) {
|
||||
key := new(PublicKey)
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("content = ?", content).
|
||||
Get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrKeyNotExist{}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
type FindPublicKeyOptions struct {
|
||||
db.ListOptions
|
||||
OwnerID int64
|
||||
Fingerprint string
|
||||
KeyTypes []KeyType
|
||||
NotKeytype KeyType
|
||||
LoginSourceID int64
|
||||
}
|
||||
|
||||
func (opts FindPublicKeyOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opts.OwnerID > 0 {
|
||||
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
|
||||
}
|
||||
if opts.Fingerprint != "" {
|
||||
cond = cond.And(builder.Eq{"fingerprint": opts.Fingerprint})
|
||||
}
|
||||
if len(opts.KeyTypes) > 0 {
|
||||
cond = cond.And(builder.In("`type`", opts.KeyTypes))
|
||||
}
|
||||
if opts.NotKeytype > 0 {
|
||||
cond = cond.And(builder.Neq{"`type`": opts.NotKeytype})
|
||||
}
|
||||
if opts.LoginSourceID > 0 {
|
||||
cond = cond.And(builder.Eq{"login_source_id": opts.LoginSourceID})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
// UpdatePublicKeyUpdated updates public key use time.
|
||||
func UpdatePublicKeyUpdated(ctx context.Context, id int64) error {
|
||||
// Check if key exists before update as affected rows count is unreliable
|
||||
// and will return 0 affected rows if two updates are made at the same time
|
||||
if cnt, err := db.GetEngine(ctx).ID(id).Count(&PublicKey{}); err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrKeyNotExist{id}
|
||||
}
|
||||
|
||||
_, err := db.GetEngine(ctx).ID(id).Cols("updated_unix").Update(&PublicKey{
|
||||
UpdatedUnix: timeutil.TimeStampNow(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
||||
func PublicKeysAreExternallyManaged(ctx context.Context, keys []*PublicKey) ([]bool, error) {
|
||||
sourceCache := make(map[int64]*auth.Source, len(keys))
|
||||
externals := make([]bool, len(keys))
|
||||
|
||||
for i, key := range keys {
|
||||
if key.LoginSourceID == 0 {
|
||||
externals[i] = false
|
||||
continue
|
||||
}
|
||||
|
||||
source, ok := sourceCache[key.LoginSourceID]
|
||||
if !ok {
|
||||
var err error
|
||||
source, err = auth.GetSourceByID(ctx, key.LoginSourceID)
|
||||
if err != nil {
|
||||
if auth.IsErrSourceNotExist(err) {
|
||||
externals[i] = false
|
||||
sourceCache[key.LoginSourceID] = &auth.Source{
|
||||
ID: key.LoginSourceID,
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if sshKeyProvider, ok := source.Cfg.(auth.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() {
|
||||
// Disable setting SSH keys for this user
|
||||
externals[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
return externals, nil
|
||||
}
|
||||
|
||||
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
|
||||
func PublicKeyIsExternallyManaged(ctx context.Context, id int64) (bool, error) {
|
||||
key, err := GetPublicKeyByID(ctx, id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if key.LoginSourceID == 0 {
|
||||
return false, nil
|
||||
}
|
||||
source, err := auth.GetSourceByID(ctx, key.LoginSourceID)
|
||||
if err != nil {
|
||||
if auth.IsErrSourceNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if sshKeyProvider, ok := source.Cfg.(auth.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() {
|
||||
// Disable setting SSH keys for this user
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
|
||||
func deleteKeysMarkedForDeletion(ctx context.Context, keys []string) (bool, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (bool, error) {
|
||||
// Delete keys marked for deletion
|
||||
var sshKeysNeedUpdate bool
|
||||
for _, KeyToDelete := range keys {
|
||||
key, err := SearchPublicKeyByContent(ctx, KeyToDelete)
|
||||
if err != nil {
|
||||
log.Error("SearchPublicKeyByContent: %v", err)
|
||||
continue
|
||||
}
|
||||
if _, err = db.DeleteByID[PublicKey](ctx, key.ID); err != nil {
|
||||
log.Error("DeleteByID[PublicKey]: %v", err)
|
||||
continue
|
||||
}
|
||||
sshKeysNeedUpdate = true
|
||||
}
|
||||
|
||||
return sshKeysNeedUpdate, nil
|
||||
})
|
||||
}
|
||||
|
||||
// AddPublicKeysBySource add a users public keys. Returns true if there are changes.
|
||||
func AddPublicKeysBySource(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
||||
var sshKeysNeedUpdate bool
|
||||
for _, sshKey := range sshPublicKeys {
|
||||
var err error
|
||||
found := false
|
||||
keys := []byte(sshKey)
|
||||
loop:
|
||||
for len(keys) > 0 && err == nil {
|
||||
var out ssh.PublicKey
|
||||
// We ignore options as they are not relevant to Gitea
|
||||
out, _, _, keys, err = ssh.ParseAuthorizedKey(keys)
|
||||
if err != nil {
|
||||
break loop
|
||||
}
|
||||
found = true
|
||||
marshalled := string(ssh.MarshalAuthorizedKey(out))
|
||||
marshalled = marshalled[:len(marshalled)-1]
|
||||
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
|
||||
|
||||
if _, err := AddPublicKey(ctx, usr.ID, sshKeyName, marshalled, s.ID); err != nil {
|
||||
if IsErrKeyAlreadyExist(err) {
|
||||
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
|
||||
} else {
|
||||
log.Error("AddPublicKeysBySource[%s]: Error adding Public SSH Key for user %s: %v", sshKeyName, usr.Name, err)
|
||||
}
|
||||
} else {
|
||||
log.Trace("AddPublicKeysBySource[%s]: Added Public SSH Key for user %s", sshKeyName, usr.Name)
|
||||
sshKeysNeedUpdate = true
|
||||
}
|
||||
}
|
||||
if !found && err != nil {
|
||||
log.Warn("AddPublicKeysBySource[%s]: Skipping invalid Public SSH Key for user %s: %v", s.Name, usr.Name, sshKey)
|
||||
}
|
||||
}
|
||||
return sshKeysNeedUpdate
|
||||
}
|
||||
|
||||
// SynchronizePublicKeys updates a user's public keys. Returns true if there are changes.
|
||||
func SynchronizePublicKeys(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
|
||||
var sshKeysNeedUpdate bool
|
||||
|
||||
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
|
||||
|
||||
// Get Public Keys from DB with the current auth source
|
||||
var giteaKeys []string
|
||||
keys, err := db.Find[PublicKey](ctx, FindPublicKeyOptions{
|
||||
OwnerID: usr.ID,
|
||||
LoginSourceID: s.ID,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
|
||||
}
|
||||
|
||||
for _, v := range keys {
|
||||
giteaKeys = append(giteaKeys, v.OmitEmail())
|
||||
}
|
||||
|
||||
// Process the provided keys to remove duplicates and name part
|
||||
var providedKeys []string
|
||||
for _, v := range sshPublicKeys {
|
||||
sshKeySplit := strings.Split(v, " ")
|
||||
if len(sshKeySplit) > 1 {
|
||||
key := strings.Join(sshKeySplit[:2], " ")
|
||||
if !util.SliceContainsString(providedKeys, key) {
|
||||
providedKeys = append(providedKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if Public Key sync is needed
|
||||
if util.SliceSortedEqual(giteaKeys, providedKeys) {
|
||||
log.Trace("synchronizePublicKeys[%s]: Public Keys are already in sync for %s (Source:%v/DB:%v)", s.Name, usr.Name, len(providedKeys), len(giteaKeys))
|
||||
return false
|
||||
}
|
||||
log.Trace("synchronizePublicKeys[%s]: Public Key needs update for user %s (Source:%v/DB:%v)", s.Name, usr.Name, len(providedKeys), len(giteaKeys))
|
||||
|
||||
// Add new Public SSH Keys that doesn't already exist in DB
|
||||
var newKeys []string
|
||||
for _, key := range providedKeys {
|
||||
if !util.SliceContainsString(giteaKeys, key) {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
if AddPublicKeysBySource(ctx, usr, s, newKeys) {
|
||||
sshKeysNeedUpdate = true
|
||||
}
|
||||
|
||||
// Mark keys from DB that no longer exist in the source for deletion
|
||||
var giteaKeysToDelete []string
|
||||
for _, giteaKey := range giteaKeys {
|
||||
if !util.SliceContainsString(providedKeys, giteaKey) {
|
||||
log.Trace("synchronizePublicKeys[%s]: Marking Public SSH Key for deletion for user %s: %v", s.Name, usr.Name, giteaKey)
|
||||
giteaKeysToDelete = append(giteaKeysToDelete, giteaKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete keys from DB that no longer exist in the source
|
||||
needUpd, err := deleteKeysMarkedForDeletion(ctx, giteaKeysToDelete)
|
||||
if err != nil {
|
||||
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
|
||||
}
|
||||
if needUpd {
|
||||
sshKeysNeedUpdate = true
|
||||
}
|
||||
|
||||
return sshKeysNeedUpdate
|
||||
}
|
||||
172
models/asymkey/ssh_key_authorized_keys.go
Normal file
172
models/asymkey/ssh_key_authorized_keys.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// AuthorizedStringCommentPrefix is a magic tag
|
||||
// some functions like RegeneratePublicKeys needs this tag to skip the keys generated by Gitea, while keep other keys
|
||||
const AuthorizedStringCommentPrefix = `# gitea public key`
|
||||
|
||||
var sshOpLocker sync.Mutex
|
||||
|
||||
func WithSSHOpLocker(f func() error) error {
|
||||
sshOpLocker.Lock()
|
||||
defer sshOpLocker.Unlock()
|
||||
return f()
|
||||
}
|
||||
|
||||
// AuthorizedStringForKey creates the authorized keys string appropriate for the provided key
|
||||
func AuthorizedStringForKey(key *PublicKey) (string, error) {
|
||||
sb := &strings.Builder{}
|
||||
_, err := writeAuthorizedStringForKey(key, sb)
|
||||
return sb.String(), err
|
||||
}
|
||||
|
||||
// WriteAuthorizedStringForValidKey writes the authorized key for the provided key. If the key is invalid, it does nothing.
|
||||
func WriteAuthorizedStringForValidKey(key *PublicKey, w io.Writer) error {
|
||||
validKey, err := writeAuthorizedStringForKey(key, w)
|
||||
if !validKey {
|
||||
log.Debug("WriteAuthorizedStringForValidKey: key %s is not valid: %v", key, err)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func writeAuthorizedStringForKey(key *PublicKey, w io.Writer) (keyValid bool, err error) {
|
||||
const tpl = AuthorizedStringCommentPrefix + "\n" + `command=%s,no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty,no-user-rc,restrict %s %s` + "\n"
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key.Content))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// now the key is valid, the code below could only return template/IO related errors
|
||||
sbCmd := &strings.Builder{}
|
||||
err = setting.SSH.AuthorizedKeysCommandTemplateTemplate.Execute(sbCmd, map[string]any{
|
||||
"AppPath": util.ShellEscape(setting.AppPath),
|
||||
"AppWorkPath": util.ShellEscape(setting.AppWorkPath),
|
||||
"CustomConf": util.ShellEscape(setting.CustomConf),
|
||||
"CustomPath": util.ShellEscape(setting.CustomPath),
|
||||
"Key": key,
|
||||
})
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
sshCommandEscaped := util.ShellEscape(sbCmd.String())
|
||||
sshKeyMarshalled := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||
sshKeyComment := fmt.Sprintf("user-%d", key.OwnerID)
|
||||
_, err = fmt.Fprintf(w, tpl, sshCommandEscaped, sshKeyMarshalled, sshKeyComment)
|
||||
return true, err
|
||||
}
|
||||
|
||||
// appendAuthorizedKeysToFile appends new SSH keys' content to authorized_keys file.
|
||||
func appendAuthorizedKeysToFile(keys ...*PublicKey) error {
|
||||
// Don't need to rewrite this file if builtin SSH server is enabled.
|
||||
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedKeysFile {
|
||||
return nil
|
||||
}
|
||||
|
||||
sshOpLocker.Lock()
|
||||
defer sshOpLocker.Unlock()
|
||||
|
||||
if setting.SSH.RootPath != "" {
|
||||
// First of ensure that the RootPath is present, and if not make it with 0700 permissions
|
||||
// This of course doesn't guarantee that this is the right directory for authorized_keys
|
||||
// but at least if it's supposed to be this directory and it doesn't exist and we're the
|
||||
// right user it will at least be created properly.
|
||||
err := os.MkdirAll(setting.SSH.RootPath, 0o700)
|
||||
if err != nil {
|
||||
log.Error("Unable to MkdirAll(%s): %v", setting.SSH.RootPath, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
|
||||
f, err := os.OpenFile(fPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Note: chmod command does not support in Windows.
|
||||
if !setting.IsWindows {
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// .ssh directory should have mode 700, and authorized_keys file should have mode 600.
|
||||
if fi.Mode().Perm() > 0o600 {
|
||||
log.Error("authorized_keys file has unusual permission flags: %s - setting to -rw-------", fi.Mode().Perm().String())
|
||||
if err = f.Chmod(0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if key.Type == KeyTypePrincipal {
|
||||
continue
|
||||
}
|
||||
if err = WriteAuthorizedStringForValidKey(key, f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegeneratePublicKeys regenerates the authorized_keys file
|
||||
func RegeneratePublicKeys(ctx context.Context, t io.Writer) error {
|
||||
if err := db.GetEngine(ctx).Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean any) (err error) {
|
||||
return WriteAuthorizedStringForValidKey(bean.(*PublicKey), t)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
|
||||
isExist, err := util.IsExist(fPath)
|
||||
if err != nil {
|
||||
log.Error("Unable to check if %s exists. Error: %v", fPath, err)
|
||||
return err
|
||||
}
|
||||
if isExist {
|
||||
f, err := os.Open(fPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, AuthorizedStringCommentPrefix) {
|
||||
scanner.Scan()
|
||||
continue
|
||||
}
|
||||
_, err = io.WriteString(t, line+"\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return fmt.Errorf("RegeneratePublicKeys scan: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
214
models/asymkey/ssh_key_deploy.go
Normal file
214
models/asymkey/ssh_key_deploy.go
Normal file
@@ -0,0 +1,214 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/perm"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ________ .__ ____ __.
|
||||
// \______ \ ____ ______ | | ____ ___.__.| |/ _|____ ___.__.
|
||||
// | | \_/ __ \\____ \| | / _ < | || <_/ __ < | |
|
||||
// | ` \ ___/| |_> > |_( <_> )___ || | \ ___/\___ |
|
||||
// /_______ /\___ > __/|____/\____// ____||____|__ \___ > ____|
|
||||
// \/ \/|__| \/ \/ \/\/
|
||||
//
|
||||
// This file contains functions specific to DeployKeys
|
||||
|
||||
// DeployKey represents deploy key information and its relation with repository.
|
||||
type DeployKey struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
KeyID int64 `xorm:"UNIQUE(s) INDEX"`
|
||||
RepoID int64 `xorm:"UNIQUE(s) INDEX"`
|
||||
Name string
|
||||
Fingerprint string
|
||||
Content string `xorm:"-"`
|
||||
|
||||
Mode perm.AccessMode `xorm:"NOT NULL DEFAULT 1"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
HasRecentActivity bool `xorm:"-"`
|
||||
HasUsed bool `xorm:"-"`
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (key *DeployKey) AfterLoad() {
|
||||
key.HasUsed = key.UpdatedUnix > key.CreatedUnix
|
||||
key.HasRecentActivity = key.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
// GetContent gets associated public key content.
|
||||
func (key *DeployKey) GetContent(ctx context.Context) error {
|
||||
pkey, err := GetPublicKeyByID(ctx, key.KeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.Content = pkey.Content
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReadOnly checks if the key can only be used for read operations, used by template
|
||||
func (key *DeployKey) IsReadOnly() bool {
|
||||
return key.Mode == perm.AccessModeRead
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(DeployKey))
|
||||
}
|
||||
|
||||
func checkDeployKey(ctx context.Context, keyID, repoID int64, name string) error {
|
||||
// Note: We want error detail, not just true or false here.
|
||||
has, err := db.GetEngine(ctx).
|
||||
Where("key_id = ? AND repo_id = ?", keyID, repoID).
|
||||
Get(new(DeployKey))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrDeployKeyAlreadyExist{keyID, repoID}
|
||||
}
|
||||
|
||||
has, err = db.GetEngine(ctx).
|
||||
Where("repo_id = ? AND name = ?", repoID, name).
|
||||
Get(new(DeployKey))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrDeployKeyNameAlreadyUsed{repoID, name}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addDeployKey adds new key-repo relation.
|
||||
func addDeployKey(ctx context.Context, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
|
||||
if err := checkDeployKey(ctx, keyID, repoID, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := &DeployKey{
|
||||
KeyID: keyID,
|
||||
RepoID: repoID,
|
||||
Name: name,
|
||||
Fingerprint: fingerprint,
|
||||
Mode: mode,
|
||||
}
|
||||
return key, db.Insert(ctx, key)
|
||||
}
|
||||
|
||||
// HasDeployKey returns true if public key is a deploy key of given repository.
|
||||
func HasDeployKey(ctx context.Context, keyID, repoID int64) bool {
|
||||
has, _ := db.GetEngine(ctx).
|
||||
Where("key_id = ? AND repo_id = ?", keyID, repoID).
|
||||
Get(new(DeployKey))
|
||||
return has
|
||||
}
|
||||
|
||||
// AddDeployKey add new deploy key to database and authorized_keys file.
|
||||
func AddDeployKey(ctx context.Context, repoID int64, name, content string, readOnly bool) (*DeployKey, error) {
|
||||
fingerprint, err := CalcFingerprint(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessMode := perm.AccessModeRead
|
||||
if !readOnly {
|
||||
accessMode = perm.AccessModeWrite
|
||||
}
|
||||
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*DeployKey, error) {
|
||||
pkey, exist, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if exist {
|
||||
if pkey.Type != KeyTypeDeploy {
|
||||
return nil, ErrKeyAlreadyExist{0, fingerprint, ""}
|
||||
}
|
||||
} else {
|
||||
// First time use this deploy key.
|
||||
pkey = &PublicKey{
|
||||
Fingerprint: fingerprint,
|
||||
Mode: accessMode,
|
||||
Type: KeyTypeDeploy,
|
||||
Content: content,
|
||||
Name: name,
|
||||
}
|
||||
if err = addKey(ctx, pkey); err != nil {
|
||||
return nil, fmt.Errorf("addKey: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
key, err := addDeployKey(ctx, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetDeployKeyByID returns deploy key by given ID.
|
||||
func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
|
||||
key, exist, err := db.GetByID[DeployKey](ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !exist {
|
||||
return nil, ErrDeployKeyNotExist{id, 0, 0}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
|
||||
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
|
||||
key, exist, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !exist {
|
||||
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// IsDeployKeyExistByKeyID return true if there is at least one deploykey with the key id
|
||||
func IsDeployKeyExistByKeyID(ctx context.Context, keyID int64) (bool, error) {
|
||||
return db.GetEngine(ctx).
|
||||
Where("key_id = ?", keyID).
|
||||
Get(new(DeployKey))
|
||||
}
|
||||
|
||||
// UpdateDeployKeyCols updates deploy key information in the specified columns.
|
||||
func UpdateDeployKeyCols(ctx context.Context, key *DeployKey, cols ...string) error {
|
||||
_, err := db.GetEngine(ctx).ID(key.ID).Cols(cols...).Update(key)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListDeployKeysOptions are options for ListDeployKeys
|
||||
type ListDeployKeysOptions struct {
|
||||
db.ListOptions
|
||||
RepoID int64
|
||||
KeyID int64
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
func (opt ListDeployKeysOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
if opt.RepoID != 0 {
|
||||
cond = cond.And(builder.Eq{"repo_id": opt.RepoID})
|
||||
}
|
||||
if opt.KeyID != 0 {
|
||||
cond = cond.And(builder.Eq{"key_id": opt.KeyID})
|
||||
}
|
||||
if opt.Fingerprint != "" {
|
||||
cond = cond.And(builder.Eq{"fingerprint": opt.Fingerprint})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
49
models/asymkey/ssh_key_fingerprint.go
Normal file
49
models/asymkey/ssh_key_fingerprint.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// The database is used in checkKeyFingerprint. However, most of these functions probably belong in a module
|
||||
|
||||
// checkKeyFingerprint only checks if key fingerprint has been used as a public key,
|
||||
// it is OK to use same key as deploy key for multiple repositories/users.
|
||||
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
|
||||
has, err := db.Exist[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrKeyAlreadyExist{0, fingerprint, ""}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func calcFingerprintNative(publicKeyContent string) (string, error) {
|
||||
// Calculate fingerprint.
|
||||
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKeyContent))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ssh.FingerprintSHA256(pk), nil
|
||||
}
|
||||
|
||||
// CalcFingerprint calculate public key's fingerprint
|
||||
func CalcFingerprint(publicKeyContent string) (string, error) {
|
||||
fp, err := calcFingerprintNative(publicKeyContent)
|
||||
if err != nil {
|
||||
if IsErrKeyUnableVerify(err) {
|
||||
return "", err
|
||||
}
|
||||
return "", fmt.Errorf("CalcFingerprint: %w", err)
|
||||
}
|
||||
return fp, nil
|
||||
}
|
||||
246
models/asymkey/ssh_key_parse.go
Normal file
246
models/asymkey/ssh_key_parse.go
Normal file
@@ -0,0 +1,246 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// ____ __. __________
|
||||
// | |/ _|____ ___.__. \______ \_____ _______ ______ ___________
|
||||
// | <_/ __ < | | | ___/\__ \\_ __ \/ ___// __ \_ __ \
|
||||
// | | \ ___/\___ | | | / __ \| | \/\___ \\ ___/| | \/
|
||||
// |____|__ \___ > ____| |____| (____ /__| /____ >\___ >__|
|
||||
// \/ \/\/ \/ \/ \/
|
||||
//
|
||||
// This file contains functions for parsing ssh-keys
|
||||
//
|
||||
// TODO: Consider if these functions belong in models - no other models function call them or are called by them
|
||||
// They may belong in a service or a module
|
||||
|
||||
const ssh2keyStart = "---- BEGIN SSH2 PUBLIC KEY ----"
|
||||
|
||||
func extractTypeFromBase64Key(key string) (string, error) {
|
||||
b, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil || len(b) < 4 {
|
||||
return "", fmt.Errorf("invalid key format: %w", err)
|
||||
}
|
||||
|
||||
keyLength := int(binary.BigEndian.Uint32(b))
|
||||
if len(b) < 4+keyLength {
|
||||
return "", fmt.Errorf("invalid key format: not enough length %d", keyLength)
|
||||
}
|
||||
|
||||
return string(b[4 : 4+keyLength]), nil
|
||||
}
|
||||
|
||||
// parseKeyString parses any key string in OpenSSH or SSH2 format to clean OpenSSH string (RFC4253).
|
||||
func parseKeyString(content string) (string, error) {
|
||||
// remove whitespace at start and end
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
var keyType, keyContent, keyComment string
|
||||
|
||||
if strings.HasPrefix(content, ssh2keyStart) {
|
||||
// Parse SSH2 file format.
|
||||
|
||||
// Transform all legal line endings to a single "\n".
|
||||
content = strings.NewReplacer("\r\n", "\n", "\r", "\n").Replace(content)
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
continuationLine := false
|
||||
|
||||
for _, line := range lines {
|
||||
// Skip lines that:
|
||||
// 1) are a continuation of the previous line,
|
||||
// 2) contain ":" as that are comment lines
|
||||
// 3) contain "-" as that are begin and end tags
|
||||
if continuationLine || strings.ContainsAny(line, ":-") {
|
||||
continuationLine = strings.HasSuffix(line, "\\")
|
||||
} else {
|
||||
keyContent += line
|
||||
}
|
||||
}
|
||||
|
||||
t, err := extractTypeFromBase64Key(keyContent)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extractTypeFromBase64Key: %w", err)
|
||||
}
|
||||
keyType = t
|
||||
} else {
|
||||
if strings.Contains(content, "-----BEGIN") {
|
||||
// Convert PEM Keys to OpenSSH format
|
||||
// Transform all legal line endings to a single "\n".
|
||||
content = strings.NewReplacer("\r\n", "\n", "\r", "\n").Replace(content)
|
||||
|
||||
block, _ := pem.Decode([]byte(content))
|
||||
if block == nil {
|
||||
return "", errors.New("failed to parse PEM block containing the public key")
|
||||
}
|
||||
if strings.Contains(block.Type, "PRIVATE") {
|
||||
return "", ErrKeyIsPrivate
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
var pk rsa.PublicKey
|
||||
_, err2 := asn1.Unmarshal(block.Bytes, &pk)
|
||||
if err2 != nil {
|
||||
return "", fmt.Errorf("failed to parse DER encoded public key as either PKIX or PEM RSA Key: %v %w", err, err2)
|
||||
}
|
||||
pub = &pk
|
||||
}
|
||||
|
||||
sshKey, err := ssh.NewPublicKey(pub)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to convert to ssh public key: %w", err)
|
||||
}
|
||||
content = string(ssh.MarshalAuthorizedKey(sshKey))
|
||||
}
|
||||
// Parse OpenSSH format.
|
||||
|
||||
// Remove all newlines
|
||||
content = strings.NewReplacer("\r\n", "", "\n", "").Replace(content)
|
||||
|
||||
parts := strings.SplitN(content, " ", 3)
|
||||
switch len(parts) {
|
||||
case 0:
|
||||
return "", util.NewInvalidArgumentErrorf("empty key")
|
||||
case 1:
|
||||
keyContent = parts[0]
|
||||
case 2:
|
||||
keyType = parts[0]
|
||||
keyContent = parts[1]
|
||||
default:
|
||||
keyType = parts[0]
|
||||
keyContent = parts[1]
|
||||
keyComment = parts[2]
|
||||
}
|
||||
|
||||
// If keyType is not given, extract it from content. If given, validate it.
|
||||
t, err := extractTypeFromBase64Key(keyContent)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extractTypeFromBase64Key: %w", err)
|
||||
}
|
||||
if len(keyType) == 0 {
|
||||
keyType = t
|
||||
} else if keyType != t {
|
||||
return "", fmt.Errorf("key type and content does not match: %s - %s", keyType, t)
|
||||
}
|
||||
}
|
||||
// Finally we need to check whether we can actually read the proposed key:
|
||||
_, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + keyContent + " " + keyComment))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid ssh public key: %w", err)
|
||||
}
|
||||
return keyType + " " + keyContent + " " + keyComment, nil
|
||||
}
|
||||
|
||||
// CheckPublicKeyString checks if the given public key string is recognized by SSH.
|
||||
// It returns the actual public key line on success.
|
||||
func CheckPublicKeyString(content string) (_ string, err error) {
|
||||
content, err = parseKeyString(content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content = strings.TrimRight(content, "\n\r")
|
||||
if strings.ContainsAny(content, "\n\r") {
|
||||
return "", util.NewInvalidArgumentErrorf("only a single line with a single key please")
|
||||
}
|
||||
|
||||
// remove any unnecessary whitespace now
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
if !setting.SSH.MinimumKeySizeCheck {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
keyType, length, err := SSHNativeParsePublicKey(content)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("SSHNativeParsePublicKey: %w", err)
|
||||
}
|
||||
log.Trace("Key info [native: %v]: %s-%d", setting.SSH.StartBuiltinServer, keyType, length)
|
||||
|
||||
if minLen, found := setting.SSH.MinimumKeySizes[keyType]; found && length >= minLen {
|
||||
return content, nil
|
||||
} else if found && length < minLen {
|
||||
return "", fmt.Errorf("key length is not enough: got %d, needs %d", length, minLen)
|
||||
}
|
||||
return "", fmt.Errorf("key type is not allowed: %s", keyType)
|
||||
}
|
||||
|
||||
// SSHNativeParsePublicKey extracts the key type and length using the golang SSH library.
|
||||
func SSHNativeParsePublicKey(keyLine string) (string, int, error) {
|
||||
fields := strings.Fields(keyLine)
|
||||
if len(fields) < 2 {
|
||||
return "", 0, fmt.Errorf("not enough fields in public key line: %s", keyLine)
|
||||
}
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(fields[1])
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
pkey, err := ssh.ParsePublicKey(raw)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "ssh: unknown key algorithm") {
|
||||
return "", 0, ErrKeyUnableVerify{err.Error()}
|
||||
}
|
||||
return "", 0, fmt.Errorf("ParsePublicKey: %w", err)
|
||||
}
|
||||
|
||||
// The ssh library can parse the key, so next we find out what key exactly we have.
|
||||
switch pkey.Type() {
|
||||
case ssh.KeyAlgoDSA: //nolint:staticcheck // it's deprecated
|
||||
rawPub := struct {
|
||||
Name string
|
||||
P, Q, G, Y *big.Int
|
||||
}{}
|
||||
if err := ssh.Unmarshal(pkey.Marshal(), &rawPub); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
// as per https://bugzilla.mindrot.org/show_bug.cgi?id=1647 we should never
|
||||
// see dsa keys != 1024 bit, but as it seems to work, we will not check here
|
||||
return "dsa", rawPub.P.BitLen(), nil // use P as per crypto/dsa/dsa.go (is L)
|
||||
case ssh.KeyAlgoRSA:
|
||||
rawPub := struct {
|
||||
Name string
|
||||
E *big.Int
|
||||
N *big.Int
|
||||
}{}
|
||||
if err := ssh.Unmarshal(pkey.Marshal(), &rawPub); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return "rsa", rawPub.N.BitLen(), nil // use N as per crypto/rsa/rsa.go (is bits)
|
||||
case ssh.KeyAlgoECDSA256:
|
||||
return "ecdsa", 256, nil
|
||||
case ssh.KeyAlgoECDSA384:
|
||||
return "ecdsa", 384, nil
|
||||
case ssh.KeyAlgoECDSA521:
|
||||
return "ecdsa", 521, nil
|
||||
case ssh.KeyAlgoED25519:
|
||||
return "ed25519", 256, nil
|
||||
case ssh.KeyAlgoSKECDSA256:
|
||||
return "ecdsa-sk", 256, nil
|
||||
case ssh.KeyAlgoSKED25519:
|
||||
return "ed25519-sk", 256, nil
|
||||
}
|
||||
return "", 0, fmt.Errorf("unsupported key length detection for type: %s", pkey.Type())
|
||||
}
|
||||
56
models/asymkey/ssh_key_principals.go
Normal file
56
models/asymkey/ssh_key_principals.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
user_model "code.gitea.io/gitea/models/user"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines
|
||||
func CheckPrincipalKeyString(ctx context.Context, user *user_model.User, content string) (_ string, err error) {
|
||||
if setting.SSH.Disabled {
|
||||
return "", db.ErrSSHDisabled{}
|
||||
}
|
||||
|
||||
content = strings.TrimSpace(content)
|
||||
if strings.ContainsAny(content, "\r\n") {
|
||||
return "", util.NewInvalidArgumentErrorf("only a single line with a single principal please")
|
||||
}
|
||||
|
||||
// check all the allowed principals, email, username or anything
|
||||
// if any matches, return ok
|
||||
for _, v := range setting.SSH.AuthorizedPrincipalsAllow {
|
||||
switch v {
|
||||
case "anything":
|
||||
return content, nil
|
||||
case "email":
|
||||
emails, err := user_model.GetEmailAddresses(ctx, user.ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, email := range emails {
|
||||
if !email.IsActivated {
|
||||
continue
|
||||
}
|
||||
if content == email.Email {
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
|
||||
case "username":
|
||||
if content == user.Name {
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("didn't match allowed principals: %s", setting.SSH.AuthorizedPrincipalsAllow)
|
||||
}
|
||||
482
models/asymkey/ssh_key_test.go
Normal file
482
models/asymkey/ssh_key_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
// Copyright 2016 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"github.com/42wim/sshsig"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_SSHParsePublicKey(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skipSSHKeygen bool
|
||||
keyType string
|
||||
length int
|
||||
content string
|
||||
}{
|
||||
{"rsa-1024", false, "rsa", 1024, "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
|
||||
{"rsa-2048", false, "rsa", 2048, "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDMZXh+1OBUwSH9D45wTaxErQIN9IoC9xl7MKJkqvTvv6O5RR9YW/IK9FbfjXgXsppYGhsCZo1hFOOsXHMnfOORqu/xMDx4yPuyvKpw4LePEcg4TDipaDFuxbWOqc/BUZRZcXu41QAWfDLrInwsltWZHSeG7hjhpacl4FrVv9V1pS6Oc5Q1NxxEzTzuNLS/8diZrTm/YAQQ/+B+mzWI3zEtF4miZjjAljWd1LTBPvU23d29DcBmmFahcZ441XZsTeAwGxG/Q6j8NgNXj9WxMeWwxXV2jeAX/EBSpZrCVlCQ1yJswT6xCp8TuBnTiGWYMBNTbOZvPC4e0WI2/yZW/s5F nocomment"},
|
||||
{"ecdsa-256", false, "ecdsa", 256, "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFQacN3PrOll7PXmN5B/ZNVahiUIqI05nbBlZk1KXsO3d06ktAWqbNflv2vEmA38bTFTfJ2sbn2B5ksT52cDDbA= nocomment"},
|
||||
{"ecdsa-384", false, "ecdsa", 384, "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBINmioV+XRX1Fm9Qk2ehHXJ2tfVxW30ypUWZw670Zyq5GQfBAH6xjygRsJ5wWsHXBsGYgFUXIHvMKVAG1tpw7s6ax9oA+dJOJ7tj+vhn8joFqT+sg3LYHgZkHrfqryRasQ== nocomment"},
|
||||
{"ecdsa-sk", true, "ecdsa-sk", 256, "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
|
||||
{"ed25519-sk", true, "ed25519-sk", 256, "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("Native", func(t *testing.T) {
|
||||
keyTypeN, lengthN, err := SSHNativeParsePublicKey(tc.content)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.keyType, keyTypeN)
|
||||
assert.Equal(t, tc.length, lengthN)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckPublicKeyString(t *testing.T) {
|
||||
oldValue := setting.SSH.MinimumKeySizeCheck
|
||||
setting.SSH.MinimumKeySizeCheck = false
|
||||
for _, test := range []struct {
|
||||
content string
|
||||
}{
|
||||
{"ssh-dss AAAAB3NzaC1kc3MAAACBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn59NriyboW2x/DZuYAz3ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczWOVsaszBZSl90q8UnWlSH6P+/YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQsecdKktISwTakzAAAAFQCzKsO2JhNKlL+wwwLGOcLffoAmkwAAAIBpK7/3xvduajLBD/9vASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb0N6s08NZysLzvj0N+ZC/FnhKTLzIyMtkHf/IrPCwlM+pV/M/96YgAAAIEAqQcGn9CKgzgPaguIZooTAOQdvBLMI5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxcNs4BeVKhy2PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd642982daopE7zQ/NPAnJfag= nocomment"},
|
||||
{"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
|
||||
{"ssh-rsa AAAAB3NzaC1yc2EA\r\nAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+\r\nBZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNx\r\nfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\r\n\r\n"},
|
||||
{"ssh-rsa AAAAB3NzaC1yc2EA\r\nAAADAQABAAAAgQDAu7tvI\nvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+\r\nBZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvW\nqIwC4prx/WVk2wLTJjzBAhyNx\r\nfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\r\n\r\n"},
|
||||
{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf"},
|
||||
{"\r\nssh-ed25519 \r\nAAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf\r\n\r\n"},
|
||||
{"sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
|
||||
{"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
|
||||
{`---- BEGIN SSH2 PUBLIC KEY ----
|
||||
Comment: "1024-bit DSA, converted by andrew@phaedra from OpenSSH"
|
||||
AAAAB3NzaC1kc3MAAACBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn59NriyboW2x/DZuYAz3
|
||||
ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczWOVsaszBZSl90q8UnWlSH6P+/
|
||||
YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQsecdKktISwTakzAAAAFQCzKsO2JhNKlL
|
||||
+wwwLGOcLffoAmkwAAAIBpK7/3xvduajLBD/9vASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8
|
||||
A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb
|
||||
0N6s08NZysLzvj0N+ZC/FnhKTLzIyMtkHf/IrPCwlM+pV/M/96YgAAAIEAqQcGn9CKgzgP
|
||||
aguIZooTAOQdvBLMI5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxc
|
||||
Ns4BeVKhy2PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd6429
|
||||
82daopE7zQ/NPAnJfag=
|
||||
---- END SSH2 PUBLIC KEY ----
|
||||
`},
|
||||
{`---- BEGIN SSH2 PUBLIC KEY ----
|
||||
Comment: "1024-bit RSA, converted by andrew@phaedra from OpenSSH"
|
||||
AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxB
|
||||
cQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIV
|
||||
j0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ==
|
||||
---- END SSH2 PUBLIC KEY ----
|
||||
`},
|
||||
{`-----BEGIN RSA PUBLIC KEY-----
|
||||
MIGJAoGBAMC7u28i9fpketFe5k1+RHdcsdKy4Ir1mfdfnyXEFxDO6jnFmAHq9HDC
|
||||
b9C0m4X7Nk+1jmGxAgsEuYX4FnlakpmnWMF5KMfYbuXF632Rtwf6QhWPS08USjIo
|
||||
j3C9aojALimvH9ZWTbAtMmPMECHI3F8SrsL0J6Jf2lARsSol+QoJAgMBAAE=
|
||||
-----END RSA PUBLIC KEY-----
|
||||
`},
|
||||
{`-----BEGIN PUBLIC KEY-----
|
||||
MIIBtzCCASsGByqGSM44BAEwggEeAoGBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn5
|
||||
9NriyboW2x/DZuYAz3ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczW
|
||||
OVsaszBZSl90q8UnWlSH6P+/YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQse
|
||||
cdKktISwTakzAhUAsyrDtiYTSpS/sMMCxjnC336AJpMCgYBpK7/3xvduajLBD/9v
|
||||
ASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g
|
||||
+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb0N6s08NZysLzvj0N+ZC/FnhKTL
|
||||
zIyMtkHf/IrPCwlM+pV/M/96YgOBhQACgYEAqQcGn9CKgzgPaguIZooTAOQdvBLM
|
||||
I5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxcNs4BeVKhy2
|
||||
PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd642982da
|
||||
opE7zQ/NPAnJfag=
|
||||
-----END PUBLIC KEY-----
|
||||
`},
|
||||
{`-----BEGIN PUBLIC KEY-----
|
||||
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDAu7tvIvX6ZHrRXuZNfkR3XLHS
|
||||
suCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jB
|
||||
eSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C
|
||||
9CeiX9pQEbEqJfkKCQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
`},
|
||||
{`-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzGV4ftTgVMEh/Q+OcE2s
|
||||
RK0CDfSKAvcZezCiZKr077+juUUfWFvyCvRW3414F7KaWBobAmaNYRTjrFxzJ3zj
|
||||
karv8TA8eMj7sryqcOC3jxHIOEw4qWgxbsW1jqnPwVGUWXF7uNUAFnwy6yJ8LJbV
|
||||
mR0nhu4Y4aWnJeBa1b/VdaUujnOUNTccRM087jS0v/HYma05v2AEEP/gfps1iN8x
|
||||
LReJomY4wJY1ndS0wT71Nt3dvQ3AZphWoXGeONV2bE3gMBsRv0Oo/DYDV4/VsTHl
|
||||
sMV1do3gF/xAUqWawlZQkNcibME+sQqfE7gZ04hlmDATU2zmbzwuHtFiNv8mVv7O
|
||||
RQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
`},
|
||||
{`---- BEGIN SSH2 PUBLIC KEY ----
|
||||
Comment: "256-bit ED25519, converted by andrew@phaedra from OpenSSH"
|
||||
AAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf
|
||||
---- END SSH2 PUBLIC KEY ----
|
||||
`},
|
||||
} {
|
||||
_, err := CheckPublicKeyString(test.content)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
setting.SSH.MinimumKeySizeCheck = oldValue
|
||||
for _, invalidKeys := range []struct {
|
||||
content string
|
||||
}{
|
||||
{"test"},
|
||||
{"---- NOT A REAL KEY ----"},
|
||||
{"bad\nkey"},
|
||||
{"\t\t:)\t\r\n"},
|
||||
{"\r\ntest \r\ngitea\r\n\r\n"},
|
||||
} {
|
||||
_, err := CheckPublicKeyString(invalidKeys.content)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_calcFingerprint(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skipSSHKeygen bool
|
||||
fp string
|
||||
content string
|
||||
}{
|
||||
{"rsa-1024", false, "SHA256:vSnDkvRh/xM6kMxPidLgrUhq3mCN7CDaronCEm2joyQ", "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
|
||||
{"rsa-2048", false, "SHA256:ZHD//a1b9VuTq9XSunAeYjKeU1xDa2tBFZYrFr2Okkg", "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDMZXh+1OBUwSH9D45wTaxErQIN9IoC9xl7MKJkqvTvv6O5RR9YW/IK9FbfjXgXsppYGhsCZo1hFOOsXHMnfOORqu/xMDx4yPuyvKpw4LePEcg4TDipaDFuxbWOqc/BUZRZcXu41QAWfDLrInwsltWZHSeG7hjhpacl4FrVv9V1pS6Oc5Q1NxxEzTzuNLS/8diZrTm/YAQQ/+B+mzWI3zEtF4miZjjAljWd1LTBPvU23d29DcBmmFahcZ441XZsTeAwGxG/Q6j8NgNXj9WxMeWwxXV2jeAX/EBSpZrCVlCQ1yJswT6xCp8TuBnTiGWYMBNTbOZvPC4e0WI2/yZW/s5F nocomment"},
|
||||
{"ecdsa-256", false, "SHA256:Bqx/xgWqRKLtkZ0Lr4iZpgb+5lYsFpSwXwVZbPwuTRw", "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFQacN3PrOll7PXmN5B/ZNVahiUIqI05nbBlZk1KXsO3d06ktAWqbNflv2vEmA38bTFTfJ2sbn2B5ksT52cDDbA= nocomment"},
|
||||
{"ecdsa-384", false, "SHA256:4qfJOgJDtUd8BrEjyVNdI8IgjiZKouztVde43aDhe1E", "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBINmioV+XRX1Fm9Qk2ehHXJ2tfVxW30ypUWZw670Zyq5GQfBAH6xjygRsJ5wWsHXBsGYgFUXIHvMKVAG1tpw7s6ax9oA+dJOJ7tj+vhn8joFqT+sg3LYHgZkHrfqryRasQ== nocomment"},
|
||||
{"ecdsa-sk", true, "SHA256:4wcIu4z+53gHc+db85OPfy8IydyNzPLCr6kHIs625LQ", "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
|
||||
{"ed25519-sk", true, "SHA256:RB4ku1OeWKN7fLMrjxz38DK0mp1BnOPBx4BItjTvJ0g", "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("Native", func(t *testing.T) {
|
||||
fpN, err := calcFingerprintNative(tc.content)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.fp, fpN)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// Generated with "ssh-keygen -C test@rekor.dev -f id_rsa"
|
||||
sshPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
|
||||
NhAAAAAwEAAQAAAYEA16H5ImoRO7mr41r8Z8JFBdu6jIM+6XU8M0r9F81RuhLYqzr9zw1n
|
||||
LeGCqFxPXNBKm8ZyH2BCsBHsbXbwe85IMHM3SUh8X/9fI0Lpi5/xbqAproFUpNR+UJYv6s
|
||||
8AaWk5zpN1rmpBrqGFJfGQKJCioDiiwNGmSdVkUNmQmYIANxJMDWYmNe8vUOh6nYEHB+lz
|
||||
fGgDAAzVSXTACW994UkSY47AD05swU4rIT/JWA6BkUrEhO//F0QQhFeROCPJiPRhJXGcFf
|
||||
9SicffJqR/ELzM1zNYnRXMD0bbdTUwDrIcIFFNBbtcfJVOUUCGumSlt+qjUC7y8cvwbHAu
|
||||
wf5nS6baA7P6LfTYplF2XIAkdWtkN6O1ouoyIHICXMlddDW2vNaJeEXTeKjx51WSM7qPnQ
|
||||
ZKsBtwjLQeEY/OPkIvu88lNNYSD63qMUA12msohjwVFCIgJVvYLIrkViczZ7t3L7lgy1X0
|
||||
CJI4e1roOfM/r9jTieyDHchEYpZYcw3L1R2qtePlAAAFiHdJQKl3SUCpAAAAB3NzaC1yc2
|
||||
EAAAGBANeh+SJqETu5q+Na/GfCRQXbuoyDPul1PDNK/RfNUboS2Ks6/c8NZy3hgqhcT1zQ
|
||||
SpvGch9gQrAR7G128HvOSDBzN0lIfF//XyNC6Yuf8W6gKa6BVKTUflCWL+rPAGlpOc6Tda
|
||||
5qQa6hhSXxkCiQoqA4osDRpknVZFDZkJmCADcSTA1mJjXvL1Doep2BBwfpc3xoAwAM1Ul0
|
||||
wAlvfeFJEmOOwA9ObMFOKyE/yVgOgZFKxITv/xdEEIRXkTgjyYj0YSVxnBX/UonH3yakfx
|
||||
C8zNczWJ0VzA9G23U1MA6yHCBRTQW7XHyVTlFAhrpkpbfqo1Au8vHL8GxwLsH+Z0um2gOz
|
||||
+i302KZRdlyAJHVrZDejtaLqMiByAlzJXXQ1trzWiXhF03io8edVkjO6j50GSrAbcIy0Hh
|
||||
GPzj5CL7vPJTTWEg+t6jFANdprKIY8FRQiICVb2CyK5FYnM2e7dy+5YMtV9AiSOHta6Dnz
|
||||
P6/Y04nsgx3IRGKWWHMNy9UdqrXj5QAAAAMBAAEAAAGAJyaOcFQnuttUPRxY9ZHNLGofrc
|
||||
Fqm8KgYoO7/iVWMF2Zn0U/rec2E5t9OIpCEozy7uOR9uZoVUV70sgkk6X5b2qL4C9b/aYF
|
||||
JQbSFnq8wCQuTTPIJYE7SfBq1Mwuu/TR/RLC7B74u/cxkJkSXnscO9Dso+ussH0hEJjf6y
|
||||
8yUM1up4Qjbel2gs8i7BPwLdySDkVoPgsWcpbTAyOODGhTAWZ6soy/rD1AEXJeYTGJDtMv
|
||||
aR+WBihig1TO1g2RWt9bqqiG7PIlljd3ZsjSSU5y3t6ZN/8j5keKD032EtxbZB0WFD3Ar4
|
||||
FbFwlW+urb2MQ0JyNKOio3nhdjolXYkJa+C6LXdaaml/8BhMR1eLoMe8nS45w76o8mdJWX
|
||||
wsirB8tvjCLY0QBXgGv/1DTsKu/wEFCW2/Y0e50gF7pHAlYFNmKDcgI9OyORRYhFbV4D82
|
||||
fI8JLQ42ZJkS/0t6xQma8WC88pbHGEuVSB6CE/p25fyYRX+UPTQ79tWFvLV4kNQAaBAAAA
|
||||
wEvyd6H8ePyBXImg8JzGxthufB0eXSfZBrabjf6e6bR2ivpJsHmB64gbMkV6MFV7EWYX1B
|
||||
wYPQxf4gA2Ez7aJvDtfE7uV6pa0WJS3hW1+be8DHEftmLSbTy/TEvDujNb2gqoi7uWQXWJ
|
||||
yYWZlYO65r1a6HucryQ8+78fTuTRbZALO43vNGz0oXH1hPSddkcbNAhZTsD0rQKNwqVTe5
|
||||
wl+6Cduy/CQwjHLYrY73MyWy1Vh1LXhAdGMPnWZwGIu/dnkgAAAMEA9KuaoGnfnLQkrjeR
|
||||
tO4RCRS2quNRvm4L6i4vHgTDsYtoSlR1ujge7SGOOmIPS4XVjZN5zzCOA7+EDVnuz3WWmx
|
||||
hmkjpG1YxzmJGaWoYdeo3a6UgJtisfMp8eUKqjJT1mhsCliCWtaOQNRoQieDQmgwZzSX/v
|
||||
ZiGsOIKa6cR37eKvOJSjVrHsAUzdtYrmi8P2gvAUFWyzXobAtpzHcWrwWkOEIm04G0OGXb
|
||||
J46hfIX3f45E5EKXvFzexGgVOD2I7hAAAAwQDhniYAizfW9YfG7UJWekkl42xMP7Cb8b0W
|
||||
SindSIuE8bFTukV1yxbmNZp/f0pKvn/DWc2n0I0bwSGZpy8BCY46RKKB2DYQavY/tGcC1N
|
||||
AynKuvbtWs11A0mTXmq3WwHVXQDozMwJ2nnHpm0UHspPuHqkYpurlP+xoFsocaQ9QwITyp
|
||||
lL4qHtXBEzaT8okkcGZBHdSx3gk4TzCsEDOP7ZZPLq42lpKMK10zFPTMd0maXtJDYKU/b4
|
||||
gAATvvPoylyYUAAAAOdGVzdEByZWtvci5kZXYBAgMEBQ==
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
`
|
||||
sshPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDXofkiahE7uavjWvxnwkUF27qMgz7pdTwzSv0XzVG6EtirOv3PDWct4YKoXE9c0EqbxnIfYEKwEextdvB7zkgwczdJSHxf/18jQumLn/FuoCmugVSk1H5Qli/qzwBpaTnOk3WuakGuoYUl8ZAokKKgOKLA0aZJ1WRQ2ZCZggA3EkwNZiY17y9Q6HqdgQcH6XN8aAMADNVJdMAJb33hSRJjjsAPTmzBTishP8lYDoGRSsSE7/8XRBCEV5E4I8mI9GElcZwV/1KJx98mpH8QvMzXM1idFcwPRtt1NTAOshwgUU0Fu1x8lU5RQIa6ZKW36qNQLvLxy/BscC7B/mdLptoDs/ot9NimUXZcgCR1a2Q3o7Wi6jIgcgJcyV10Nba81ol4RdN4qPHnVZIzuo+dBkqwG3CMtB4Rj84+Qi+7zyU01hIPreoxQDXaayiGPBUUIiAlW9gsiuRWJzNnu3cvuWDLVfQIkjh7Wug58z+v2NOJ7IMdyERillhzDcvVHaq14+U= test@rekor.dev
|
||||
`
|
||||
// Generated with "ssh-keygen -C other-test@rekor.dev -f id_rsa"
|
||||
otherSSHPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
|
||||
NhAAAAAwEAAQAAAYEAw/WCSWC9TEvCQOwO+T68EvNa3OSIv1Y0+sT8uSvyjPyEO0+p0t8C
|
||||
g/zy67vOxiQpU5jN6MItjXAjMmeCm8GKMt6gk+cDoaAev/ZfjuzSL7RayExpmhBleh2X3G
|
||||
KLkkXF9ABFNchlTqSLOZiEjDoNpbFv16KT1sE6CqW8DjxXQkQk9JK65hLH+BxeWMNCEJVa
|
||||
Cma4X04aJmC7zJAi5yGeeT0SKVqMohavF90O6XiYFCQHuwXPPyHfocqgudmXnozz+6D6ax
|
||||
JKZMwQsNp3WKumOjlzWnxBCCB1l2jN6Rag8aJ2277iMFXRwjTL/8jaEsW4KkysDf0GjV2/
|
||||
iqbr0q5b0arDYbv7CrGBR+uH0wGz/Zog1x5iZANObhZULpDrLVJidEMc27HXBb7PMsNDy7
|
||||
BGYRB1yc0d0y83p8mUqvOlWSArxn1WnAZO04pAgTrclrhEh4ZXOkn2Sn82eu3DpQ8inkol
|
||||
Y4IfnhIfbOIeemoUNq1tOUquhow9GLRM6INieHLBAAAFkPPnA1jz5wNYAAAAB3NzaC1yc2
|
||||
EAAAGBAMP1gklgvUxLwkDsDvk+vBLzWtzkiL9WNPrE/Lkr8oz8hDtPqdLfAoP88uu7zsYk
|
||||
KVOYzejCLY1wIzJngpvBijLeoJPnA6GgHr/2X47s0i+0WshMaZoQZXodl9xii5JFxfQART
|
||||
XIZU6kizmYhIw6DaWxb9eik9bBOgqlvA48V0JEJPSSuuYSx/gcXljDQhCVWgpmuF9OGiZg
|
||||
u8yQIuchnnk9EilajKIWrxfdDul4mBQkB7sFzz8h36HKoLnZl56M8/ug+msSSmTMELDad1
|
||||
irpjo5c1p8QQggdZdozekWoPGidtu+4jBV0cI0y//I2hLFuCpMrA39Bo1dv4qm69KuW9Gq
|
||||
w2G7+wqxgUfrh9MBs/2aINceYmQDTm4WVC6Q6y1SYnRDHNux1wW+zzLDQ8uwRmEQdcnNHd
|
||||
MvN6fJlKrzpVkgK8Z9VpwGTtOKQIE63Ja4RIeGVzpJ9kp/Nnrtw6UPIp5KJWOCH54SH2zi
|
||||
HnpqFDatbTlKroaMPRi0TOiDYnhywQAAAAMBAAEAAAGAYycx4oEhp55Zz1HijblxnsEmQ8
|
||||
kbbH1pV04fdm7HTxFis0Qu8PVIp5JxNFiWWunnQ1Z5MgI23G9WT+XST4+RpwXBCLWGv9xu
|
||||
UsGOPpqUC/FdUiZf9MXBIxYgRjJS3xORA1KzsnAQ2sclb2I+B1pEl4d9yQWJesvQ25xa2H
|
||||
Utzej/LgWkrk/ogSGRl6ZNImj/421wc0DouGyP+gUgtATt0/jT3LrlmAqUVCXVqssLYH2O
|
||||
r9JTuGUibBJEW2W/c0lsM0jaHa5bGAdL3nhDuF1Q6KFB87mZoNw8c2znYoTzQ3FyWtIEZI
|
||||
V/9oWrkS7V6242SKSR9tJoEzK0jtrKC/FZwBiI4hPcwoqY6fZbT1701i/n50xWEfEUOLVm
|
||||
d6VqNKyAbIaZIPN0qfZuD+xdrHuM3V6k/rgFxGl4XTrp/N4AsruiQs0nRQKNTw3fHE0zPq
|
||||
UTxSeMvjywRCepxhBFCNh8NHydapclHtEPEGdTVHohL3krJehstPO/IuRyKLfSVtL1AAAA
|
||||
wQCmGA8k+uW6mway9J3jp8mlMhhp3DCX6DAcvalbA/S5OcqMyiTM3c/HD5OJ6OYFDldcqu
|
||||
MPEgLRL2HfxL29LsbQSzjyOIrfp5PLJlo70P5lXS8u2QPbo4/KQJmQmsIX18LDyU2zRtNA
|
||||
C2WfBiHSZV+guLhmHms9S5gQYKt2T5OnY/W0tmnInx9lmFCMC+XKS1iSQ2o433IrtCPQJp
|
||||
IXZd59OQpO9QjJABgJIDtXxFIXt45qpXduDPJuggrhg81stOwAAADBAPX73u/CY+QUPts+
|
||||
LV185Z4mZ2y+qu2ZMCAU3BnpHktGZZ1vFN1Xq9o8KdnuPZ+QJRdO8eKMWpySqrIdIbTYLm
|
||||
9nXmVH0uNECIEAvdU+wgKeR+BSHxCRVuTF4YSygmNadgH/z+oRWLgOblGo2ywFBoXsIAKQ
|
||||
paNu1MFGRUmhz67+dcpkkBUDRU9loAgBKexMo8D9vkR0YiHLOUjCrtmEZRNm0YRZt0gQhD
|
||||
ZSD1fOH0fZDcCVNpGP2zqAKos4EGLnkwAAAMEAy/AuLtPKA2u9oCA8e18ZnuQRAi27FBVU
|
||||
rU2D7bMg1eS0IakG8v0gE9K6WdYzyArY1RoKB3ZklK5VmJ1cOcWc2x3Ejc5jcJgc8cC6lZ
|
||||
wwjpE8HfWL1kIIYgPdcexqFc+l6MdgH6QMKU3nLg1LsM4v5FEldtk/2dmnw620xnFfstpF
|
||||
VxSZNdKrYfM/v9o6sRaDRqSfH1dG8BvkUxPznTAF+JDxBENcKXYECcq9f6dcl1w5IEnNTD
|
||||
Wry/EKQvgvOUjbAAAAFG90aGVyLXRlc3RAcmVrb3IuZGV2AQIDBAUG
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
`
|
||||
otherSSHPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDD9YJJYL1MS8JA7A75PrwS81rc5Ii/VjT6xPy5K/KM/IQ7T6nS3wKD/PLru87GJClTmM3owi2NcCMyZ4KbwYoy3qCT5wOhoB6/9l+O7NIvtFrITGmaEGV6HZfcYouSRcX0AEU1yGVOpIs5mISMOg2lsW/XopPWwToKpbwOPFdCRCT0krrmEsf4HF5Yw0IQlVoKZrhfThomYLvMkCLnIZ55PRIpWoyiFq8X3Q7peJgUJAe7Bc8/Id+hyqC52ZeejPP7oPprEkpkzBCw2ndYq6Y6OXNafEEIIHWXaM3pFqDxonbbvuIwVdHCNMv/yNoSxbgqTKwN/QaNXb+KpuvSrlvRqsNhu/sKsYFH64fTAbP9miDXHmJkA05uFlQukOstUmJ0QxzbsdcFvs8yw0PLsEZhEHXJzR3TLzenyZSq86VZICvGfVacBk7TikCBOtyWuESHhlc6SfZKfzZ67cOlDyKeSiVjgh+eEh9s4h56ahQ2rW05Sq6GjD0YtEzog2J4csE= other-test@rekor.dev
|
||||
`
|
||||
|
||||
// Generated with ssh-keygen -C test@rekor.dev -t ed25519 -f id_ed25519
|
||||
ed25519PrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
|
||||
QyNTUxOQAAACBB45zRHxPPFtabwS3Vd6Lb9vMe+tIHZj2qN5VQ+bgLfQAAAJgyRa3cMkWt
|
||||
3AAAAAtzc2gtZWQyNTUxOQAAACBB45zRHxPPFtabwS3Vd6Lb9vMe+tIHZj2qN5VQ+bgLfQ
|
||||
AAAED7y4N/DsVnRQiBZNxEWdsJ9RmbranvtQ3X9jnb6gFed0HjnNEfE88W1pvBLdV3otv2
|
||||
8x760gdmPao3lVD5uAt9AAAADnRlc3RAcmVrb3IuZGV2AQIDBAUGBw==
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
`
|
||||
ed25519PublicKey = `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEHjnNEfE88W1pvBLdV3otv28x760gdmPao3lVD5uAt9 test@rekor.dev
|
||||
`
|
||||
)
|
||||
|
||||
func TestFromOpenSSH(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
pub string
|
||||
priv string
|
||||
}{
|
||||
{
|
||||
name: "rsa",
|
||||
pub: sshPublicKey,
|
||||
priv: sshPrivateKey,
|
||||
},
|
||||
{
|
||||
name: "ed25519",
|
||||
pub: ed25519PublicKey,
|
||||
priv: ed25519PrivateKey,
|
||||
},
|
||||
} {
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
t.Skip("skip TestFromOpenSSH: missing ssh-keygen in PATH")
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt := tt
|
||||
|
||||
// Test that a signature from the cli can validate here.
|
||||
td := t.TempDir()
|
||||
|
||||
data := []byte("hello, ssh world")
|
||||
dataPath := write(t, data, td, "data")
|
||||
|
||||
privPath := write(t, []byte(tt.priv), td, "id")
|
||||
write(t, []byte(tt.pub), td, "id.pub")
|
||||
|
||||
sigPath := dataPath + ".sig"
|
||||
run(t, nil, "ssh-keygen", "-Y", "sign", "-n", "file", "-f", privPath, dataPath)
|
||||
|
||||
sigBytes, err := os.ReadFile(sigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := sshsig.Verify(bytes.NewReader(data), sigBytes, []byte(tt.pub), "file"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// It should not verify if we check against another public key
|
||||
if err := sshsig.Verify(bytes.NewReader(data), sigBytes, []byte(otherSSHPublicKey), "file"); err == nil {
|
||||
t.Error("expected error with incorrect key")
|
||||
}
|
||||
|
||||
// It should not verify if the data is tampered
|
||||
if err := sshsig.Verify(strings.NewReader("bad data"), sigBytes, []byte(sshPublicKey), "file"); err == nil {
|
||||
t.Error("expected error with incorrect data")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenSSH(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
pub string
|
||||
priv string
|
||||
}{
|
||||
{
|
||||
name: "rsa",
|
||||
pub: sshPublicKey,
|
||||
priv: sshPrivateKey,
|
||||
},
|
||||
{
|
||||
name: "ed25519",
|
||||
pub: ed25519PublicKey,
|
||||
priv: ed25519PrivateKey,
|
||||
},
|
||||
} {
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
t.Skip("skip TestToOpenSSH: missing ssh-keygen in PATH")
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt := tt
|
||||
// Test that a signature from here can validate in the CLI.
|
||||
td := t.TempDir()
|
||||
|
||||
data := []byte("hello, ssh world")
|
||||
write(t, data, td, "data")
|
||||
|
||||
armored, err := sshsig.Sign([]byte(tt.priv), bytes.NewReader(data), "file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sigPath := write(t, armored, td, "oursig")
|
||||
|
||||
// Create an allowed_signers file with two keys to check against.
|
||||
allowedSigner := "test@rekor.dev " + tt.pub + "\n"
|
||||
allowedSigner += "othertest@rekor.dev " + otherSSHPublicKey + "\n"
|
||||
allowedSigners := write(t, []byte(allowedSigner), td, "allowed_signer")
|
||||
|
||||
// We use the correct principal here so it should work.
|
||||
run(t, data, "ssh-keygen", "-Y", "verify", "-f", allowedSigners,
|
||||
"-I", "test@rekor.dev", "-n", "file", "-s", sigPath)
|
||||
|
||||
// Just to be sure, check against the other public key as well.
|
||||
runErr(t, data, "ssh-keygen", "-Y", "verify", "-f", allowedSigners,
|
||||
"-I", "othertest@rekor.dev", "-n", "file", "-s", sigPath)
|
||||
|
||||
// It should error if we run it against other data
|
||||
data = []byte("other data!")
|
||||
runErr(t, data, "ssh-keygen", "-Y", "check-novalidate", "-n", "file", "-s", sigPath)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
data := []byte("my good data to be signed!")
|
||||
|
||||
// Create one extra signature for all the tests.
|
||||
otherSig, err := sshsig.Sign([]byte(otherSSHPrivateKey), bytes.NewReader(data), "file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
pub string
|
||||
priv string
|
||||
}{
|
||||
{
|
||||
name: "rsa",
|
||||
pub: sshPublicKey,
|
||||
priv: sshPrivateKey,
|
||||
},
|
||||
{
|
||||
name: "ed25519",
|
||||
pub: ed25519PublicKey,
|
||||
priv: ed25519PrivateKey,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt := tt
|
||||
sig, err := sshsig.Sign([]byte(tt.priv), bytes.NewReader(data), "file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check the signature against that data and public key
|
||||
if err := sshsig.Verify(bytes.NewReader(data), sig, []byte(tt.pub), "file"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Now check it against invalid data.
|
||||
if err := sshsig.Verify(strings.NewReader("invalid data!"), sig, []byte(tt.pub), "file"); err == nil {
|
||||
t.Error("expected error!")
|
||||
}
|
||||
|
||||
// Now check it against the wrong key.
|
||||
if err := sshsig.Verify(bytes.NewReader(data), sig, []byte(otherSSHPublicKey), "file"); err == nil {
|
||||
t.Error("expected error!")
|
||||
}
|
||||
|
||||
// Now check it against an invalid signature data.
|
||||
if err := sshsig.Verify(bytes.NewReader(data), []byte("invalid signature!"), []byte(tt.pub), "file"); err == nil {
|
||||
t.Error("expected error!")
|
||||
}
|
||||
|
||||
// Once more, use the wrong signature and check it against the original (wrong public key)
|
||||
if err := sshsig.Verify(bytes.NewReader(data), otherSig, []byte(tt.pub), "file"); err == nil {
|
||||
t.Error("expected error!")
|
||||
}
|
||||
// It should work against the correct public key.
|
||||
if err := sshsig.Verify(bytes.NewReader(data), otherSig, []byte(otherSSHPublicKey), "file"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func write(t *testing.T, d []byte, fp ...string) string {
|
||||
p := filepath.Join(fp...)
|
||||
if err := os.WriteFile(p, d, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func run(t *testing.T, stdin []byte, args ...string) {
|
||||
t.Helper()
|
||||
/* #nosec */
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = bytes.NewReader(stdin)
|
||||
out, err := cmd.CombinedOutput()
|
||||
t.Logf("cmd %v: %s", cmd, string(out))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func runErr(t *testing.T, stdin []byte, args ...string) {
|
||||
t.Helper()
|
||||
/* #nosec */
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = bytes.NewReader(stdin)
|
||||
out, err := cmd.CombinedOutput()
|
||||
t.Logf("cmd %v: %s", cmd, string(out))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PublicKeysAreExternallyManaged(t *testing.T) {
|
||||
key1 := unittest.AssertExistsAndLoadBean(t, &PublicKey{ID: 1})
|
||||
externals, err := PublicKeysAreExternallyManaged(t.Context(), []*PublicKey{key1})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, externals, 1)
|
||||
assert.False(t, externals[0])
|
||||
}
|
||||
47
models/asymkey/ssh_key_verify.go
Normal file
47
models/asymkey/ssh_key_verify.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package asymkey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
|
||||
"github.com/42wim/sshsig"
|
||||
)
|
||||
|
||||
// VerifySSHKey marks a SSH key as verified
|
||||
func VerifySSHKey(ctx context.Context, ownerID int64, fingerprint, token, signature string) (string, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (string, error) {
|
||||
key := new(PublicKey)
|
||||
|
||||
has, err := db.GetEngine(ctx).Where("owner_id = ? AND fingerprint = ?", ownerID, fingerprint).Get(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if !has {
|
||||
return "", ErrKeyNotExist{}
|
||||
}
|
||||
|
||||
err = sshsig.Verify(strings.NewReader(token), []byte(signature), []byte(key.Content), "gitea")
|
||||
if err != nil {
|
||||
// edge case for Windows based shells that will add CR LF if piped to ssh-keygen command
|
||||
// see https://github.com/PowerShell/PowerShell/issues/5974
|
||||
if sshsig.Verify(strings.NewReader(token+"\r\n"), []byte(signature), []byte(key.Content), "gitea") != nil {
|
||||
log.Debug("VerifySSHKey sshsig.Verify failed: %v", err)
|
||||
return "", ErrSSHInvalidTokenSignature{
|
||||
Fingerprint: key.Fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
key.Verified = true
|
||||
if _, err := db.GetEngine(ctx).ID(key.ID).Cols("verified").Update(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return key.Fingerprint, nil
|
||||
})
|
||||
}
|
||||
236
models/auth/access_token.go
Normal file
236
models/auth/access_token.go
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error.
|
||||
type ErrAccessTokenNotExist struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist.
|
||||
func IsErrAccessTokenNotExist(err error) bool {
|
||||
_, ok := err.(ErrAccessTokenNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenNotExist) Error() string {
|
||||
return fmt.Sprintf("access token does not exist [sha: %s]", err.Token)
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error.
|
||||
type ErrAccessTokenEmpty struct{}
|
||||
|
||||
// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty.
|
||||
func IsErrAccessTokenEmpty(err error) bool {
|
||||
_, ok := err.(ErrAccessTokenEmpty)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenEmpty) Error() string {
|
||||
return "access token is empty"
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenEmpty) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
var successfulAccessTokenCache *lru.Cache[string, any]
|
||||
|
||||
// AccessToken represents a personal access token.
|
||||
type AccessToken struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"INDEX"`
|
||||
Name string
|
||||
Token string `xorm:"-"`
|
||||
TokenHash string `xorm:"UNIQUE"` // sha256 of token
|
||||
TokenSalt string
|
||||
TokenLastEight string `xorm:"INDEX token_last_eight"`
|
||||
Scope AccessTokenScope
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
HasRecentActivity bool `xorm:"-"`
|
||||
HasUsed bool `xorm:"-"`
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (t *AccessToken) AfterLoad() {
|
||||
t.HasUsed = t.UpdatedUnix > t.CreatedUnix
|
||||
t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(AccessToken), func() error {
|
||||
if setting.SuccessfulTokensCacheSize > 0 {
|
||||
var err error
|
||||
successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to allocate AccessToken cache: %w", err)
|
||||
}
|
||||
} else {
|
||||
successfulAccessTokenCache = nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// NewAccessToken creates new access token.
|
||||
func NewAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
salt, err := util.CryptoRandomString(10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token, err := util.CryptoRandomBytes(20)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.TokenSalt = salt
|
||||
t.Token = hex.EncodeToString(token)
|
||||
t.TokenHash = HashToken(t.Token, t.TokenSalt)
|
||||
t.TokenLastEight = t.Token[len(t.Token)-8:]
|
||||
_, err = db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// DisplayPublicOnly whether to display this as a public-only token.
|
||||
func (t *AccessToken) DisplayPublicOnly() bool {
|
||||
publicOnly, err := t.Scope.PublicOnly()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return publicOnly
|
||||
}
|
||||
|
||||
func getAccessTokenIDFromCache(token string) int64 {
|
||||
if successfulAccessTokenCache == nil {
|
||||
return 0
|
||||
}
|
||||
tInterface, ok := successfulAccessTokenCache.Get(token)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
t, ok := tInterface.(int64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// GetAccessTokenBySHA returns access token by given token value
|
||||
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
|
||||
if token == "" {
|
||||
return nil, ErrAccessTokenEmpty{}
|
||||
}
|
||||
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
|
||||
if len(token) != 40 {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
for _, x := range []byte(token) {
|
||||
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
}
|
||||
|
||||
lastEight := token[len(token)-8:]
|
||||
|
||||
if id := getAccessTokenIDFromCache(token); id > 0 {
|
||||
accessToken := &AccessToken{
|
||||
TokenLastEight: lastEight,
|
||||
}
|
||||
// Re-get the token from the db in case it has been deleted in the intervening period
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if has {
|
||||
return accessToken, nil
|
||||
}
|
||||
successfulAccessTokenCache.Remove(token)
|
||||
}
|
||||
|
||||
var tokens []AccessToken
|
||||
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(tokens) == 0 {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
tempHash := HashToken(token, t.TokenSalt)
|
||||
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
|
||||
if successfulAccessTokenCache != nil {
|
||||
successfulAccessTokenCache.Add(token, t.ID)
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
|
||||
// AccessTokenByNameExists checks if a token name has been used already by a user.
|
||||
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
|
||||
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
|
||||
}
|
||||
|
||||
// ListAccessTokensOptions contain filter options
|
||||
type ListAccessTokensOptions struct {
|
||||
db.ListOptions
|
||||
Name string
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (opts ListAccessTokensOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
// user id is required, otherwise it will return all result which maybe a possible bug
|
||||
cond = cond.And(builder.Eq{"uid": opts.UserID})
|
||||
if len(opts.Name) > 0 {
|
||||
cond = cond.And(builder.Eq{"name": opts.Name})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts ListAccessTokensOptions) ToOrders() string {
|
||||
return "created_unix DESC"
|
||||
}
|
||||
|
||||
// UpdateAccessToken updates information of access token.
|
||||
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteAccessTokenByID deletes access token by given ID.
|
||||
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrAccessTokenNotExist{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
377
models/auth/access_token_scope.go
Normal file
377
models/auth/access_token_scope.go
Normal file
@@ -0,0 +1,377 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/perm"
|
||||
)
|
||||
|
||||
// AccessTokenScopeCategory represents the scope category for an access token
|
||||
type AccessTokenScopeCategory int
|
||||
|
||||
const (
|
||||
AccessTokenScopeCategoryActivityPub AccessTokenScopeCategory = iota
|
||||
AccessTokenScopeCategoryAdmin
|
||||
AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values
|
||||
AccessTokenScopeCategoryNotification
|
||||
AccessTokenScopeCategoryOrganization
|
||||
AccessTokenScopeCategoryPackage
|
||||
AccessTokenScopeCategoryIssue
|
||||
AccessTokenScopeCategoryRepository
|
||||
AccessTokenScopeCategoryUser
|
||||
)
|
||||
|
||||
// AllAccessTokenScopeCategories contains all access token scope categories
|
||||
var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{
|
||||
AccessTokenScopeCategoryActivityPub,
|
||||
AccessTokenScopeCategoryAdmin,
|
||||
AccessTokenScopeCategoryMisc,
|
||||
AccessTokenScopeCategoryNotification,
|
||||
AccessTokenScopeCategoryOrganization,
|
||||
AccessTokenScopeCategoryPackage,
|
||||
AccessTokenScopeCategoryIssue,
|
||||
AccessTokenScopeCategoryRepository,
|
||||
AccessTokenScopeCategoryUser,
|
||||
}
|
||||
|
||||
// AccessTokenScopeLevel represents the access levels without a given scope category
|
||||
type AccessTokenScopeLevel int
|
||||
|
||||
const (
|
||||
NoAccess AccessTokenScopeLevel = iota
|
||||
Read
|
||||
Write
|
||||
)
|
||||
|
||||
// AccessTokenScope represents the scope for an access token.
|
||||
type AccessTokenScope string
|
||||
|
||||
// for all categories, write implies read
|
||||
const (
|
||||
AccessTokenScopeAll AccessTokenScope = "all"
|
||||
AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos
|
||||
|
||||
AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub"
|
||||
AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub"
|
||||
|
||||
AccessTokenScopeReadAdmin AccessTokenScope = "read:admin"
|
||||
AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin"
|
||||
|
||||
AccessTokenScopeReadMisc AccessTokenScope = "read:misc"
|
||||
AccessTokenScopeWriteMisc AccessTokenScope = "write:misc"
|
||||
|
||||
AccessTokenScopeReadNotification AccessTokenScope = "read:notification"
|
||||
AccessTokenScopeWriteNotification AccessTokenScope = "write:notification"
|
||||
|
||||
AccessTokenScopeReadOrganization AccessTokenScope = "read:organization"
|
||||
AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization"
|
||||
|
||||
AccessTokenScopeReadPackage AccessTokenScope = "read:package"
|
||||
AccessTokenScopeWritePackage AccessTokenScope = "write:package"
|
||||
|
||||
AccessTokenScopeReadIssue AccessTokenScope = "read:issue"
|
||||
AccessTokenScopeWriteIssue AccessTokenScope = "write:issue"
|
||||
|
||||
AccessTokenScopeReadRepository AccessTokenScope = "read:repository"
|
||||
AccessTokenScopeWriteRepository AccessTokenScope = "write:repository"
|
||||
|
||||
AccessTokenScopeReadUser AccessTokenScope = "read:user"
|
||||
AccessTokenScopeWriteUser AccessTokenScope = "write:user"
|
||||
)
|
||||
|
||||
// accessTokenScopeBitmap represents a bitmap of access token scopes.
|
||||
type accessTokenScopeBitmap uint64
|
||||
|
||||
// Bitmap of each scope, including the child scopes.
|
||||
const (
|
||||
// AccessTokenScopeAllBits is the bitmap of all access token scopes
|
||||
accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits |
|
||||
accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits |
|
||||
accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits |
|
||||
accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits
|
||||
|
||||
accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota
|
||||
|
||||
accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadActivityPubBits
|
||||
|
||||
accessTokenScopeReadAdminBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteAdminBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadAdminBits
|
||||
|
||||
accessTokenScopeReadMiscBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteMiscBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadMiscBits
|
||||
|
||||
accessTokenScopeReadNotificationBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteNotificationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadNotificationBits
|
||||
|
||||
accessTokenScopeReadOrganizationBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteOrganizationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadOrganizationBits
|
||||
|
||||
accessTokenScopeReadPackageBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWritePackageBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadPackageBits
|
||||
|
||||
accessTokenScopeReadIssueBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteIssueBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadIssueBits
|
||||
|
||||
accessTokenScopeReadRepositoryBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteRepositoryBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadRepositoryBits
|
||||
|
||||
accessTokenScopeReadUserBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteUserBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadUserBits
|
||||
|
||||
// The current implementation only supports up to 64 token scopes.
|
||||
// If we need to support > 64 scopes,
|
||||
// refactoring the whole implementation in this file (and only this file) is needed.
|
||||
)
|
||||
|
||||
// allAccessTokenScopes contains all access token scopes.
|
||||
// The order is important: parent scope must precede child scopes.
|
||||
var allAccessTokenScopes = []AccessTokenScope{
|
||||
AccessTokenScopePublicOnly,
|
||||
AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub,
|
||||
AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin,
|
||||
AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc,
|
||||
AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification,
|
||||
AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization,
|
||||
AccessTokenScopeWritePackage, AccessTokenScopeReadPackage,
|
||||
AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue,
|
||||
AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository,
|
||||
AccessTokenScopeWriteUser, AccessTokenScopeReadUser,
|
||||
}
|
||||
|
||||
// allAccessTokenScopeBits contains all access token scopes.
|
||||
var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{
|
||||
AccessTokenScopeAll: accessTokenScopeAllBits,
|
||||
AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits,
|
||||
AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits,
|
||||
AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits,
|
||||
AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits,
|
||||
AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits,
|
||||
AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits,
|
||||
AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits,
|
||||
AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits,
|
||||
AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits,
|
||||
AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits,
|
||||
AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits,
|
||||
AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits,
|
||||
AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits,
|
||||
AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits,
|
||||
AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits,
|
||||
AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits,
|
||||
AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits,
|
||||
AccessTokenScopeReadUser: accessTokenScopeReadUserBits,
|
||||
AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits,
|
||||
}
|
||||
|
||||
// readAccessTokenScopes maps a scope category to the read permission scope
|
||||
var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{
|
||||
Read: {
|
||||
AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub,
|
||||
AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin,
|
||||
AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc,
|
||||
AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification,
|
||||
AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization,
|
||||
AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage,
|
||||
AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue,
|
||||
AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository,
|
||||
AccessTokenScopeCategoryUser: AccessTokenScopeReadUser,
|
||||
},
|
||||
Write: {
|
||||
AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub,
|
||||
AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin,
|
||||
AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc,
|
||||
AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification,
|
||||
AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization,
|
||||
AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage,
|
||||
AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue,
|
||||
AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository,
|
||||
AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser,
|
||||
},
|
||||
}
|
||||
|
||||
func GetAccessTokenCategories() (res []string) {
|
||||
for _, cat := range accessTokenScopes[Read] {
|
||||
res = append(res, strings.TrimPrefix(string(cat), "read:"))
|
||||
}
|
||||
slices.Sort(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// GetRequiredScopes gets the specific scopes for a given level and categories
|
||||
func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope {
|
||||
scopes := make([]AccessTokenScope, 0, len(scopeCategories))
|
||||
for _, cat := range scopeCategories {
|
||||
scopes = append(scopes, accessTokenScopes[level][cat])
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
// ContainsCategory checks if a list of categories contains a specific category
|
||||
func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool {
|
||||
return slices.Contains(categories, category)
|
||||
}
|
||||
|
||||
// GetScopeLevelFromAccessMode converts permission access mode to scope level
|
||||
func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel {
|
||||
switch mode {
|
||||
case perm.AccessModeNone:
|
||||
return NoAccess
|
||||
case perm.AccessModeRead:
|
||||
return Read
|
||||
case perm.AccessModeWrite:
|
||||
return Write
|
||||
case perm.AccessModeAdmin:
|
||||
return Write
|
||||
case perm.AccessModeOwner:
|
||||
return Write
|
||||
default:
|
||||
return NoAccess
|
||||
}
|
||||
}
|
||||
|
||||
// parse the scope string into a bitmap, thus removing possible duplicates.
|
||||
func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) {
|
||||
var bitmap accessTokenScopeBitmap
|
||||
|
||||
// The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code
|
||||
remainingScopes := string(s)
|
||||
for len(remainingScopes) > 0 {
|
||||
i := strings.IndexByte(remainingScopes, ',')
|
||||
var v string
|
||||
if i < 0 {
|
||||
v = remainingScopes
|
||||
remainingScopes = ""
|
||||
} else if i+1 >= len(remainingScopes) {
|
||||
v = remainingScopes[:i]
|
||||
remainingScopes = ""
|
||||
} else {
|
||||
v = remainingScopes[:i]
|
||||
remainingScopes = remainingScopes[i+1:]
|
||||
}
|
||||
singleScope := AccessTokenScope(v)
|
||||
if singleScope == "" {
|
||||
continue
|
||||
}
|
||||
if singleScope == AccessTokenScopeAll {
|
||||
bitmap |= accessTokenScopeAllBits
|
||||
continue
|
||||
}
|
||||
|
||||
bits, ok := allAccessTokenScopeBits[singleScope]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid access token scope: %s", singleScope)
|
||||
}
|
||||
bitmap |= bits
|
||||
}
|
||||
|
||||
return bitmap, nil
|
||||
}
|
||||
|
||||
// StringSlice returns the AccessTokenScope as a []string
|
||||
func (s AccessTokenScope) StringSlice() []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(string(s), ",")
|
||||
}
|
||||
|
||||
// Normalize returns a normalized scope string without any duplicates.
|
||||
func (s AccessTokenScope) Normalize() (AccessTokenScope, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return bitmap.toScope(), nil
|
||||
}
|
||||
|
||||
func (s AccessTokenScope) HasPermissionScope() bool {
|
||||
return s != "" && s != AccessTokenScopePublicOnly
|
||||
}
|
||||
|
||||
// PublicOnly checks if this token scope is limited to public resources
|
||||
func (s AccessTokenScope) PublicOnly() (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return bitmap.hasScope(AccessTokenScopePublicOnly)
|
||||
}
|
||||
|
||||
// HasScope returns true if the string has the given scope
|
||||
func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, s := range scopes {
|
||||
if has, err := bitmap.hasScope(s); !has || err != nil {
|
||||
return has, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// HasAnyScope returns true if any of the scopes is contained in the string
|
||||
func (s AccessTokenScope) HasAnyScope(scopes ...AccessTokenScope) (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, s := range scopes {
|
||||
if has, err := bitmap.hasScope(s); has || err != nil {
|
||||
return has, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// hasScope returns true if the string has the given scope
|
||||
func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) {
|
||||
expectedBits, ok := allAccessTokenScopeBits[scope]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid access token scope: %s", scope)
|
||||
}
|
||||
|
||||
return bitmap&expectedBits == expectedBits, nil
|
||||
}
|
||||
|
||||
// toScope returns a normalized scope string without any duplicates.
|
||||
func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope {
|
||||
var scopes []string
|
||||
|
||||
// iterate over all scopes, and reconstruct the bitmap
|
||||
// if the reconstructed bitmap doesn't change, then the scope is already included
|
||||
var reconstruct accessTokenScopeBitmap
|
||||
|
||||
for _, singleScope := range allAccessTokenScopes {
|
||||
// no need for error checking here, since we know the scope is valid
|
||||
if ok, _ := bitmap.hasScope(singleScope); ok {
|
||||
current := reconstruct | allAccessTokenScopeBits[singleScope]
|
||||
if current == reconstruct {
|
||||
continue
|
||||
}
|
||||
|
||||
reconstruct = current
|
||||
scopes = append(scopes, string(singleScope))
|
||||
}
|
||||
}
|
||||
|
||||
scope := AccessTokenScope(strings.Join(scopes, ","))
|
||||
scope = AccessTokenScope(strings.ReplaceAll(
|
||||
string(scope),
|
||||
"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user",
|
||||
"all",
|
||||
))
|
||||
return scope
|
||||
}
|
||||
91
models/auth/access_token_scope_test.go
Normal file
91
models/auth/access_token_scope_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type scopeTestNormalize struct {
|
||||
in AccessTokenScope
|
||||
out AccessTokenScope
|
||||
err error
|
||||
}
|
||||
|
||||
func TestAccessTokenScope_Normalize(t *testing.T) {
|
||||
assert.Equal(t, []string{"activitypub", "admin", "issue", "misc", "notification", "organization", "package", "repository", "user"}, GetAccessTokenCategories())
|
||||
tests := []scopeTestNormalize{
|
||||
{"", "", nil},
|
||||
{"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil},
|
||||
{"all", "all", nil},
|
||||
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil},
|
||||
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil},
|
||||
}
|
||||
|
||||
for _, scope := range GetAccessTokenCategories() {
|
||||
tests = append(tests,
|
||||
scopeTestNormalize{AccessTokenScope("read:" + scope), AccessTokenScope("read:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope("write:" + scope), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(string(test.in), func(t *testing.T) {
|
||||
scope, err := test.in.Normalize()
|
||||
assert.Equal(t, test.out, scope)
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type scopeTestHasScope struct {
|
||||
in AccessTokenScope
|
||||
scope AccessTokenScope
|
||||
out bool
|
||||
err error
|
||||
}
|
||||
|
||||
func TestAccessTokenScope_HasScope(t *testing.T) {
|
||||
tests := []scopeTestHasScope{
|
||||
{"read:admin", "write:package", false, nil},
|
||||
{"all", "write:package", true, nil},
|
||||
{"write:package", "all", false, nil},
|
||||
{"public-only", "read:issue", false, nil},
|
||||
}
|
||||
|
||||
for _, scope := range GetAccessTokenCategories() {
|
||||
tests = append(tests,
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("read:" + scope),
|
||||
AccessTokenScope("read:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("write:" + scope),
|
||||
AccessTokenScope("write:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("write:" + scope),
|
||||
AccessTokenScope("read:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("read:" + scope),
|
||||
AccessTokenScope("write:" + scope), false, nil,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(string(test.in), func(t *testing.T) {
|
||||
hasScope, err := test.in.HasScope(test.scope)
|
||||
assert.Equal(t, test.out, hasScope)
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
132
models/auth/access_token_test.go
Normal file
132
models/auth/access_token_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2016 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAccessToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := &auth_model.AccessToken{
|
||||
UID: 3,
|
||||
Name: "Token C",
|
||||
}
|
||||
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
invalidToken := &auth_model.AccessToken{
|
||||
ID: token.ID, // duplicate
|
||||
UID: 2,
|
||||
Name: "Token F",
|
||||
}
|
||||
assert.Error(t, auth_model.NewAccessToken(t.Context(), invalidToken))
|
||||
}
|
||||
|
||||
func TestAccessTokenByNameExists(t *testing.T) {
|
||||
name := "Token Gitea"
|
||||
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := &auth_model.AccessToken{
|
||||
UID: 3,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
// Check to make sure it doesn't exists already
|
||||
exist, err := auth_model.AccessTokenByNameExists(t.Context(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
|
||||
// Save it to the database
|
||||
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
// This token must be found by name in the DB now
|
||||
exist, err = auth_model.AccessTokenByNameExists(t.Context(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exist)
|
||||
|
||||
user4Token := &auth_model.AccessToken{
|
||||
UID: 4,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
// Name matches but different user ID, this shouldn't exists in the
|
||||
// database
|
||||
exist, err = auth_model.AccessTokenByNameExists(t.Context(), user4Token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
}
|
||||
|
||||
func TestGetAccessTokenBySHA(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
assert.Equal(t, "Token A", token.Name)
|
||||
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
|
||||
assert.Equal(t, "e4efbf36", token.TokenLastEight)
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "notahash")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
|
||||
}
|
||||
|
||||
func TestListAccessTokens(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
tokens, err := db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 2) {
|
||||
assert.Equal(t, int64(1), tokens[0].UID)
|
||||
assert.Equal(t, int64(1), tokens[1].UID)
|
||||
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A")
|
||||
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
|
||||
}
|
||||
|
||||
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 2})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 1) {
|
||||
assert.Equal(t, int64(2), tokens[0].UID)
|
||||
assert.Equal(t, "Token A", tokens[0].Name)
|
||||
}
|
||||
|
||||
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 100})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, tokens)
|
||||
}
|
||||
|
||||
func TestUpdateAccessToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
token.Name = "Token Z"
|
||||
|
||||
assert.NoError(t, auth_model.UpdateAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
}
|
||||
|
||||
func TestDeleteAccessTokenByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
|
||||
assert.NoError(t, auth_model.DeleteAccessTokenByID(t.Context(), token.ID, 1))
|
||||
unittest.AssertNotExistsBean(t, token)
|
||||
|
||||
err = auth_model.DeleteAccessTokenByID(t.Context(), 100, 100)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
}
|
||||
65
models/auth/auth_token.go
Normal file
65
models/auth/auth_token.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
var ErrAuthTokenNotExist = util.NewNotExistErrorf("auth token does not exist")
|
||||
|
||||
type AuthToken struct { //nolint:revive // export stutter
|
||||
ID string `xorm:"pk"`
|
||||
TokenHash string
|
||||
UserID int64 `xorm:"INDEX"`
|
||||
ExpiresUnix timeutil.TimeStamp `xorm:"INDEX"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(AuthToken))
|
||||
}
|
||||
|
||||
func InsertAuthToken(ctx context.Context, t *AuthToken) error {
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetAuthTokenByID(ctx context.Context, id string) (*AuthToken, error) {
|
||||
at := &AuthToken{}
|
||||
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(at)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrAuthTokenNotExist
|
||||
}
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func UpdateAuthTokenByID(ctx context.Context, t *AuthToken) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).Cols("token_hash", "expires_unix").Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAuthTokenByID(ctx context.Context, id string) error {
|
||||
_, err := db.GetEngine(ctx).ID(id).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAuthTokensByUserID(ctx context.Context, uid int64) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Eq{"user_id": uid}).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteExpiredAuthTokens(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Lt{"expires_unix": timeutil.TimeStampNow()}).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
20
models/auth/main_test.go
Normal file
20
models/auth/main_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
_ "code.gitea.io/gitea/models"
|
||||
_ "code.gitea.io/gitea/models/actions"
|
||||
_ "code.gitea.io/gitea/models/activities"
|
||||
_ "code.gitea.io/gitea/models/auth"
|
||||
_ "code.gitea.io/gitea/models/perm/access"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
636
models/auth/oauth2.go
Normal file
636
models/auth/oauth2.go
Normal file
@@ -0,0 +1,636 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
uuid "github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// OAuth2Application represents an OAuth2 client (RFC 6749)
|
||||
type OAuth2Application struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"INDEX"`
|
||||
Name string
|
||||
ClientID string `xorm:"unique"`
|
||||
ClientSecret string
|
||||
// OAuth defines both Confidential and Public client types
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-2.1
|
||||
// "Authorization servers MUST record the client type in the client registration details"
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-8.4
|
||||
ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"`
|
||||
SkipSecondaryAuthorization bool `xorm:"NOT NULL DEFAULT FALSE"`
|
||||
RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(OAuth2Application))
|
||||
db.RegisterModel(new(OAuth2AuthorizationCode))
|
||||
db.RegisterModel(new(OAuth2Grant))
|
||||
}
|
||||
|
||||
type BuiltinOAuth2Application struct {
|
||||
ConfigName string
|
||||
DisplayName string
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
func BuiltinApplications() map[string]*BuiltinOAuth2Application {
|
||||
m := make(map[string]*BuiltinOAuth2Application)
|
||||
m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "git-credential-oauth",
|
||||
DisplayName: "git-credential-oauth",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "git-credential-manager",
|
||||
DisplayName: "Git Credential Manager",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "tea",
|
||||
DisplayName: "tea",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
builtinApps := BuiltinApplications()
|
||||
var builtinAllClientIDs []string
|
||||
for clientID := range builtinApps {
|
||||
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
|
||||
}
|
||||
|
||||
var registeredApps []*OAuth2Application
|
||||
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(®isteredApps); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientIDsToAdd := container.Set[string]{}
|
||||
for _, configName := range setting.OAuth2.DefaultApplications {
|
||||
found := false
|
||||
for clientID, builtinApp := range builtinApps {
|
||||
if builtinApp.ConfigName == configName {
|
||||
clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("unknown oauth2 application: %q", configName)
|
||||
}
|
||||
}
|
||||
clientIDsToDelete := container.Set[string]{}
|
||||
for _, app := range registeredApps {
|
||||
if !clientIDsToAdd.Contains(app.ClientID) {
|
||||
clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted
|
||||
}
|
||||
}
|
||||
for _, app := range registeredApps {
|
||||
clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set
|
||||
}
|
||||
|
||||
for _, app := range registeredApps {
|
||||
if clientIDsToDelete.Contains(app.ClientID) {
|
||||
if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for clientID := range clientIDsToAdd {
|
||||
builtinApp := builtinApps[clientID]
|
||||
if err := db.Insert(ctx, &OAuth2Application{
|
||||
Name: builtinApp.DisplayName,
|
||||
ClientID: clientID,
|
||||
RedirectURIs: builtinApp.RedirectURIs,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_application`
|
||||
func (app *OAuth2Application) TableName() string {
|
||||
return "oauth2_application"
|
||||
}
|
||||
|
||||
// ContainsRedirectURI checks if redirectURI is allowed for app
|
||||
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
|
||||
// OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed.
|
||||
// https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2
|
||||
// https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1
|
||||
contains := func(s string) bool {
|
||||
s = strings.TrimSuffix(strings.ToLower(s), "/")
|
||||
for _, u := range app.RedirectURIs {
|
||||
if strings.TrimSuffix(strings.ToLower(u), "/") == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !app.ConfidentialClient {
|
||||
uri, err := url.Parse(redirectURI)
|
||||
// ignore port for http loopback uris following https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
if err == nil && uri.Scheme == "http" && uri.Port() != "" {
|
||||
ip := net.ParseIP(uri.Hostname())
|
||||
if ip != nil && ip.IsLoopback() {
|
||||
// strip port
|
||||
uri.Host = uri.Hostname()
|
||||
if contains(uri.String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return contains(redirectURI)
|
||||
}
|
||||
|
||||
// Base32 characters, but lowercased.
|
||||
const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567"
|
||||
|
||||
// base32 encoder that uses lowered characters without padding.
|
||||
var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding)
|
||||
|
||||
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
|
||||
func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) {
|
||||
rBytes, err := util.CryptoRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Add a prefix to the base32, this is in order to make it easier
|
||||
// for code scanners to grab sensitive tokens.
|
||||
clientSecret := "gto_" + base32Lower.EncodeToString(rBytes)
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
app.ClientSecret = string(hashedSecret)
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates the given secret by the hash saved in database
|
||||
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
|
||||
}
|
||||
|
||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
|
||||
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// CreateGrant generates a grant for an user
|
||||
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
grant := &OAuth2Grant{
|
||||
ApplicationID: app.ID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
err := db.Insert(ctx, grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
|
||||
if !has {
|
||||
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
|
||||
}
|
||||
return app, err
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(app)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
|
||||
type CreateOAuth2ApplicationOptions struct {
|
||||
Name string
|
||||
UserID int64
|
||||
ConfidentialClient bool
|
||||
SkipSecondaryAuthorization bool
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// CreateOAuth2Application inserts a new oauth2 application
|
||||
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
clientID := uuid.New().String()
|
||||
app := &OAuth2Application{
|
||||
UID: opts.UserID,
|
||||
Name: opts.Name,
|
||||
ClientID: clientID,
|
||||
RedirectURIs: opts.RedirectURIs,
|
||||
ConfidentialClient: opts.ConfidentialClient,
|
||||
SkipSecondaryAuthorization: opts.SkipSecondaryAuthorization,
|
||||
}
|
||||
if err := db.Insert(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
|
||||
type UpdateOAuth2ApplicationOptions struct {
|
||||
ID int64
|
||||
Name string
|
||||
UserID int64
|
||||
ConfidentialClient bool
|
||||
SkipSecondaryAuthorization bool
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// UpdateOAuth2Application updates an oauth2 application
|
||||
func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*OAuth2Application, error) {
|
||||
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app.UID != opts.UserID {
|
||||
return nil, errors.New("UID mismatch")
|
||||
}
|
||||
builtinApps := BuiltinApplications()
|
||||
if _, builtin := builtinApps[app.ClientID]; builtin {
|
||||
return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
|
||||
}
|
||||
|
||||
app.Name = opts.Name
|
||||
app.RedirectURIs = opts.RedirectURIs
|
||||
app.ConfidentialClient = opts.ConfidentialClient
|
||||
app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
|
||||
|
||||
if err = updateOAuth2Application(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app.ClientSecret = ""
|
||||
|
||||
return app, nil
|
||||
})
|
||||
}
|
||||
|
||||
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client", "skip_secondary_authorization").Update(app); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
sess := db.GetEngine(ctx)
|
||||
// the userid could be 0 if the app is instance-wide
|
||||
if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil {
|
||||
return err
|
||||
} else if deleted == 0 {
|
||||
return ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
codes := make([]*OAuth2AuthorizationCode, 0)
|
||||
// delete correlating auth codes
|
||||
if err := sess.Join("INNER", "oauth2_grant",
|
||||
"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
|
||||
return err
|
||||
}
|
||||
codeIDs := make([]int64, 0, len(codes))
|
||||
for _, grant := range codes {
|
||||
codeIDs = append(codeIDs, grant.ID)
|
||||
}
|
||||
|
||||
if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
|
||||
func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
app, err := GetOAuth2ApplicationByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
builtinApps := BuiltinApplications()
|
||||
if _, builtin := builtinApps[app.ClientID]; builtin {
|
||||
return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
|
||||
}
|
||||
return deleteOAuth2Application(ctx, id, userid)
|
||||
})
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
|
||||
type OAuth2AuthorizationCode struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Grant *OAuth2Grant `xorm:"-"`
|
||||
GrantID int64
|
||||
Code string `xorm:"INDEX unique"`
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
RedirectURI string
|
||||
ValidUntil timeutil.TimeStamp `xorm:"index"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_authorization_code`
|
||||
func (code *OAuth2AuthorizationCode) TableName() string {
|
||||
return "oauth2_authorization_code"
|
||||
}
|
||||
|
||||
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
|
||||
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) {
|
||||
redirect, err := url.Parse(code.RedirectURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := redirect.Query()
|
||||
if state != "" {
|
||||
q.Set("state", state)
|
||||
}
|
||||
q.Set("code", code.Code)
|
||||
redirect.RawQuery = q.Encode()
|
||||
return redirect, err
|
||||
}
|
||||
|
||||
// Invalidate deletes the auth code from the database to invalidate this code
|
||||
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
|
||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
|
||||
switch code.CodeChallengeMethod {
|
||||
case "S256":
|
||||
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return hashedVerifier == code.CodeChallenge
|
||||
case "plain":
|
||||
return verifier == code.CodeChallenge
|
||||
case "":
|
||||
return true
|
||||
default:
|
||||
// unsupported method -> return false
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
|
||||
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
auth = new(OAuth2AuthorizationCode)
|
||||
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
auth.Grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2Grant represents the permission of an user for a specific application to access resources
|
||||
type OAuth2Grant struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UserID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Application *OAuth2Application `xorm:"-"`
|
||||
ApplicationID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Counter int64 `xorm:"NOT NULL DEFAULT 1"`
|
||||
Scope string `xorm:"TEXT"`
|
||||
Nonce string `xorm:"TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_grant`
|
||||
func (grant *OAuth2Grant) TableName() string {
|
||||
return "oauth2_grant"
|
||||
}
|
||||
|
||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
rBytes, err := util.CryptoRandomBytes(32)
|
||||
if err != nil {
|
||||
return &OAuth2AuthorizationCode{}, err
|
||||
}
|
||||
// Add a prefix to the base32, this is in order to make it easier
|
||||
// for code scanners to grab sensitive tokens.
|
||||
codeSecret := "gta_" + base32Lower.EncodeToString(rBytes)
|
||||
|
||||
code = &OAuth2AuthorizationCode{
|
||||
Grant: grant,
|
||||
GrantID: grant.ID,
|
||||
RedirectURI: redirectURI,
|
||||
Code: codeSecret,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
}
|
||||
if err := db.Insert(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// IncreaseCounter increases the counter and updates the grant
|
||||
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
grant.Counter = updatedGrant.Counter
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScopeContains returns true if the grant scope contains the specified scope
|
||||
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
|
||||
return slices.Contains(strings.Split(grant.Scope, " "), scope)
|
||||
}
|
||||
|
||||
// SetNonce updates the current nonce value of a grant
|
||||
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
|
||||
grant.Nonce = nonce
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOAuth2GrantByID returns the grant with the given ID
|
||||
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil
|
||||
}
|
||||
return grant, err
|
||||
}
|
||||
|
||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
|
||||
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
|
||||
type joinedOAuth2Grant struct {
|
||||
Grant *OAuth2Grant `xorm:"extends"`
|
||||
Application *OAuth2Application `xorm:"extends"`
|
||||
}
|
||||
var results *xorm.Rows
|
||||
var err error
|
||||
if results, err = db.GetEngine(ctx).
|
||||
Table("oauth2_grant").
|
||||
Where("user_id = ?", uid).
|
||||
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
|
||||
Rows(new(joinedOAuth2Grant)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer results.Close()
|
||||
grants := make([]*OAuth2Grant, 0)
|
||||
for results.Next() {
|
||||
joinedGrant := new(joinedOAuth2Grant)
|
||||
if err := results.Scan(joinedGrant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedGrant.Grant.Application = joinedGrant.Application
|
||||
grants = append(grants, joinedGrant.Grant)
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
|
||||
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{})
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
|
||||
type ErrOAuthClientIDInvalid struct {
|
||||
ClientID string
|
||||
}
|
||||
|
||||
// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid.
|
||||
func IsErrOauthClientIDInvalid(err error) bool {
|
||||
_, ok := err.(ErrOAuthClientIDInvalid)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthClientIDInvalid) Error() string {
|
||||
return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrOAuthClientIDInvalid) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
|
||||
type ErrOAuthApplicationNotFound struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
|
||||
func IsErrOAuthApplicationNotFound(err error) bool {
|
||||
_, ok := err.(ErrOAuthApplicationNotFound)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthApplicationNotFound) Error() string {
|
||||
return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrOAuthApplicationNotFound) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// GetActiveOAuth2SourceByAuthName returns a OAuth2 AuthSource based on the given name
|
||||
func GetActiveOAuth2SourceByAuthName(ctx context.Context, name string) (*Source, error) {
|
||||
authSource := new(Source)
|
||||
has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !has {
|
||||
return nil, fmt.Errorf("oauth2 source not found, name: %q", name)
|
||||
}
|
||||
|
||||
return authSource, nil
|
||||
}
|
||||
|
||||
func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error {
|
||||
deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID})
|
||||
|
||||
if _, err := db.GetEngine(ctx).In("grant_id", deleteCond).
|
||||
Delete(&OAuth2AuthorizationCode{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DeleteBeans(ctx,
|
||||
&OAuth2Application{UID: userID},
|
||||
&OAuth2Grant{UserID: userID},
|
||||
); err != nil {
|
||||
return fmt.Errorf("DeleteBeans: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
32
models/auth/oauth2_list.go
Normal file
32
models/auth/oauth2_list.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"code.gitea.io/gitea/models/db"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type FindOAuth2ApplicationsOptions struct {
|
||||
db.ListOptions
|
||||
// OwnerID is the user id or org id of the owner of the application
|
||||
OwnerID int64
|
||||
// find global applications, if true, then OwnerID will be igonred
|
||||
IsGlobal bool
|
||||
}
|
||||
|
||||
func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond {
|
||||
conds := builder.NewCond()
|
||||
if opts.IsGlobal {
|
||||
conds = conds.And(builder.Eq{"uid": 0})
|
||||
} else if opts.OwnerID != 0 {
|
||||
conds = conds.And(builder.Eq{"uid": opts.OwnerID})
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
func (opts FindOAuth2ApplicationsOptions) ToOrders() string {
|
||||
return "id DESC"
|
||||
}
|
||||
264
models/auth/oauth2_test.go
Normal file
264
models/auth/oauth2_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
secret, err := app.GenerateClientSecret(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, secret)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
|
||||
}
|
||||
|
||||
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
|
||||
assert.NoError(b, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1})
|
||||
for b.Loop() {
|
||||
_, _ = app.GenerateClientSecret(b.Context())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{
|
||||
RedirectURIs: []string{"a", "b", "c"},
|
||||
}
|
||||
assert.True(t, app.ContainsRedirectURI("a"))
|
||||
assert.True(t, app.ContainsRedirectURI("b"))
|
||||
assert.True(t, app.ContainsRedirectURI("c"))
|
||||
assert.False(t, app.ContainsRedirectURI("d"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{
|
||||
RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"},
|
||||
ConfidentialClient: false,
|
||||
}
|
||||
|
||||
// http loopback uris should ignore port
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/"))
|
||||
|
||||
// not http
|
||||
assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/"))
|
||||
// not loopback
|
||||
assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://intranet:3456/"))
|
||||
// unparseable
|
||||
assert.False(t, app.ContainsRedirectURI(":"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}}
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
|
||||
|
||||
app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}}
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
secret, err := app.GenerateClientSecret(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, app.ValidateClientSecret([]byte(secret)))
|
||||
assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
|
||||
}
|
||||
|
||||
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := auth_model.GetOAuth2ApplicationByClientID(t.Context(), "da7da3ba-9a13-4167-856f-3899de0b0138")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
|
||||
|
||||
app, err = auth_model.GetOAuth2ApplicationByClientID(t.Context(), "invalid client id")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, app)
|
||||
}
|
||||
|
||||
func TestCreateOAuth2Application(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := auth_model.CreateOAuth2Application(t.Context(), auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newapp", app.Name)
|
||||
assert.Len(t, app.ClientID, 36)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"})
|
||||
}
|
||||
|
||||
func TestOAuth2Application_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName())
|
||||
}
|
||||
|
||||
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
grant, err := app.GetGrantByUserID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.UserID)
|
||||
|
||||
grant, err = app.GetGrantByUserID(t.Context(), 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Application_CreateGrant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
grant, err := app.CreateGrant(t.Context(), 2, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, grant)
|
||||
assert.Equal(t, int64(2), grant.UserID)
|
||||
assert.Equal(t, int64(1), grant.ApplicationID)
|
||||
assert.Empty(t, grant.Scope)
|
||||
}
|
||||
|
||||
//////////////////// Grant
|
||||
|
||||
func TestGetOAuth2GrantByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant, err := auth_model.GetOAuth2GrantByID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.ID)
|
||||
|
||||
grant, err = auth_model.GetOAuth2GrantByID(t.Context(), 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
|
||||
assert.NoError(t, grant.IncreaseCounter(t.Context()))
|
||||
assert.Equal(t, int64(2), grant.Counter)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2})
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_ScopeContains(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"})
|
||||
assert.True(t, grant.ScopeContains("openid"))
|
||||
assert.True(t, grant.ScopeContains("profile"))
|
||||
assert.False(t, grant.ScopeContains("profil"))
|
||||
assert.False(t, grant.ScopeContains("profile2"))
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
|
||||
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Greater(t, len(code.Code), 32) // secret length > 32
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName())
|
||||
}
|
||||
|
||||
func TestGetOAuth2GrantsByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
result, err := auth_model.GetOAuth2GrantsByUserID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, int64(1), result[0].ID)
|
||||
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
|
||||
|
||||
result, err = auth_model.GetOAuth2GrantsByUserID(t.Context(), 34134)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestRevokeOAuth2Grant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.NoError(t, auth_model.RevokeOAuth2Grant(t.Context(), 1, 1))
|
||||
unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1})
|
||||
}
|
||||
|
||||
//////////////////// Authorization Code
|
||||
|
||||
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), "authcode")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Equal(t, "authcode", code.Code)
|
||||
assert.Equal(t, int64(1), code.ID)
|
||||
|
||||
code, err = auth_model.GetOAuth2AuthorizationByCode(t.Context(), "does not exist")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, code)
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
|
||||
// test plain
|
||||
code := &auth_model.OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "plain",
|
||||
CodeChallenge: "test123",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge("test123"))
|
||||
assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
|
||||
|
||||
// test S256
|
||||
code = &auth_model.OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "S256",
|
||||
CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
|
||||
assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
|
||||
|
||||
// test unknown
|
||||
code = &auth_model.OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "monkey",
|
||||
CodeChallenge: "foiwgjioriogeiogjerger",
|
||||
}
|
||||
assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
|
||||
|
||||
// test no code challenge
|
||||
code = &auth_model.OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: "",
|
||||
CodeChallenge: "foierjiogerogerg",
|
||||
}
|
||||
assert.True(t, code.ValidateCodeChallenge(""))
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
|
||||
code := &auth_model.OAuth2AuthorizationCode{
|
||||
RedirectURI: "https://example.com/callback",
|
||||
Code: "thecode",
|
||||
}
|
||||
|
||||
redirect, err := code.GenerateRedirectURI("thestate")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
|
||||
|
||||
redirect, err = code.GenerateRedirectURI("")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
code := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
|
||||
assert.NoError(t, code.Invalidate(t.Context()))
|
||||
unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName())
|
||||
}
|
||||
112
models/auth/session.go
Normal file
112
models/auth/session.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// Session represents a session compatible for go-chi session
|
||||
type Session struct {
|
||||
Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
|
||||
Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
|
||||
Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Session))
|
||||
}
|
||||
|
||||
// UpdateSession updates the session with provided id
|
||||
func UpdateSession(ctx context.Context, key string, data []byte) error {
|
||||
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
|
||||
Data: data,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadSession reads the data for the provided session
|
||||
func ReadSession(ctx context.Context, key string) (*Session, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
|
||||
session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !exist {
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
}
|
||||
if err := db.Insert(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return session, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ExistSession checks if a session exists
|
||||
func ExistSession(ctx context.Context, key string) (bool, error) {
|
||||
return db.Exist[Session](ctx, builder.Eq{"`key`": key})
|
||||
}
|
||||
|
||||
// DestroySession destroys a session
|
||||
func DestroySession(ctx context.Context, key string) error {
|
||||
_, err := db.GetEngine(ctx).Delete(&Session{
|
||||
Key: key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RegenerateSession regenerates a session from the old id
|
||||
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
|
||||
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
|
||||
return nil, err
|
||||
} else if has {
|
||||
return nil, fmt.Errorf("session Key: %s already exists", newKey)
|
||||
}
|
||||
|
||||
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
if err := db.Insert(ctx, &Session{
|
||||
Key: oldKey,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "UPDATE `session` SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
|
||||
if err != nil {
|
||||
// is not exist, it should be impossible
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// CountSessions returns the number of sessions
|
||||
func CountSessions(ctx context.Context) (int64, error) {
|
||||
return db.GetEngine(ctx).Count(&Session{})
|
||||
}
|
||||
|
||||
// CleanupSessions cleans up expired sessions
|
||||
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
|
||||
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
|
||||
return err
|
||||
}
|
||||
398
models/auth/source.go
Normal file
398
models/auth/source.go
Normal file
@@ -0,0 +1,398 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/optional"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/convert"
|
||||
)
|
||||
|
||||
// Type represents an login type.
|
||||
type Type int
|
||||
|
||||
// Note: new type must append to the end of list to maintain compatibility.
|
||||
const (
|
||||
NoType Type = iota
|
||||
Plain // 1
|
||||
LDAP // 2
|
||||
SMTP // 3
|
||||
PAM // 4
|
||||
DLDAP // 5
|
||||
OAuth2 // 6
|
||||
SSPI // 7
|
||||
)
|
||||
|
||||
// String returns the string name of the LoginType
|
||||
func (typ Type) String() string {
|
||||
return Names[typ]
|
||||
}
|
||||
|
||||
// Int returns the int value of the LoginType
|
||||
func (typ Type) Int() int {
|
||||
return int(typ)
|
||||
}
|
||||
|
||||
// Names contains the name of LoginType values.
|
||||
var Names = map[Type]string{
|
||||
LDAP: "LDAP (via BindDN)",
|
||||
DLDAP: "LDAP (simple auth)", // Via direct bind
|
||||
SMTP: "SMTP",
|
||||
PAM: "PAM",
|
||||
OAuth2: "OAuth2",
|
||||
SSPI: "SPNEGO with SSPI",
|
||||
}
|
||||
|
||||
// Config represents login config as far as the db is concerned
|
||||
type Config interface {
|
||||
convert.Conversion
|
||||
SetAuthSource(*Source)
|
||||
}
|
||||
|
||||
type ConfigBase struct {
|
||||
AuthSource *Source
|
||||
}
|
||||
|
||||
func (p *ConfigBase) SetAuthSource(s *Source) {
|
||||
p.AuthSource = s
|
||||
}
|
||||
|
||||
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
|
||||
type SkipVerifiable interface {
|
||||
IsSkipVerify() bool
|
||||
}
|
||||
|
||||
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
|
||||
type HasTLSer interface {
|
||||
HasTLS() bool
|
||||
}
|
||||
|
||||
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
|
||||
type UseTLSer interface {
|
||||
UseTLS() bool
|
||||
}
|
||||
|
||||
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
|
||||
type SSHKeyProvider interface {
|
||||
ProvidesSSHKeys() bool
|
||||
}
|
||||
|
||||
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
|
||||
type RegisterableSource interface {
|
||||
RegisterSource() error
|
||||
UnregisterSource() error
|
||||
}
|
||||
|
||||
var registeredConfigs = map[Type]func() Config{}
|
||||
|
||||
// RegisterTypeConfig register a config for a provided type
|
||||
func RegisterTypeConfig(typ Type, exemplar Config) {
|
||||
if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
|
||||
// Pointer:
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not a Pointer
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
|
||||
}
|
||||
}
|
||||
|
||||
// Source represents an external way for authorizing users.
|
||||
type Source struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Type Type
|
||||
Name string `xorm:"UNIQUE"`
|
||||
IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
TwoFactorPolicy string `xorm:"two_factor_policy NOT NULL DEFAULT ''"`
|
||||
Cfg Config `xorm:"TEXT"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
// TableName xorm will read the table name from this method
|
||||
func (Source) TableName() string {
|
||||
return "login_source"
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Source))
|
||||
}
|
||||
|
||||
// BeforeSet is invoked from XORM before setting the value of a field of this object.
|
||||
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
|
||||
if colName == "type" {
|
||||
typ := Type(db.Cell2Int64(val))
|
||||
constructor, ok := registeredConfigs[typ]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
source.Cfg = constructor()
|
||||
source.Cfg.SetAuthSource(source)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeName return name of this login source type.
|
||||
func (source *Source) TypeName() string {
|
||||
return Names[source.Type]
|
||||
}
|
||||
|
||||
// IsLDAP returns true of this source is of the LDAP type.
|
||||
func (source *Source) IsLDAP() bool {
|
||||
return source.Type == LDAP
|
||||
}
|
||||
|
||||
// IsDLDAP returns true of this source is of the DLDAP type.
|
||||
func (source *Source) IsDLDAP() bool {
|
||||
return source.Type == DLDAP
|
||||
}
|
||||
|
||||
// IsSMTP returns true of this source is of the SMTP type.
|
||||
func (source *Source) IsSMTP() bool {
|
||||
return source.Type == SMTP
|
||||
}
|
||||
|
||||
// IsPAM returns true of this source is of the PAM type.
|
||||
func (source *Source) IsPAM() bool {
|
||||
return source.Type == PAM
|
||||
}
|
||||
|
||||
// IsOAuth2 returns true of this source is of the OAuth2 type.
|
||||
func (source *Source) IsOAuth2() bool {
|
||||
return source.Type == OAuth2
|
||||
}
|
||||
|
||||
// IsSSPI returns true of this source is of the SSPI type.
|
||||
func (source *Source) IsSSPI() bool {
|
||||
return source.Type == SSPI
|
||||
}
|
||||
|
||||
// HasTLS returns true of this source supports TLS.
|
||||
func (source *Source) HasTLS() bool {
|
||||
hasTLSer, ok := source.Cfg.(HasTLSer)
|
||||
return ok && hasTLSer.HasTLS()
|
||||
}
|
||||
|
||||
// UseTLS returns true of this source is configured to use TLS.
|
||||
func (source *Source) UseTLS() bool {
|
||||
useTLSer, ok := source.Cfg.(UseTLSer)
|
||||
return ok && useTLSer.UseTLS()
|
||||
}
|
||||
|
||||
// SkipVerify returns true if this source is configured to skip SSL
|
||||
// verification.
|
||||
func (source *Source) SkipVerify() bool {
|
||||
skipVerifiable, ok := source.Cfg.(SkipVerifiable)
|
||||
return ok && skipVerifiable.IsSkipVerify()
|
||||
}
|
||||
|
||||
func (source *Source) TwoFactorShouldSkip() bool {
|
||||
return source.TwoFactorPolicy == "skip"
|
||||
}
|
||||
|
||||
// CreateSource inserts a AuthSource in the DB if not already
|
||||
// existing with the given name.
|
||||
func CreateSource(ctx context.Context, source *Source) error {
|
||||
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrSourceAlreadyExist{source.Name}
|
||||
}
|
||||
// Synchronization is only available with LDAP for now
|
||||
if !source.IsLDAP() && !source.IsOAuth2() {
|
||||
source.IsSyncEnabled = false
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).Insert(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
source.Cfg.SetAuthSource(source)
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// remove the AuthSource in case of errors while registering configuration
|
||||
if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil {
|
||||
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type FindSourcesOptions struct {
|
||||
db.ListOptions
|
||||
IsActive optional.Option[bool]
|
||||
LoginType Type
|
||||
}
|
||||
|
||||
func (opts FindSourcesOptions) ToConds() builder.Cond {
|
||||
conds := builder.NewCond()
|
||||
if opts.IsActive.Has() {
|
||||
conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()})
|
||||
}
|
||||
if opts.LoginType != NoType {
|
||||
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
// IsSSPIEnabled returns true if there is at least one activated login
|
||||
// source of type LoginSSPI
|
||||
func IsSSPIEnabled(ctx context.Context) bool {
|
||||
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
|
||||
IsActive: optional.Some(true),
|
||||
LoginType: SSPI,
|
||||
}.ToConds())
|
||||
if err != nil {
|
||||
log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err)
|
||||
return false
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// GetSourceByID returns login source by given ID.
|
||||
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
|
||||
source := new(Source)
|
||||
if id == 0 {
|
||||
source.Cfg = registeredConfigs[NoType]()
|
||||
// Set this source to active
|
||||
// FIXME: allow disabling of db based password authentication in future
|
||||
source.IsActive = true
|
||||
return source, nil
|
||||
}
|
||||
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrSourceNotExist{id}
|
||||
}
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// UpdateSource updates a Source record in DB.
|
||||
func UpdateSource(ctx context.Context, source *Source) error {
|
||||
var originalSource *Source
|
||||
if source.IsOAuth2() {
|
||||
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
|
||||
var err error
|
||||
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrSourceAlreadyExist{source.Name}
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
source.Cfg.SetAuthSource(source)
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// restore original values since we cannot update the provider itself
|
||||
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
|
||||
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
|
||||
type ErrSourceNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
|
||||
func IsErrSourceNotExist(err error) bool {
|
||||
_, ok := err.(ErrSourceNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceNotExist) Error() string {
|
||||
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrSourceNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
|
||||
type ErrSourceAlreadyExist struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
|
||||
func IsErrSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrSourceAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrExist err
|
||||
func (err ErrSourceAlreadyExist) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrSourceInUse represents a "SourceInUse" kind of error.
|
||||
type ErrSourceInUse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
|
||||
func IsErrSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrSourceInUse)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
|
||||
}
|
||||
63
models/auth/source_test.go
Normal file
63
models/auth/source_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
"code.gitea.io/gitea/modules/json"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type TestSource struct {
|
||||
auth_model.ConfigBase
|
||||
|
||||
Provider string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
OpenIDConnectAutoDiscoveryURL string
|
||||
IconURL string
|
||||
}
|
||||
|
||||
// FromDB fills up a LDAPConfig from serialized format.
|
||||
func (source *TestSource) FromDB(bs []byte) error {
|
||||
return json.Unmarshal(bs, &source)
|
||||
}
|
||||
|
||||
// ToDB exports a LDAPConfig to a serialized format.
|
||||
func (source *TestSource) ToDB() ([]byte, error) {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
|
||||
func TestDumpAuthSource(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
authSourceSchema, err := unittest.GetXORMEngine().TableInfo(new(auth_model.Source))
|
||||
assert.NoError(t, err)
|
||||
|
||||
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
|
||||
|
||||
auth_model.CreateSource(t.Context(), &auth_model.Source{
|
||||
Type: auth_model.OAuth2,
|
||||
Name: "TestSource",
|
||||
IsActive: false,
|
||||
Cfg: &TestSource{
|
||||
Provider: "ConvertibleSourceName",
|
||||
ClientID: "42",
|
||||
},
|
||||
})
|
||||
|
||||
sb := new(strings.Builder)
|
||||
|
||||
// TODO: this test is quite hacky, it should use a low-level "select" (without model processors) but not a database dump
|
||||
engine := unittest.GetXORMEngine()
|
||||
require.NoError(t, engine.DumpTables([]*schemas.Table{authSourceSchema}, sb))
|
||||
assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
|
||||
}
|
||||
176
models/auth/twofactor.go
Normal file
176
models/auth/twofactor.go
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/secret"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
//
|
||||
// Two-factor authentication
|
||||
//
|
||||
|
||||
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
|
||||
type ErrTwoFactorNotEnrolled struct {
|
||||
UID int64
|
||||
}
|
||||
|
||||
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
|
||||
func IsErrTwoFactorNotEnrolled(err error) bool {
|
||||
_, ok := err.(ErrTwoFactorNotEnrolled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrTwoFactorNotEnrolled) Error() string {
|
||||
return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrTwoFactorNotEnrolled) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// TwoFactor represents a two-factor authentication token.
|
||||
type TwoFactor struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"UNIQUE"`
|
||||
Secret string
|
||||
ScratchSalt string
|
||||
ScratchHash string
|
||||
LastUsedPasscode string `xorm:"VARCHAR(10)"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(TwoFactor))
|
||||
}
|
||||
|
||||
// GenerateScratchToken recreates the scratch token the user is using.
|
||||
func (t *TwoFactor) GenerateScratchToken() (string, error) {
|
||||
tokenBytes, err := util.CryptoRandomBytes(6)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`.
|
||||
const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes)
|
||||
t.ScratchSalt, _ = util.CryptoRandomString(10)
|
||||
t.ScratchHash = HashToken(token, t.ScratchSalt)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// HashToken return the hashable salt
|
||||
func HashToken(token, salt string) string {
|
||||
tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
|
||||
return hex.EncodeToString(tempHash)
|
||||
}
|
||||
|
||||
// VerifyScratchToken verifies if the specified scratch token is valid.
|
||||
func (t *TwoFactor) VerifyScratchToken(token string) bool {
|
||||
if len(token) == 0 {
|
||||
return false
|
||||
}
|
||||
tempHash := HashToken(token, t.ScratchSalt)
|
||||
return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
|
||||
}
|
||||
|
||||
func (t *TwoFactor) getEncryptionKey() []byte {
|
||||
k := md5.Sum([]byte(setting.SecretKey))
|
||||
return k[:]
|
||||
}
|
||||
|
||||
// SetSecret sets the 2FA secret.
|
||||
func (t *TwoFactor) SetSecret(secretString string) error {
|
||||
secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTOTP validates the provided passcode.
|
||||
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
|
||||
decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ValidateTOTP invalid base64: %w", err)
|
||||
}
|
||||
secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ValidateTOTP unable to decrypt (maybe SECRET_KEY is wrong): %w", err)
|
||||
}
|
||||
secretStr := string(secretBytes)
|
||||
return totp.Validate(passcode, secretStr), nil
|
||||
}
|
||||
|
||||
// NewTwoFactor creates a new two-factor authentication token.
|
||||
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTwoFactor updates a two-factor authentication token.
|
||||
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
|
||||
twofa := &TwoFactor{}
|
||||
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrTwoFactorNotEnrolled{uid}
|
||||
}
|
||||
return twofa, nil
|
||||
}
|
||||
|
||||
// HasTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
|
||||
}
|
||||
|
||||
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
|
||||
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrTwoFactorNotEnrolled{userID}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HasTwoFactorOrWebAuthn(ctx context.Context, id int64) (bool, error) {
|
||||
has, err := HasTwoFactorByUID(ctx, id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if has {
|
||||
return true, nil
|
||||
}
|
||||
return HasWebAuthnRegistrationsByUID(ctx, id)
|
||||
}
|
||||
212
models/auth/webauthn.go
Normal file
212
models/auth/webauthn.go
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/timeutil"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
)
|
||||
|
||||
// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
|
||||
type ErrWebAuthnCredentialNotExist struct {
|
||||
ID int64
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
func (err ErrWebAuthnCredentialNotExist) Error() string {
|
||||
if len(err.CredentialID) == 0 {
|
||||
return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID)
|
||||
}
|
||||
return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrWebAuthnCredentialNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist.
|
||||
func IsErrWebAuthnCredentialNotExist(err error) bool {
|
||||
_, ok := err.(ErrWebAuthnCredentialNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
// WebAuthnCredential represents the WebAuthn credential data for a public-key
|
||||
// credential conformant to WebAuthn Level 1
|
||||
type WebAuthnCredential struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Name string
|
||||
LowerName string `xorm:"unique(s)"`
|
||||
UserID int64 `xorm:"INDEX unique(s)"`
|
||||
CredentialID []byte `xorm:"INDEX VARBINARY(1024)"`
|
||||
PublicKey []byte
|
||||
AttestationType string
|
||||
AAGUID []byte
|
||||
SignCount uint32 `xorm:"BIGINT"`
|
||||
CloneWarning bool
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(WebAuthnCredential))
|
||||
}
|
||||
|
||||
// TableName returns a better table name for WebAuthnCredential
|
||||
func (cred WebAuthnCredential) TableName() string {
|
||||
return "webauthn_credential"
|
||||
}
|
||||
|
||||
// UpdateSignCount will update the database value of SignCount
|
||||
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
|
||||
return err
|
||||
}
|
||||
|
||||
// BeforeInsert will be invoked by XORM before updating a record
|
||||
func (cred *WebAuthnCredential) BeforeInsert() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// BeforeUpdate will be invoked by XORM before updating a record
|
||||
func (cred *WebAuthnCredential) BeforeUpdate() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (cred *WebAuthnCredential) AfterLoad() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// WebAuthnCredentialList is a list of *WebAuthnCredential
|
||||
type WebAuthnCredentialList []*WebAuthnCredential
|
||||
|
||||
// newCredentialFlagsFromAuthenticatorFlags is copied from https://github.com/go-webauthn/webauthn/pull/337
|
||||
// to convert protocol.AuthenticatorFlags to webauthn.CredentialFlags
|
||||
func newCredentialFlagsFromAuthenticatorFlags(flags protocol.AuthenticatorFlags) webauthn.CredentialFlags {
|
||||
return webauthn.CredentialFlags{
|
||||
UserPresent: flags.HasUserPresent(),
|
||||
UserVerified: flags.HasUserVerified(),
|
||||
BackupEligible: flags.HasBackupEligible(),
|
||||
BackupState: flags.HasBackupState(),
|
||||
}
|
||||
}
|
||||
|
||||
// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials
|
||||
func (list WebAuthnCredentialList) ToCredentials(defaultAuthFlags ...protocol.AuthenticatorFlags) []webauthn.Credential {
|
||||
// TODO: at the moment, Gitea doesn't store or check the flags
|
||||
// so we need to use the default flags from the authenticator to make the login validation pass
|
||||
// In the future, we should:
|
||||
// 1. store the flags when registering the credential
|
||||
// 2. provide the stored flags when converting the credentials (for login)
|
||||
// 3. for old users, still use this fallback to the default flags
|
||||
defAuthFlags := util.OptionalArg(defaultAuthFlags)
|
||||
creds := make([]webauthn.Credential, 0, len(list))
|
||||
for _, cred := range list {
|
||||
creds = append(creds, webauthn.Credential{
|
||||
ID: cred.CredentialID,
|
||||
PublicKey: cred.PublicKey,
|
||||
AttestationType: cred.AttestationType,
|
||||
Flags: newCredentialFlagsFromAuthenticatorFlags(defAuthFlags),
|
||||
Authenticator: webauthn.Authenticator{
|
||||
AAGUID: cred.AAGUID,
|
||||
SignCount: cred.SignCount,
|
||||
CloneWarning: cred.CloneWarning,
|
||||
},
|
||||
})
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
|
||||
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
|
||||
creds := make(WebAuthnCredentialList, 0)
|
||||
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
|
||||
}
|
||||
|
||||
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
|
||||
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByName returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByID returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{ID: id}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
|
||||
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
|
||||
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// CreateCredential will create a new WebAuthnCredential from the given Credential
|
||||
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
|
||||
c := &WebAuthnCredential{
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
CredentialID: cred.ID,
|
||||
PublicKey: cred.PublicKey,
|
||||
AttestationType: cred.AttestationType,
|
||||
AAGUID: cred.Authenticator.AAGUID,
|
||||
SignCount: cred.Authenticator.SignCount,
|
||||
CloneWarning: false,
|
||||
}
|
||||
|
||||
if err := db.Insert(ctx, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DeleteCredential will delete WebAuthnCredential
|
||||
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
|
||||
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
|
||||
return had > 0, err
|
||||
}
|
||||
|
||||
// WebAuthnCredentials implements the webauthn.User interface
|
||||
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
|
||||
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbCreds.ToCredentials(), nil
|
||||
}
|
||||
66
models/auth/webauthn_test.go
Normal file
66
models/auth/webauthn_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetWebAuthnCredentialByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialByID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn credential", res.Name)
|
||||
|
||||
_, err = auth_model.GetWebAuthnCredentialByID(t.Context(), 342432)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
|
||||
}
|
||||
|
||||
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialsByUID(t.Context(), 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "WebAuthn credential", res[0].Name)
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_TableName(t *testing.T) {
|
||||
assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName())
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 1
|
||||
assert.NoError(t, cred.UpdateSignCount(t.Context()))
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 0xffffffff
|
||||
assert.NoError(t, cred.UpdateSignCount(t.Context()))
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
|
||||
}
|
||||
|
||||
func TestCreateCredential(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.CreateCredential(t.Context(), 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn Created Credential", res.Name)
|
||||
assert.Equal(t, []byte("Test"), res.CredentialID)
|
||||
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1})
|
||||
}
|
||||
238
models/avatars/avatar.go
Normal file
238
models/avatars/avatar.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package avatars
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/cache"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"strk.kbt.io/projects/go/libravatar"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultAvatarClass is the default class of a rendered avatar
|
||||
DefaultAvatarClass = "ui avatar tw-align-middle"
|
||||
// DefaultAvatarPixelSize is the default size in pixels of a rendered avatar
|
||||
DefaultAvatarPixelSize = 28
|
||||
)
|
||||
|
||||
// EmailHash represents a pre-generated hash map (mainly used by LibravatarURL, it queries email server's DNS records)
|
||||
type EmailHash struct {
|
||||
Hash string `xorm:"pk varchar(32)"`
|
||||
Email string `xorm:"UNIQUE NOT NULL"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(EmailHash))
|
||||
}
|
||||
|
||||
type avatarSettingStruct struct {
|
||||
defaultAvatarLink string
|
||||
gravatarSource string
|
||||
gravatarSourceURL *url.URL
|
||||
libravatar *libravatar.Libravatar
|
||||
}
|
||||
|
||||
var avatarSettingAtomic atomic.Pointer[avatarSettingStruct]
|
||||
|
||||
func loadAvatarSetting() (*avatarSettingStruct, error) {
|
||||
s := avatarSettingAtomic.Load()
|
||||
if s == nil || s.gravatarSource != setting.GravatarSource {
|
||||
s = &avatarSettingStruct{}
|
||||
u, err := url.Parse(setting.AppSubURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse AppSubURL: %w", err)
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, "/assets/img/avatar_default.png")
|
||||
s.defaultAvatarLink = u.String()
|
||||
|
||||
s.gravatarSourceURL, err = url.Parse(setting.GravatarSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse GravatarSource %q: %w", setting.GravatarSource, err)
|
||||
}
|
||||
|
||||
s.libravatar = libravatar.New()
|
||||
if s.gravatarSourceURL.Scheme == "https" {
|
||||
s.libravatar.SetUseHTTPS(true)
|
||||
s.libravatar.SetSecureFallbackHost(s.gravatarSourceURL.Host)
|
||||
} else {
|
||||
s.libravatar.SetUseHTTPS(false)
|
||||
s.libravatar.SetFallbackHost(s.gravatarSourceURL.Host)
|
||||
}
|
||||
|
||||
avatarSettingAtomic.Store(s)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// DefaultAvatarLink the default avatar link
|
||||
func DefaultAvatarLink() string {
|
||||
a, err := loadAvatarSetting()
|
||||
if err != nil {
|
||||
log.Error("Failed to loadAvatarSetting: %v", err)
|
||||
return ""
|
||||
}
|
||||
return a.defaultAvatarLink
|
||||
}
|
||||
|
||||
// HashEmail hashes email address to MD5 string. https://en.gravatar.com/site/implement/hash/
|
||||
func HashEmail(email string) string {
|
||||
m := md5.New()
|
||||
_, _ = m.Write([]byte(strings.ToLower(strings.TrimSpace(email))))
|
||||
return hex.EncodeToString(m.Sum(nil))
|
||||
}
|
||||
|
||||
// GetEmailForHash converts a provided md5sum to the email
|
||||
func GetEmailForHash(ctx context.Context, md5Sum string) (string, error) {
|
||||
return cache.GetString("Avatar:"+md5Sum, func() (string, error) {
|
||||
emailHash := EmailHash{
|
||||
Hash: strings.ToLower(strings.TrimSpace(md5Sum)),
|
||||
}
|
||||
|
||||
_, err := db.GetEngine(ctx).Get(&emailHash)
|
||||
return emailHash.Email, err
|
||||
})
|
||||
}
|
||||
|
||||
// LibravatarURL returns the URL for the given email. Slow due to the DNS lookup.
|
||||
// This function should only be called if a federated avatar service is enabled.
|
||||
func LibravatarURL(email string) (*url.URL, error) {
|
||||
a, err := loadAvatarSetting()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr, err := a.libravatar.FromEmail(email)
|
||||
if err != nil {
|
||||
log.Error("LibravatarService.FromEmail(email=%s): error %v", email, err)
|
||||
return nil, err
|
||||
}
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse libravatar url(%s): error %v", urlStr, err)
|
||||
return nil, err
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// saveEmailHash returns an avatar link for a provided email,
|
||||
// the email and hash are saved into database, which will be used by GetEmailForHash later
|
||||
func saveEmailHash(ctx context.Context, email string) string {
|
||||
lowerEmail := strings.ToLower(strings.TrimSpace(email))
|
||||
emailHash := HashEmail(lowerEmail)
|
||||
_, _ = cache.GetString("Avatar:"+emailHash, func() (string, error) {
|
||||
emailHash := &EmailHash{
|
||||
Email: lowerEmail,
|
||||
Hash: emailHash,
|
||||
}
|
||||
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
|
||||
if err := db.WithTx(ctx, func(ctx context.Context) error {
|
||||
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
|
||||
if has || err != nil {
|
||||
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
|
||||
return nil
|
||||
}
|
||||
_, _ = db.GetEngine(ctx).Insert(emailHash)
|
||||
return nil
|
||||
}); err != nil {
|
||||
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
|
||||
return lowerEmail, nil
|
||||
}
|
||||
return lowerEmail, nil
|
||||
})
|
||||
return emailHash
|
||||
}
|
||||
|
||||
// GenerateUserAvatarFastLink returns a fast link (302) to the user's avatar: "/user/avatar/${User.Name}/${size}"
|
||||
func GenerateUserAvatarFastLink(userName string, size int) string {
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
return setting.AppSubURL + "/user/avatar/" + url.PathEscape(userName) + "/" + strconv.Itoa(size)
|
||||
}
|
||||
|
||||
// GenerateUserAvatarImageLink returns a link for `User.Avatar` image file: "/avatars/${User.Avatar}"
|
||||
func GenerateUserAvatarImageLink(userAvatar string, size int) string {
|
||||
if size > 0 {
|
||||
return setting.AppSubURL + "/avatars/" + url.PathEscape(userAvatar) + "?size=" + strconv.Itoa(size)
|
||||
}
|
||||
return setting.AppSubURL + "/avatars/" + url.PathEscape(userAvatar)
|
||||
}
|
||||
|
||||
// generateRecognizedAvatarURL generate a recognized avatar (Gravatar/Libravatar) URL, it modifies the URL so the parameter is passed by a copy
|
||||
func generateRecognizedAvatarURL(u url.URL, size int) string {
|
||||
urlQuery := u.Query()
|
||||
urlQuery.Set("d", "identicon")
|
||||
if size > 0 {
|
||||
urlQuery.Set("s", strconv.Itoa(size))
|
||||
}
|
||||
u.RawQuery = urlQuery.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// generateEmailAvatarLink returns a email avatar link.
|
||||
// if final is true, it may use a slow path (eg: query DNS).
|
||||
// if final is false, it always uses a fast path.
|
||||
func generateEmailAvatarLink(ctx context.Context, email string, size int, final bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return DefaultAvatarLink()
|
||||
}
|
||||
|
||||
avatarSetting, err := loadAvatarSetting()
|
||||
if err != nil {
|
||||
return DefaultAvatarLink()
|
||||
}
|
||||
|
||||
enableFederatedAvatar := setting.Config().Picture.EnableFederatedAvatar.Value(ctx)
|
||||
if enableFederatedAvatar {
|
||||
emailHash := saveEmailHash(ctx, email)
|
||||
if final {
|
||||
// for final link, we can spend more time on slow external query
|
||||
var avatarURL *url.URL
|
||||
if avatarURL, err = LibravatarURL(email); err != nil {
|
||||
return DefaultAvatarLink()
|
||||
}
|
||||
return generateRecognizedAvatarURL(*avatarURL, size)
|
||||
}
|
||||
// for non-final link, we should return fast (use a 302 redirection link)
|
||||
urlStr := setting.AppSubURL + "/avatar/" + url.PathEscape(emailHash)
|
||||
if size > 0 {
|
||||
urlStr += "?size=" + strconv.Itoa(size)
|
||||
}
|
||||
return urlStr
|
||||
}
|
||||
|
||||
disableGravatar := setting.Config().Picture.DisableGravatar.Value(ctx)
|
||||
if !disableGravatar {
|
||||
// copy GravatarSourceURL, because we will modify its Path.
|
||||
avatarURLCopy := *avatarSetting.gravatarSourceURL
|
||||
avatarURLCopy.Path = path.Join(avatarURLCopy.Path, HashEmail(email))
|
||||
return generateRecognizedAvatarURL(avatarURLCopy, size)
|
||||
}
|
||||
|
||||
return DefaultAvatarLink()
|
||||
}
|
||||
|
||||
// GenerateEmailAvatarFastLink returns a avatar link (fast, the link may be a delegated one: "/avatar/${hash}")
|
||||
func GenerateEmailAvatarFastLink(ctx context.Context, email string, size int) string {
|
||||
return generateEmailAvatarLink(ctx, email, size, false)
|
||||
}
|
||||
|
||||
// GenerateEmailAvatarFinalLink returns a avatar final link (maybe slow)
|
||||
func GenerateEmailAvatarFinalLink(ctx context.Context, email string, size int) string {
|
||||
return generateEmailAvatarLink(ctx, email, size, true)
|
||||
}
|
||||
57
models/avatars/avatar_test.go
Normal file
57
models/avatars/avatar_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package avatars_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
avatars_model "code.gitea.io/gitea/models/avatars"
|
||||
system_model "code.gitea.io/gitea/models/system"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/setting/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const gravatarSource = "https://secure.gravatar.com/avatar/"
|
||||
|
||||
func disableGravatar(t *testing.T) {
|
||||
err := system_model.SetSettings(t.Context(), map[string]string{setting.Config().Picture.EnableFederatedAvatar.DynKey(): "false"})
|
||||
assert.NoError(t, err)
|
||||
err = system_model.SetSettings(t.Context(), map[string]string{setting.Config().Picture.DisableGravatar.DynKey(): "true"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func enableGravatar(t *testing.T) {
|
||||
err := system_model.SetSettings(t.Context(), map[string]string{setting.Config().Picture.DisableGravatar.DynKey(): "false"})
|
||||
assert.NoError(t, err)
|
||||
setting.GravatarSource = gravatarSource
|
||||
}
|
||||
|
||||
func TestHashEmail(t *testing.T) {
|
||||
assert.Equal(t,
|
||||
"d41d8cd98f00b204e9800998ecf8427e",
|
||||
avatars_model.HashEmail(""),
|
||||
)
|
||||
assert.Equal(t,
|
||||
"353cbad9b58e69c96154ad99f92bedc7",
|
||||
avatars_model.HashEmail("gitea@example.com"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestSizedAvatarLink(t *testing.T) {
|
||||
setting.AppSubURL = "/testsuburl"
|
||||
|
||||
disableGravatar(t)
|
||||
config.GetDynGetter().InvalidateCache()
|
||||
assert.Equal(t, "/testsuburl/assets/img/avatar_default.png",
|
||||
avatars_model.GenerateEmailAvatarFastLink(t.Context(), "gitea@example.com", 100))
|
||||
|
||||
enableGravatar(t)
|
||||
config.GetDynGetter().InvalidateCache()
|
||||
assert.Equal(t,
|
||||
"https://secure.gravatar.com/avatar/353cbad9b58e69c96154ad99f92bedc7?d=identicon&s=100",
|
||||
avatars_model.GenerateEmailAvatarFastLink(t.Context(), "gitea@example.com", 100),
|
||||
)
|
||||
}
|
||||
18
models/avatars/main_test.go
Normal file
18
models/avatars/main_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package avatars_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
_ "code.gitea.io/gitea/models"
|
||||
_ "code.gitea.io/gitea/models/activities"
|
||||
_ "code.gitea.io/gitea/models/perm/access"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
191
models/db/collation.go
Normal file
191
models/db/collation.go
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/modules/container"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type CheckCollationsResult struct {
|
||||
ExpectedCollation string
|
||||
AvailableCollation container.Set[string]
|
||||
DatabaseCollation string
|
||||
IsCollationCaseSensitive func(s string) bool
|
||||
CollationEquals func(a, b string) bool
|
||||
ExistingTableNumber int
|
||||
|
||||
InconsistentCollationColumns []string
|
||||
}
|
||||
|
||||
func findAvailableCollationsMySQL(x *xorm.Engine) (ret container.Set[string], err error) {
|
||||
var res []struct {
|
||||
Collation string
|
||||
}
|
||||
if err = x.SQL("SHOW COLLATION WHERE (Collation = 'utf8mb4_bin') OR (Collation LIKE '%\\_as\\_cs%')").Find(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = make(container.Set[string], len(res))
|
||||
for _, r := range res {
|
||||
ret.Add(r.Collation)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func findAvailableCollationsMSSQL(x *xorm.Engine) (ret container.Set[string], err error) {
|
||||
var res []struct {
|
||||
Name string
|
||||
}
|
||||
if err = x.SQL("SELECT * FROM sys.fn_helpcollations() WHERE name LIKE '%[_]CS[_]AS%'").Find(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = make(container.Set[string], len(res))
|
||||
for _, r := range res {
|
||||
ret.Add(r.Name)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckCollations(x *xorm.Engine) (*CheckCollationsResult, error) {
|
||||
dbTables, err := x.DBMetas()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &CheckCollationsResult{
|
||||
ExistingTableNumber: len(dbTables),
|
||||
CollationEquals: func(a, b string) bool { return a == b },
|
||||
}
|
||||
|
||||
var candidateCollations []string
|
||||
if x.Dialect().URI().DBType == schemas.MYSQL {
|
||||
_, err = x.SQL("SELECT DEFAULT_COLLATION_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?", setting.Database.Name).Get(&res.DatabaseCollation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.IsCollationCaseSensitive = func(s string) bool {
|
||||
return s == "utf8mb4_bin" || strings.HasSuffix(s, "_as_cs")
|
||||
}
|
||||
candidateCollations = []string{"utf8mb4_0900_as_cs", "uca1400_as_cs", "utf8mb4_bin"}
|
||||
res.AvailableCollation, err = findAvailableCollationsMySQL(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.CollationEquals = func(a, b string) bool {
|
||||
// MariaDB adds the "utf8mb4_" prefix, eg: "utf8mb4_uca1400_as_cs", but not the name "uca1400_as_cs" in "SHOW COLLATION"
|
||||
// At the moment, it's safe to ignore the database difference, just trim the prefix and compare. It could be fixed easily if there is any problem in the future.
|
||||
return a == b || strings.TrimPrefix(a, "utf8mb4_") == strings.TrimPrefix(b, "utf8mb4_")
|
||||
}
|
||||
} else if x.Dialect().URI().DBType == schemas.MSSQL {
|
||||
if _, err = x.SQL("SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation')").Get(&res.DatabaseCollation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.IsCollationCaseSensitive = func(s string) bool {
|
||||
return strings.HasSuffix(s, "_CS_AS")
|
||||
}
|
||||
candidateCollations = []string{"Latin1_General_CS_AS"}
|
||||
res.AvailableCollation, err = findAvailableCollationsMSSQL(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if res.DatabaseCollation == "" {
|
||||
return nil, errors.New("unable to get collation for current database")
|
||||
}
|
||||
|
||||
res.ExpectedCollation = setting.Database.CharsetCollation
|
||||
if res.ExpectedCollation == "" {
|
||||
for _, collation := range candidateCollations {
|
||||
if res.AvailableCollation.Contains(collation) {
|
||||
res.ExpectedCollation = collation
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if res.ExpectedCollation == "" {
|
||||
return nil, errors.New("unable to find a suitable collation for current database")
|
||||
}
|
||||
|
||||
allColumnsMatchExpected := true
|
||||
allColumnsMatchDatabase := true
|
||||
for _, table := range dbTables {
|
||||
for _, col := range table.Columns() {
|
||||
if col.Collation != "" {
|
||||
allColumnsMatchExpected = allColumnsMatchExpected && res.CollationEquals(col.Collation, res.ExpectedCollation)
|
||||
allColumnsMatchDatabase = allColumnsMatchDatabase && res.CollationEquals(col.Collation, res.DatabaseCollation)
|
||||
if !res.IsCollationCaseSensitive(col.Collation) || !res.CollationEquals(col.Collation, res.DatabaseCollation) {
|
||||
res.InconsistentCollationColumns = append(res.InconsistentCollationColumns, fmt.Sprintf("%s.%s", table.Name, col.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if all columns match expected collation or all match database collation, then it could also be considered as "consistent"
|
||||
if allColumnsMatchExpected || allColumnsMatchDatabase {
|
||||
res.InconsistentCollationColumns = nil
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func CheckCollationsDefaultEngine() (*CheckCollationsResult, error) {
|
||||
return CheckCollations(xormEngine)
|
||||
}
|
||||
|
||||
func alterDatabaseCollation(x *xorm.Engine, collation string) error {
|
||||
if x.Dialect().URI().DBType == schemas.MYSQL {
|
||||
_, err := x.Exec("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE " + collation)
|
||||
return err
|
||||
} else if x.Dialect().URI().DBType == schemas.MSSQL {
|
||||
// TODO: MSSQL has many limitations on changing database collation, it could fail in many cases.
|
||||
_, err := x.Exec("ALTER DATABASE CURRENT COLLATE " + collation)
|
||||
return err
|
||||
}
|
||||
return errors.New("unsupported database type")
|
||||
}
|
||||
|
||||
// preprocessDatabaseCollation checks database & table column collation, and alter the database collation if needed
|
||||
func preprocessDatabaseCollation(x *xorm.Engine) {
|
||||
r, err := CheckCollations(x)
|
||||
if err != nil {
|
||||
log.Error("Failed to check database collation: %v", err)
|
||||
}
|
||||
if r == nil {
|
||||
return // no check result means the database doesn't need to do such check/process (at the moment ....)
|
||||
}
|
||||
|
||||
// try to alter database collation to expected if the database is empty, it might fail in some cases (and it isn't necessary to succeed)
|
||||
// at the moment, there is no "altering" solution for MSSQL, site admin should manually change the database collation
|
||||
if !r.CollationEquals(r.DatabaseCollation, r.ExpectedCollation) && r.ExistingTableNumber == 0 {
|
||||
if err = alterDatabaseCollation(x, r.ExpectedCollation); err != nil {
|
||||
log.Error("Failed to change database collation to %q: %v", r.ExpectedCollation, err)
|
||||
} else {
|
||||
_, _ = x.Exec("SELECT 1") // after "altering", MSSQL's session becomes invalid, so make a simple query to "refresh" the session
|
||||
if r, err = CheckCollations(x); err != nil {
|
||||
log.Error("Failed to check database collation again after altering: %v", err) // impossible case
|
||||
return
|
||||
}
|
||||
log.Warn("Current database has been altered to use collation %q", r.DatabaseCollation)
|
||||
}
|
||||
}
|
||||
|
||||
// check column collation, and show warning/error to end users -- no need to fatal, do not block the startup
|
||||
if !r.IsCollationCaseSensitive(r.DatabaseCollation) {
|
||||
log.Warn("Current database is using a case-insensitive collation %q, although Gitea could work with it, there might be some rare cases which don't work as expected.", r.DatabaseCollation)
|
||||
}
|
||||
|
||||
if len(r.InconsistentCollationColumns) > 0 {
|
||||
log.Error("There are %d table columns using inconsistent collation, they should use %q. Please go to admin panel Self Check page", len(r.InconsistentCollationColumns), r.DatabaseCollation)
|
||||
}
|
||||
}
|
||||
55
models/db/common.go
Normal file
55
models/db/common.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively.
|
||||
// Handles especially SQLite correctly as UPPER there only transforms ASCII letters.
|
||||
func BuildCaseInsensitiveLike(key, value string) builder.Cond {
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
return builder.Like{"UPPER(" + key + ")", util.ToUpperASCII(value)}
|
||||
}
|
||||
return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)}
|
||||
}
|
||||
|
||||
// BuildCaseInsensitiveIn returns a condition to check if the given value is in the given values case-insensitively.
|
||||
// Handles especially SQLite correctly as UPPER there only transforms ASCII letters.
|
||||
func BuildCaseInsensitiveIn(key string, values []string) builder.Cond {
|
||||
uppers := make([]string, 0, len(values))
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
for _, value := range values {
|
||||
uppers = append(uppers, util.ToUpperASCII(value))
|
||||
}
|
||||
} else {
|
||||
for _, value := range values {
|
||||
uppers = append(uppers, strings.ToUpper(value))
|
||||
}
|
||||
}
|
||||
|
||||
return builder.In("UPPER("+key+")", uppers)
|
||||
}
|
||||
|
||||
// BuilderDialect returns the xorm.Builder dialect of the engine
|
||||
func BuilderDialect() string {
|
||||
switch {
|
||||
case setting.Database.Type.IsMySQL():
|
||||
return builder.MYSQL
|
||||
case setting.Database.Type.IsSQLite3():
|
||||
return builder.SQLITE
|
||||
case setting.Database.Type.IsPostgreSQL():
|
||||
return builder.POSTGRES
|
||||
case setting.Database.Type.IsMSSQL():
|
||||
return builder.MSSQL
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
31
models/db/consistency.go
Normal file
31
models/db/consistency.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// CountOrphanedObjects count subjects with have no existing refobject anymore
|
||||
func CountOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) (int64, error) {
|
||||
return GetEngine(ctx).
|
||||
Table("`"+subject+"`").
|
||||
Join("LEFT", "`"+refObject+"`", joinCond).
|
||||
Where(builder.IsNull{"`" + refObject + "`.id"}).
|
||||
Select("COUNT(`" + subject + "`.`id`)").
|
||||
Count()
|
||||
}
|
||||
|
||||
// DeleteOrphanedObjects delete subjects with have no existing refobject anymore
|
||||
func DeleteOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) error {
|
||||
subQuery := builder.Select("`"+subject+"`.id").
|
||||
From("`"+subject+"`").
|
||||
Join("LEFT", "`"+refObject+"`", joinCond).
|
||||
Where(builder.IsNull{"`" + refObject + "`.id"})
|
||||
b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`")
|
||||
_, err := GetEngine(ctx).Exec(b)
|
||||
return err
|
||||
}
|
||||
317
models/db/context.go
Normal file
317
models/db/context.go
Normal file
@@ -0,0 +1,317 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
type engineContextKeyType struct{}
|
||||
|
||||
var engineContextKey = engineContextKeyType{}
|
||||
|
||||
func withContextEngine(ctx context.Context, e Engine) context.Context {
|
||||
return context.WithValue(ctx, engineContextKey, e)
|
||||
}
|
||||
|
||||
var (
|
||||
contextSafetyOnce sync.Once
|
||||
contextSafetyDeniedFuncPCs []uintptr
|
||||
)
|
||||
|
||||
func contextSafetyCheck(e Engine) {
|
||||
if setting.IsProd && !setting.IsInTesting {
|
||||
return
|
||||
}
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
|
||||
contextSafetyOnce.Do(func() {
|
||||
// try to figure out the bad functions to deny
|
||||
type m struct{}
|
||||
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
|
||||
callers := make([]uintptr, 32)
|
||||
callerNum := runtime.Callers(1, callers)
|
||||
for i := range callerNum {
|
||||
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
|
||||
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if len(contextSafetyDeniedFuncPCs) != 1 {
|
||||
panic(errors.New("unable to determine the functions to deny"))
|
||||
}
|
||||
})
|
||||
|
||||
// it should be very fast: xxxx ns/op
|
||||
callers := make([]uintptr, 32)
|
||||
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
|
||||
for i := range callerNum {
|
||||
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
|
||||
panic(errors.New("using session context in an iterator would cause corrupted results"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetEngine gets an existing db Engine/Statement or creates a new Session
|
||||
func GetEngine(ctx context.Context) Engine {
|
||||
if engine, ok := ctx.Value(engineContextKey).(Engine); ok {
|
||||
// if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session
|
||||
contextSafetyCheck(engine)
|
||||
return engine
|
||||
}
|
||||
// no need to do "contextSafetyCheck" because it's a new Session
|
||||
return xormEngine.Context(ctx)
|
||||
}
|
||||
|
||||
func GetXORMEngineForTesting() *xorm.Engine {
|
||||
return xormEngine
|
||||
}
|
||||
|
||||
// Committer represents an interface to Commit or Close the Context
|
||||
type Committer interface {
|
||||
Commit() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// halfCommitter is a wrapper of Committer.
|
||||
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
|
||||
type halfCommitter struct {
|
||||
committer Committer
|
||||
committed bool
|
||||
}
|
||||
|
||||
func (c *halfCommitter) Commit() error {
|
||||
c.committed = true
|
||||
// should do nothing, and the parent committer will commit later
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfCommitter) Close() error {
|
||||
if c.committed {
|
||||
// it's "commit and close", should do nothing, and the parent committer will commit later
|
||||
return nil
|
||||
}
|
||||
|
||||
// it's "rollback and close", let the parent committer rollback right now
|
||||
return c.committer.Close()
|
||||
}
|
||||
|
||||
// TxContext represents a transaction Context,
|
||||
// it will reuse the existing transaction in the parent context or create a new one.
|
||||
// Some tips to use:
|
||||
//
|
||||
// 1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically.
|
||||
// 2. To maintain the old code which uses `TxContext`:
|
||||
// a. Always call `Close()` before returning regardless of whether `Commit()` has been called.
|
||||
// b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data.
|
||||
// c. Remember the `Committer` will be a halfCommitter when a transaction is being reused.
|
||||
// So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction.
|
||||
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
|
||||
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
|
||||
func TxContext(parentCtx context.Context) (context.Context, Committer, error) {
|
||||
if sess := getTransactionSession(parentCtx); sess != nil {
|
||||
return withContextEngine(parentCtx, sess), &halfCommitter{committer: sess}, nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
if err := sess.Begin(); err != nil {
|
||||
_ = sess.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return withContextEngine(parentCtx, sess), sess, nil
|
||||
}
|
||||
|
||||
// WithTx represents executing database operations on a transaction, if the transaction exist,
|
||||
// this function will reuse it otherwise will create a new one and close it when finished.
|
||||
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
if sess := getTransactionSession(parentCtx); sess != nil {
|
||||
err := f(withContextEngine(parentCtx, sess))
|
||||
if err != nil {
|
||||
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
|
||||
_ = sess.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
return txWithNoCheck(parentCtx, f)
|
||||
}
|
||||
|
||||
// WithTx2 is similar to WithTx, but it has two return values: result and error.
|
||||
func WithTx2[T any](parentCtx context.Context, f func(ctx context.Context) (T, error)) (ret T, errRet error) {
|
||||
errRet = WithTx(parentCtx, func(ctx context.Context) (errInner error) {
|
||||
ret, errInner = f(ctx)
|
||||
return errInner
|
||||
})
|
||||
return ret, errRet
|
||||
}
|
||||
|
||||
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
if err := sess.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f(withContextEngine(parentCtx, sess)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sess.Commit()
|
||||
}
|
||||
|
||||
// Insert inserts records into database
|
||||
func Insert(ctx context.Context, beans ...any) error {
|
||||
_, err := GetEngine(ctx).Insert(beans...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Exec executes a sql with args
|
||||
func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
|
||||
return GetEngine(ctx).Exec(sqlAndArgs...)
|
||||
}
|
||||
|
||||
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
|
||||
if !cond.IsValid() {
|
||||
panic("cond is invalid in db.Get(ctx, cond). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !has {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &bean, true, nil
|
||||
}
|
||||
|
||||
func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
|
||||
var bean T
|
||||
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !has {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &bean, true, nil
|
||||
}
|
||||
|
||||
func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
|
||||
if !cond.IsValid() {
|
||||
panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
|
||||
}
|
||||
|
||||
func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
|
||||
var bean T
|
||||
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
|
||||
}
|
||||
|
||||
// DeleteByID deletes the given bean with the given ID
|
||||
func DeleteByID[T any](ctx context.Context, id int64) (int64, error) {
|
||||
var bean T
|
||||
return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
}
|
||||
|
||||
func DeleteByIDs[T any](ctx context.Context, ids ...int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var bean T
|
||||
_, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
return err
|
||||
}
|
||||
|
||||
func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) {
|
||||
if opts == nil || !opts.ToConds().IsValid() {
|
||||
panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
}
|
||||
|
||||
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
|
||||
func DeleteByBean(ctx context.Context, bean any) (int64, error) {
|
||||
return GetEngine(ctx).Delete(bean)
|
||||
}
|
||||
|
||||
// FindIDs finds the IDs for the given table name satisfying the given condition
|
||||
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
|
||||
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
|
||||
ids := make([]int64, 0, 10)
|
||||
if err := GetEngine(ctx).Table(tableName).
|
||||
Cols(idCol).
|
||||
Where(cond).
|
||||
Find(&ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
|
||||
// Timestamps of the entities won't be updated
|
||||
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteBeans deletes all given beans, beans must contain delete conditions.
|
||||
func DeleteBeans(ctx context.Context, beans ...any) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
for i := range beans {
|
||||
if _, err = e.Delete(beans[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TruncateBeans deletes all given beans, beans may contain delete conditions.
|
||||
func TruncateBeans(ctx context.Context, beans ...any) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
for i := range beans {
|
||||
if _, err = e.Truncate(beans[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
|
||||
func CountByBean(ctx context.Context, bean any) (int64, error) {
|
||||
return GetEngine(ctx).Count(bean)
|
||||
}
|
||||
|
||||
// InTransaction returns true if the engine is in a transaction otherwise return false
|
||||
func InTransaction(ctx context.Context) bool {
|
||||
return getTransactionSession(ctx) != nil
|
||||
}
|
||||
|
||||
func getTransactionSession(ctx context.Context) *xorm.Session {
|
||||
e, _ := ctx.Value(engineContextKey).(Engine)
|
||||
if sess, ok := e.(*xorm.Session); ok && sess.IsInTx() {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
102
models/db/context_committer_test.go
Normal file
102
models/db/context_committer_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db // it's not db_test, because this file is for testing the private type halfCommitter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockCommitter struct {
|
||||
wants []string
|
||||
gots []string
|
||||
}
|
||||
|
||||
func NewMockCommitter(wants ...string) *MockCommitter {
|
||||
return &MockCommitter{
|
||||
wants: wants,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Commit() error {
|
||||
c.gots = append(c.gots, "commit")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Close() error {
|
||||
c.gots = append(c.gots, "close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Assert(t *testing.T) {
|
||||
assert.Equal(t, c.wants, c.gots, "want operations %v, but got %v", c.wants, c.gots)
|
||||
}
|
||||
|
||||
func Test_halfCommitter(t *testing.T) {
|
||||
/*
|
||||
Do something like:
|
||||
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
// ...
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ...
|
||||
|
||||
return committer.Commit()
|
||||
*/
|
||||
|
||||
testWithCommitter := func(committer Committer, f func(committer Committer) error) {
|
||||
if err := f(&halfCommitter{committer: committer}); err == nil {
|
||||
committer.Commit()
|
||||
}
|
||||
committer.Close()
|
||||
}
|
||||
|
||||
t.Run("commit and close", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("commit", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
defer committer.Close()
|
||||
return committer.Commit()
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
|
||||
t.Run("rollback and close", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("close", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
defer committer.Close()
|
||||
if true {
|
||||
return errors.New("error")
|
||||
}
|
||||
return committer.Commit()
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
|
||||
t.Run("close and commit", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("close", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
committer.Close()
|
||||
committer.Commit()
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
}
|
||||
135
models/db/context_test.go
Normal file
135
models/db/context_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInTransaction(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.False(t, db.InTransaction(t.Context()))
|
||||
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
assert.NoError(t, err)
|
||||
defer committer.Close()
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, db.WithTx(ctx, func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestTxContext(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
{ // create new transaction
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by TxContext and commit it
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
engine := db.GetEngine(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.Equal(t, engine, db.GetEngine(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by TxContext and close it
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
engine := db.GetEngine(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.Equal(t, engine, db.GetEngine(ctx))
|
||||
assert.NoError(t, committer.Close())
|
||||
}
|
||||
assert.NoError(t, committer.Close())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by WithTx
|
||||
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextSafety(t *testing.T) {
|
||||
type TestModel1 struct {
|
||||
ID int64
|
||||
}
|
||||
type TestModel2 struct {
|
||||
ID int64
|
||||
}
|
||||
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
|
||||
assert.NoError(t, db.TruncateBeans(t.Context(), &TestModel1{}, &TestModel2{}))
|
||||
testCount := 10
|
||||
for i := 1; i <= testCount; i++ {
|
||||
assert.NoError(t, db.Insert(t.Context(), &TestModel1{ID: int64(i)}))
|
||||
assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)}))
|
||||
}
|
||||
|
||||
t.Run("Show-XORM-Bug", func(t *testing.T) {
|
||||
actualCount := 0
|
||||
// here: db.GetEngine(t.Context()) is a new *Session created from *Engine
|
||||
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
|
||||
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
|
||||
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
|
||||
m1 := bean.(*TestModel1)
|
||||
assert.EqualValues(t, i+1, m1.ID)
|
||||
|
||||
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
|
||||
// and it conflicts with the "Iterate"'s internal states.
|
||||
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
|
||||
|
||||
actualCount++
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, testCount, actualCount)
|
||||
})
|
||||
|
||||
t.Run("DenyBadUsage", func(t *testing.T) {
|
||||
assert.PanicsWithError(t, "using session context in an iterator would cause corrupted results", func() {
|
||||
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
return db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
|
||||
_ = db.GetEngine(ctx)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
88
models/db/convert.go
Normal file
88
models/db/convert.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// ConvertDatabaseTable converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic
|
||||
func ConvertDatabaseTable() error {
|
||||
if xormEngine.Dialect().URI().DBType != schemas.MYSQL {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, err := CheckCollations(xormEngine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = xormEngine.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE %s", setting.Database.Name, r.ExpectedCollation))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tables, err := xormEngine.DBMetas()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, table := range tables {
|
||||
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic", table.Name)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE %s", table.Name, r.ExpectedCollation)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertVarcharToNVarchar converts database and tables from varchar to nvarchar if it's mssql
|
||||
func ConvertVarcharToNVarchar() error {
|
||||
if xormEngine.Dialect().URI().DBType != schemas.MSSQL {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
res, err := sess.QuerySliceString(`SELECT 'ALTER TABLE ' + OBJECT_NAME(SC.object_id) + ' MODIFY SC.name NVARCHAR(' + CONVERT(VARCHAR(5),SC.max_length) + ')'
|
||||
FROM SYS.columns SC
|
||||
JOIN SYS.types ST
|
||||
ON SC.system_type_id = ST.system_type_id
|
||||
AND SC.user_type_id = ST.user_type_id
|
||||
WHERE ST.name ='varchar'`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range res {
|
||||
if len(row) == 1 {
|
||||
if _, err = sess.Exec(row[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Cell2Int64 converts a xorm.Cell type to int64,
|
||||
// and handles possible irregular cases.
|
||||
func Cell2Int64(val xorm.Cell) int64 {
|
||||
switch (*val).(type) {
|
||||
case []uint8:
|
||||
log.Trace("Cell2Int64 ([]uint8): %v", *val)
|
||||
|
||||
v, _ := strconv.ParseInt(string((*val).([]uint8)), 10, 64)
|
||||
return v
|
||||
}
|
||||
return (*val).(int64)
|
||||
}
|
||||
147
models/db/engine.go
Executable file
147
models/db/engine.go
Executable file
@@ -0,0 +1,147 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"xorm.io/xorm"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver
|
||||
_ "github.com/lib/pq" // Needed for the Postgresql driver
|
||||
_ "github.com/microsoft/go-mssqldb" // Needed for the MSSQL driver
|
||||
)
|
||||
|
||||
var (
|
||||
xormEngine *xorm.Engine
|
||||
registeredModels []any
|
||||
registeredInitFuncs []func() error
|
||||
)
|
||||
|
||||
// Engine represents a xorm engine or session.
|
||||
type Engine interface {
|
||||
Table(tableNameOrBean any) *xorm.Session
|
||||
Count(...any) (int64, error)
|
||||
Decr(column string, arg ...any) *xorm.Session
|
||||
Delete(...any) (int64, error)
|
||||
Truncate(...any) (int64, error)
|
||||
Exec(...any) (sql.Result, error)
|
||||
Find(any, ...any) error
|
||||
Get(beans ...any) (bool, error)
|
||||
ID(any) *xorm.Session
|
||||
In(string, ...any) *xorm.Session
|
||||
Incr(column string, arg ...any) *xorm.Session
|
||||
Insert(...any) (int64, error)
|
||||
Iterate(any, xorm.IterFunc) error
|
||||
Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session
|
||||
SQL(any, ...any) *xorm.Session
|
||||
Where(any, ...any) *xorm.Session
|
||||
Asc(colNames ...string) *xorm.Session
|
||||
Desc(colNames ...string) *xorm.Session
|
||||
Limit(limit int, start ...int) *xorm.Session
|
||||
NoAutoTime() *xorm.Session
|
||||
SumInt(bean any, columnName string) (res int64, err error)
|
||||
Sync(...any) error
|
||||
Select(string) *xorm.Session
|
||||
SetExpr(string, any) *xorm.Session
|
||||
NotIn(string, ...any) *xorm.Session
|
||||
OrderBy(any, ...any) *xorm.Session
|
||||
Exist(...any) (bool, error)
|
||||
Distinct(...string) *xorm.Session
|
||||
Query(...any) ([]map[string][]byte, error)
|
||||
Cols(...string) *xorm.Session
|
||||
Context(ctx context.Context) *xorm.Session
|
||||
Ping() error
|
||||
IsTableExist(tableNameOrBean any) (bool, error)
|
||||
}
|
||||
|
||||
var (
|
||||
_ Engine = (*xorm.Engine)(nil)
|
||||
_ Engine = (*xorm.Session)(nil)
|
||||
)
|
||||
|
||||
// RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync
|
||||
func RegisterModel(bean any, initFunc ...func() error) {
|
||||
registeredModels = append(registeredModels, bean)
|
||||
if len(registeredInitFuncs) > 0 && initFunc[0] != nil {
|
||||
registeredInitFuncs = append(registeredInitFuncs, initFunc[0])
|
||||
}
|
||||
}
|
||||
|
||||
// SyncAllTables sync the schemas of all tables, is required by unit test code
|
||||
func SyncAllTables() error {
|
||||
_, err := xormEngine.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
|
||||
WarnIfDatabaseColumnMissed: true,
|
||||
}, registeredModels...)
|
||||
return err
|
||||
}
|
||||
|
||||
// NamesToBean return a list of beans or an error
|
||||
func NamesToBean(names ...string) ([]any, error) {
|
||||
beans := []any{}
|
||||
if len(names) == 0 {
|
||||
beans = append(beans, registeredModels...)
|
||||
return beans, nil
|
||||
}
|
||||
// Need to map provided names to beans...
|
||||
beanMap := make(map[string]any)
|
||||
for _, bean := range registeredModels {
|
||||
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
|
||||
beanMap[strings.ToLower(xormEngine.TableName(bean))] = bean
|
||||
beanMap[strings.ToLower(xormEngine.TableName(bean, true))] = bean
|
||||
}
|
||||
|
||||
gotBean := make(map[any]bool)
|
||||
for _, name := range names {
|
||||
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no table found that matches: %s", name)
|
||||
}
|
||||
if !gotBean[bean] {
|
||||
beans = append(beans, bean)
|
||||
gotBean[bean] = true
|
||||
}
|
||||
}
|
||||
return beans, nil
|
||||
}
|
||||
|
||||
// MaxBatchInsertSize returns the table's max batch insert size
|
||||
func MaxBatchInsertSize(bean any) int {
|
||||
t, err := xormEngine.TableInfo(bean)
|
||||
if err != nil {
|
||||
return 50
|
||||
}
|
||||
return 999 / len(t.ColumnsSeq())
|
||||
}
|
||||
|
||||
// IsTableNotEmpty returns true if table has at least one record
|
||||
func IsTableNotEmpty(beanOrTableName any) (bool, error) {
|
||||
return xormEngine.Table(beanOrTableName).Exist()
|
||||
}
|
||||
|
||||
// DeleteAllRecords will delete all the records of this table
|
||||
func DeleteAllRecords(tableName string) error {
|
||||
_, err := xormEngine.Exec("DELETE FROM " + tableName)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMaxID will return max id of the table
|
||||
func GetMaxID(beanOrTableName any) (maxID int64, err error) {
|
||||
_, err = xormEngine.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
|
||||
return maxID, err
|
||||
}
|
||||
|
||||
func SetLogSQL(ctx context.Context, on bool) {
|
||||
e := GetEngine(ctx)
|
||||
if x, ok := e.(*xorm.Engine); ok {
|
||||
x.ShowSQL(on)
|
||||
} else if sess, ok := e.(*xorm.Session); ok {
|
||||
sess.Engine().ShowSQL(on)
|
||||
}
|
||||
}
|
||||
33
models/db/engine_dump.go
Normal file
33
models/db/engine_dump.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import "xorm.io/xorm/schemas"
|
||||
|
||||
// DumpDatabase dumps all data from database according the special database SQL syntax to file system.
|
||||
func DumpDatabase(filePath, dbType string) error {
|
||||
var tbs []*schemas.Table
|
||||
for _, t := range registeredModels {
|
||||
t, err := xormEngine.TableInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tbs = append(tbs, t)
|
||||
}
|
||||
|
||||
type Version struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Version int64
|
||||
}
|
||||
t, err := xormEngine.TableInfo(&Version{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tbs = append(tbs, t)
|
||||
|
||||
if dbType != "" {
|
||||
return xormEngine.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
|
||||
}
|
||||
return xormEngine.DumpTablesToFile(tbs, filePath)
|
||||
}
|
||||
47
models/db/engine_hook.go
Normal file
47
models/db/engine_hook.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"code.gitea.io/gitea/modules/gtprof"
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/xorm/contexts"
|
||||
)
|
||||
|
||||
type EngineHook struct {
|
||||
Threshold time.Duration
|
||||
Logger log.Logger
|
||||
}
|
||||
|
||||
var _ contexts.Hook = (*EngineHook)(nil)
|
||||
|
||||
func (*EngineHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) {
|
||||
ctx, _ := gtprof.GetTracer().Start(c.Ctx, gtprof.TraceSpanDatabase)
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (h *EngineHook) AfterProcess(c *contexts.ContextHook) error {
|
||||
span := gtprof.GetContextSpan(c.Ctx)
|
||||
if span != nil {
|
||||
// Do not record SQL parameters here:
|
||||
// * It shouldn't expose the parameters because they contain sensitive information, end users need to report the trace details safely.
|
||||
// * Some parameters contain quite long texts, waste memory and are difficult to display.
|
||||
span.SetAttributeString(gtprof.TraceAttrDbSQL, c.SQL)
|
||||
span.End()
|
||||
} else {
|
||||
setting.PanicInDevOrTesting("span in database engine hook is nil")
|
||||
}
|
||||
if c.ExecuteTime >= h.Threshold {
|
||||
// 8 is the amount of skips passed to runtime.Caller, so that in the log the correct function
|
||||
// is being displayed (the function that ultimately wants to execute the query in the code)
|
||||
// instead of the function of the slow query hook being called.
|
||||
h.Logger.Log(8, &log.Event{Level: log.WARN}, "[Slow SQL Query] %s %v - %v", c.SQL, c.Args, c.ExecuteTime)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
139
models/db/engine_init.go
Normal file
139
models/db/engine_init.go
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/names"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gonicNames := []string{"SSL", "UID"}
|
||||
for _, name := range gonicNames {
|
||||
names.LintGonicMapper[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// newXORMEngine returns a new XORM engine from the configuration
|
||||
func newXORMEngine() (*xorm.Engine, error) {
|
||||
connStr, err := setting.DBConnStr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var engine *xorm.Engine
|
||||
|
||||
if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 {
|
||||
// OK whilst we sort out our schema issues - create a schema aware postgres
|
||||
registerPostgresSchemaDriver()
|
||||
engine, err = xorm.NewEngine("postgresschema", connStr)
|
||||
} else {
|
||||
engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch setting.Database.Type {
|
||||
case "mysql":
|
||||
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
|
||||
case "mssql":
|
||||
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
|
||||
}
|
||||
engine.SetSchema(setting.Database.Schema)
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// InitEngine initializes the xorm.Engine and sets it as XORM's default context
|
||||
func InitEngine(ctx context.Context) error {
|
||||
xe, err := newXORMEngine()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "SQLite3 support") {
|
||||
return fmt.Errorf(`sqlite3 requires: -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err)
|
||||
}
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
xe.SetMapper(names.GonicMapper{})
|
||||
// WARNING: for serv command, MUST remove the output to os.stdout,
|
||||
// so use log file to instead print to stdout.
|
||||
xe.SetLogger(NewXORMLogger(setting.Database.LogSQL))
|
||||
xe.ShowSQL(setting.Database.LogSQL)
|
||||
xe.SetMaxOpenConns(setting.Database.MaxOpenConns)
|
||||
xe.SetMaxIdleConns(setting.Database.MaxIdleConns)
|
||||
xe.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
|
||||
|
||||
if setting.Database.SlowQueryThreshold > 0 {
|
||||
xe.AddHook(&EngineHook{
|
||||
Threshold: setting.Database.SlowQueryThreshold,
|
||||
Logger: log.GetLogger("xorm"),
|
||||
})
|
||||
}
|
||||
|
||||
SetDefaultEngine(ctx, xe)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultEngine sets the default engine for db
|
||||
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
|
||||
xormEngine = eng
|
||||
xormEngine.SetDefaultContext(ctx)
|
||||
}
|
||||
|
||||
// UnsetDefaultEngine closes and unsets the default engine
|
||||
// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now,
|
||||
// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `xormEngine` and `xormContext` without close
|
||||
// Global database engine related functions are all racy and there is no graceful close right now.
|
||||
func UnsetDefaultEngine() {
|
||||
if xormEngine != nil {
|
||||
_ = xormEngine.Close()
|
||||
xormEngine = nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitEngineWithMigration initializes a new xorm.Engine and sets it as the XORM's default context
|
||||
// This function must never call .Sync() if the provided migration function fails.
|
||||
// When called from the "doctor" command, the migration function is a version check
|
||||
// that prevents the doctor from fixing anything in the database if the migration level
|
||||
// is different from the expected value.
|
||||
func InitEngineWithMigration(ctx context.Context, migrateFunc func(context.Context, *xorm.Engine) error) (err error) {
|
||||
if err = InitEngine(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = xormEngine.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
preprocessDatabaseCollation(xormEngine)
|
||||
|
||||
// We have to run migrateFunc here in case the user is re-running installation on a previously created DB.
|
||||
// If we do not then table schemas will be changed and there will be conflicts when the migrations run properly.
|
||||
//
|
||||
// Installation should only be being re-run if users want to recover an old database.
|
||||
// However, we should think carefully about should we support re-install on an installed instance,
|
||||
// as there may be other problems due to secret reinitialization.
|
||||
if err = migrateFunc(ctx, xormEngine); err != nil {
|
||||
return fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
|
||||
if err = SyncAllTables(); err != nil {
|
||||
return fmt.Errorf("sync database struct error: %w", err)
|
||||
}
|
||||
|
||||
for _, initFunc := range registeredInitFuncs {
|
||||
if err := initFunc(); err != nil {
|
||||
return fmt.Errorf("initFunc failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
83
models/db/engine_test.go
Normal file
83
models/db/engine_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
issues_model "code.gitea.io/gitea/models/issues"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
_ "code.gitea.io/gitea/cmd" // for TestPrimaryKeys
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDumpDatabase(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
type Version struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Version int64
|
||||
}
|
||||
assert.NoError(t, db.GetEngine(t.Context()).Sync(new(Version)))
|
||||
|
||||
for _, dbType := range setting.SupportedDatabaseTypes {
|
||||
assert.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), dbType))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteOrphanedObjects(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
countBefore, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.GetEngine(t.Context()).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003})
|
||||
assert.NoError(t, err)
|
||||
|
||||
orphaned, err := db.CountOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, orphaned)
|
||||
|
||||
err = db.DeleteOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
|
||||
assert.NoError(t, err)
|
||||
|
||||
countAfter, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, countBefore, countAfter)
|
||||
}
|
||||
|
||||
func TestPrimaryKeys(t *testing.T) {
|
||||
// Some dbs require that all tables have primary keys, see
|
||||
// https://github.com/go-gitea/gitea/issues/21086
|
||||
// https://github.com/go-gitea/gitea/issues/16802
|
||||
// To avoid creating tables without primary key again, this test will check them.
|
||||
// Import "code.gitea.io/gitea/cmd" to make sure each db.RegisterModel in init functions has been called.
|
||||
|
||||
beans, err := db.NamesToBean()
|
||||
require.NoError(t, err)
|
||||
|
||||
whitelist := map[string]string{
|
||||
"the_table_name_to_skip_checking": "Write a note here to explain why",
|
||||
}
|
||||
|
||||
for _, bean := range beans {
|
||||
table, err := db.GetXORMEngineForTesting().TableInfo(bean)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if why, ok := whitelist[table.Name]; ok {
|
||||
t.Logf("ignore %q because %q", table.Name, why)
|
||||
continue
|
||||
}
|
||||
assert.NotEmpty(t, table.PrimaryKeys, "table %q has no primary key", table.Name)
|
||||
}
|
||||
}
|
||||
74
models/db/error.go
Normal file
74
models/db/error.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// ErrCancelled represents an error due to context cancellation
|
||||
type ErrCancelled struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// IsErrCancelled checks if an error is a ErrCancelled.
|
||||
func IsErrCancelled(err error) bool {
|
||||
_, ok := err.(ErrCancelled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrCancelled) Error() string {
|
||||
return "Cancelled: " + err.Message
|
||||
}
|
||||
|
||||
// ErrCancelledf returns an ErrCancelled for the provided format and args
|
||||
func ErrCancelledf(format string, args ...any) error {
|
||||
return ErrCancelled{
|
||||
fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrSSHDisabled represents an "SSH disabled" error.
|
||||
type ErrSSHDisabled struct{}
|
||||
|
||||
// IsErrSSHDisabled checks if an error is a ErrSSHDisabled.
|
||||
func IsErrSSHDisabled(err error) bool {
|
||||
_, ok := err.(ErrSSHDisabled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSSHDisabled) Error() string {
|
||||
return "SSH is disabled"
|
||||
}
|
||||
|
||||
// ErrNotExist represents a non-exist error.
|
||||
type ErrNotExist struct {
|
||||
Resource string
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrNotExist checks if an error is an ErrNotExist
|
||||
func IsErrNotExist(err error) bool {
|
||||
_, ok := err.(ErrNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNotExist) Error() string {
|
||||
name := "record"
|
||||
if err.Resource != "" {
|
||||
name = err.Resource
|
||||
}
|
||||
|
||||
if err.ID != 0 {
|
||||
return fmt.Sprintf("%s does not exist [id: %d]", name, err.ID)
|
||||
}
|
||||
return name + " does not exist"
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
172
models/db/index.go
Normal file
172
models/db/index.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
)
|
||||
|
||||
// ResourceIndex represents a resource index which could be used as issue/release and others
|
||||
// We can create different tables i.e. issue_index, release_index, etc.
|
||||
type ResourceIndex struct {
|
||||
GroupID int64 `xorm:"pk"`
|
||||
MaxIndex int64 `xorm:"index"`
|
||||
}
|
||||
|
||||
var ErrGetResourceIndexFailed = errors.New("get resource index failed")
|
||||
|
||||
// SyncMaxResourceIndex sync the max index with the resource
|
||||
func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
|
||||
// try to update the max_index and acquire the write-lock for the record
|
||||
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=? WHERE group_id=? AND max_index<?", tableName), maxIndex, groupID, maxIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// if nothing is updated, the record might not exist or might be larger, it's safe to try to insert it again and then check whether the record exists
|
||||
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, ?)", tableName), groupID, maxIndex)
|
||||
var savedIdx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&savedIdx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if the record still doesn't exist, there must be some errors (insert error)
|
||||
if !has {
|
||||
if errIns == nil {
|
||||
return errors.New("impossible error when SyncMaxResourceIndex, insert succeeded but no record is saved")
|
||||
}
|
||||
return errIns
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postgresGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
res, err := GetEngine(ctx).Query(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
|
||||
"VALUES (?,1) ON CONFLICT (group_id) DO UPDATE SET max_index = %s.max_index+1 RETURNING max_index",
|
||||
tableName, tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return 0, ErrGetResourceIndexFailed
|
||||
}
|
||||
return strconv.ParseInt(string(res[0]["max_index"]), 10, 64)
|
||||
}
|
||||
|
||||
func mysqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
if _, err := GetEngine(ctx).Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
|
||||
"VALUES (?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
|
||||
tableName), groupID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var idx int64
|
||||
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if idx == 0 {
|
||||
return 0, errors.New("cannot get the correct index")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func mssqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
if _, err := GetEngine(ctx).Exec(fmt.Sprintf(`
|
||||
MERGE INTO %s WITH (HOLDLOCK) AS target
|
||||
USING (SELECT %d AS group_id) AS source
|
||||
(group_id)
|
||||
ON target.group_id = source.group_id
|
||||
WHEN MATCHED
|
||||
THEN UPDATE
|
||||
SET max_index = max_index + 1
|
||||
WHEN NOT MATCHED
|
||||
THEN INSERT (group_id, max_index)
|
||||
VALUES (%d, 1);
|
||||
`, tableName, groupID, groupID)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var idx int64
|
||||
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if idx == 0 {
|
||||
return 0, errors.New("cannot get the correct index")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
// GetNextResourceIndex generates a resource index, it must run in the same transaction where the resource is created
|
||||
func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
switch {
|
||||
case setting.Database.Type.IsPostgreSQL():
|
||||
return postgresGetNextResourceIndex(ctx, tableName, groupID)
|
||||
case setting.Database.Type.IsMySQL():
|
||||
return mysqlGetNextResourceIndex(ctx, tableName, groupID)
|
||||
case setting.Database.Type.IsMSSQL():
|
||||
return mssqlGetNextResourceIndex(ctx, tableName, groupID)
|
||||
}
|
||||
|
||||
e := GetEngine(ctx)
|
||||
|
||||
// try to update the max_index to next value, and acquire the write-lock for the record
|
||||
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected == 0 {
|
||||
// this slow path is only for the first time of creating a resource index
|
||||
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, 0)", tableName), groupID)
|
||||
res, err = e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// if the update still can not update any records, the record must not exist and there must be some errors (insert error)
|
||||
if affected == 0 {
|
||||
if errIns == nil {
|
||||
return 0, errors.New("impossible error when GetNextResourceIndex, insert and update both succeeded but no record is updated")
|
||||
}
|
||||
return 0, errIns
|
||||
}
|
||||
}
|
||||
|
||||
// now, the new index is in database (protected by the transaction and write-lock)
|
||||
var newIdx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&newIdx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !has {
|
||||
return 0, errors.New("impossible error when GetNextResourceIndex, upsert succeeded but no record can be selected")
|
||||
}
|
||||
return newIdx, nil
|
||||
}
|
||||
|
||||
// DeleteResourceIndex delete resource index
|
||||
func DeleteResourceIndex(ctx context.Context, tableName string, groupID int64) error {
|
||||
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
|
||||
return err
|
||||
}
|
||||
126
models/db/index_test.go
Normal file
126
models/db/index_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestIndex db.ResourceIndex
|
||||
|
||||
func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
e := db.GetEngine(ctx)
|
||||
var idx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !has {
|
||||
return 0, errors.New("no record")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func TestSyncMaxResourceIndex(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&TestIndex{}))
|
||||
|
||||
err := db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 51)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// sync new max index
|
||||
maxIndex, err := getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 51, maxIndex)
|
||||
|
||||
// smaller index doesn't change
|
||||
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 30)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 51, maxIndex)
|
||||
|
||||
// larger index changes
|
||||
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 62)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 62, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 84, maxIndex)
|
||||
return errors.New("test rollback")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex) // the max index doesn't change because the transaction was rolled back
|
||||
}
|
||||
|
||||
func TestGetNextResourceIndex(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&TestIndex{}))
|
||||
|
||||
// create a new record
|
||||
maxIndex, err := db.GetNextResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, maxIndex)
|
||||
|
||||
// increase the existing record
|
||||
maxIndex, err = db.GetNextResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 4, maxIndex)
|
||||
return errors.New("test rollback")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex) // the max index doesn't change because the transaction was rolled back
|
||||
}
|
||||
59
models/db/install/db.go
Normal file
59
models/db/install/db.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package install
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
)
|
||||
|
||||
// CheckDatabaseConnection checks the database connection
|
||||
func CheckDatabaseConnection(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).Exec("SELECT 1")
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMigrationVersion gets the database migration version
|
||||
func GetMigrationVersion(ctx context.Context) (int64, error) {
|
||||
var installedDbVersion int64
|
||||
x := db.GetEngine(ctx)
|
||||
exist, err := x.IsTableExist("version")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !exist {
|
||||
return 0, nil
|
||||
}
|
||||
_, err = x.Table("version").Cols("version").Get(&installedDbVersion)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return installedDbVersion, nil
|
||||
}
|
||||
|
||||
// HasPostInstallationUsers checks whether there are users after installation
|
||||
func HasPostInstallationUsers(ctx context.Context) (bool, error) {
|
||||
x := db.GetEngine(ctx)
|
||||
exist, err := x.IsTableExist("user")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !exist {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// if there are 2 or more users in database, we consider there are users created after installation
|
||||
threshold := 2
|
||||
if !setting.IsProd {
|
||||
// to debug easily, with non-prod RUN_MODE, we only check the count to 1
|
||||
threshold = 1
|
||||
}
|
||||
res, err := x.Table("user").Cols("id").Limit(threshold).Query()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(res) >= threshold, nil
|
||||
}
|
||||
43
models/db/iterate.go
Normal file
43
models/db/iterate.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// Iterate iterates all the Bean object
|
||||
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
|
||||
var start int
|
||||
batchSize := setting.Database.IterateBufferSize
|
||||
sess := GetEngine(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
beans := make([]*Bean, 0, batchSize)
|
||||
if cond != nil {
|
||||
sess = sess.Where(cond)
|
||||
}
|
||||
if err := sess.Limit(batchSize, start).Find(&beans); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(beans) == 0 {
|
||||
return nil
|
||||
}
|
||||
start += len(beans)
|
||||
|
||||
for _, bean := range beans {
|
||||
if err := f(ctx, bean); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
44
models/db/iterate_test.go
Normal file
44
models/db/iterate_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIterate(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
|
||||
|
||||
cnt, err := db.GetEngine(t.Context()).Count(&repo_model.RepoUnit{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var repoUnitCnt int
|
||||
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repo *repo_model.RepoUnit) error {
|
||||
repoUnitCnt++
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, cnt, repoUnitCnt)
|
||||
|
||||
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
|
||||
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has {
|
||||
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
215
models/db/list.go
Normal file
215
models/db/list.go
Normal file
@@ -0,0 +1,215 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"code.gitea.io/gitea/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultMaxInSize represents default variables number on IN () in SQL
|
||||
DefaultMaxInSize = 50
|
||||
defaultFindSliceSize = 10
|
||||
)
|
||||
|
||||
// Paginator is the base for different ListOptions types
|
||||
type Paginator interface {
|
||||
GetSkipTake() (skip, take int)
|
||||
IsListAll() bool
|
||||
}
|
||||
|
||||
// SetSessionPagination sets pagination for a database session
|
||||
func SetSessionPagination(sess Engine, p Paginator) *xorm.Session {
|
||||
skip, take := p.GetSkipTake()
|
||||
|
||||
return sess.Limit(take, skip)
|
||||
}
|
||||
|
||||
// ListOptions options to paginate results
|
||||
type ListOptions struct {
|
||||
PageSize int
|
||||
Page int // start from 1
|
||||
ListAll bool // if true, then PageSize and Page will not be taken
|
||||
}
|
||||
|
||||
var ListOptionsAll = ListOptions{ListAll: true}
|
||||
|
||||
var (
|
||||
_ Paginator = &ListOptions{}
|
||||
_ FindOptions = ListOptions{}
|
||||
)
|
||||
|
||||
// GetSkipTake returns the skip and take values
|
||||
func (opts *ListOptions) GetSkipTake() (skip, take int) {
|
||||
opts.SetDefaultValues()
|
||||
return (opts.Page - 1) * opts.PageSize, opts.PageSize
|
||||
}
|
||||
|
||||
func (opts ListOptions) GetPage() int {
|
||||
return opts.Page
|
||||
}
|
||||
|
||||
func (opts ListOptions) GetPageSize() int {
|
||||
return opts.PageSize
|
||||
}
|
||||
|
||||
// IsListAll indicates PageSize and Page will be ignored
|
||||
func (opts ListOptions) IsListAll() bool {
|
||||
return opts.ListAll
|
||||
}
|
||||
|
||||
// SetDefaultValues sets default values
|
||||
func (opts *ListOptions) SetDefaultValues() {
|
||||
if opts.PageSize <= 0 {
|
||||
opts.PageSize = setting.API.DefaultPagingNum
|
||||
}
|
||||
if opts.PageSize > setting.API.MaxResponseItems {
|
||||
opts.PageSize = setting.API.MaxResponseItems
|
||||
}
|
||||
if opts.Page <= 0 {
|
||||
opts.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
func (opts ListOptions) ToConds() builder.Cond {
|
||||
return builder.NewCond()
|
||||
}
|
||||
|
||||
// AbsoluteListOptions absolute options to paginate results
|
||||
type AbsoluteListOptions struct {
|
||||
skip int
|
||||
take int
|
||||
}
|
||||
|
||||
var _ Paginator = &AbsoluteListOptions{}
|
||||
|
||||
// NewAbsoluteListOptions creates a list option with applied limits
|
||||
func NewAbsoluteListOptions(skip, take int) *AbsoluteListOptions {
|
||||
if skip < 0 {
|
||||
skip = 0
|
||||
}
|
||||
if take <= 0 {
|
||||
take = setting.API.DefaultPagingNum
|
||||
}
|
||||
if take > setting.API.MaxResponseItems {
|
||||
take = setting.API.MaxResponseItems
|
||||
}
|
||||
return &AbsoluteListOptions{skip, take}
|
||||
}
|
||||
|
||||
// IsListAll will always return false
|
||||
func (opts *AbsoluteListOptions) IsListAll() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSkipTake returns the skip and take values
|
||||
func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) {
|
||||
return opts.skip, opts.take
|
||||
}
|
||||
|
||||
// FindOptions represents a find options
|
||||
type FindOptions interface {
|
||||
GetPage() int
|
||||
GetPageSize() int
|
||||
IsListAll() bool
|
||||
ToConds() builder.Cond
|
||||
}
|
||||
|
||||
type JoinFunc func(sess Engine) error
|
||||
|
||||
type FindOptionsJoin interface {
|
||||
ToJoins() []JoinFunc
|
||||
}
|
||||
|
||||
type FindOptionsOrder interface {
|
||||
ToOrders() string
|
||||
}
|
||||
|
||||
// Find represents a common find function which accept an options interface
|
||||
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if orderOpt, ok := opts.(FindOptionsOrder); ok {
|
||||
if order := orderOpt.ToOrders(); order != "" {
|
||||
sess.OrderBy(order)
|
||||
}
|
||||
}
|
||||
|
||||
page, pageSize := opts.GetPage(), opts.GetPageSize()
|
||||
if !opts.IsListAll() && pageSize > 0 {
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
sess.Limit(pageSize, (page-1)*pageSize)
|
||||
}
|
||||
|
||||
findPageSize := defaultFindSliceSize
|
||||
if pageSize > 0 {
|
||||
findPageSize = pageSize
|
||||
}
|
||||
objects := make([]*T, 0, findPageSize)
|
||||
if err := sess.Find(&objects); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Count represents a common count function which accept an options interface
|
||||
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var object T
|
||||
return sess.Count(&object)
|
||||
}
|
||||
|
||||
// FindAndCount represents a common findandcount function which accept an options interface
|
||||
func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
page, pageSize := opts.GetPage(), opts.GetPageSize()
|
||||
if !opts.IsListAll() && pageSize > 0 && page >= 1 {
|
||||
sess.Limit(pageSize, (page-1)*pageSize)
|
||||
}
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if orderOpt, ok := opts.(FindOptionsOrder); ok {
|
||||
if order := orderOpt.ToOrders(); order != "" {
|
||||
sess.OrderBy(order)
|
||||
}
|
||||
}
|
||||
|
||||
findPageSize := defaultFindSliceSize
|
||||
if pageSize > 0 {
|
||||
findPageSize = pageSize
|
||||
}
|
||||
objects := make([]*T, 0, findPageSize)
|
||||
cnt, err := sess.FindAndCount(&objects)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return objects, cnt, nil
|
||||
}
|
||||
52
models/db/list_test.go
Normal file
52
models/db/list_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
repo_model "code.gitea.io/gitea/models/repo"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type mockListOptions struct {
|
||||
db.ListOptions
|
||||
}
|
||||
|
||||
func (opts mockListOptions) IsListAll() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (opts mockListOptions) ToConds() builder.Cond {
|
||||
return builder.NewCond()
|
||||
}
|
||||
|
||||
func TestFind(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
|
||||
|
||||
var repoUnitCount int
|
||||
_, err := db.GetEngine(t.Context()).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, repoUnitCount)
|
||||
|
||||
opts := mockListOptions{}
|
||||
repoUnits, err := db.Find[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, repoUnits, repoUnitCount)
|
||||
|
||||
cnt, err := db.Count[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, repoUnitCount, cnt)
|
||||
|
||||
repoUnits, newCnt, err := db.FindAndCount[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cnt, newCnt)
|
||||
assert.Len(t, repoUnits, repoUnitCount)
|
||||
}
|
||||
107
models/db/log.go
Normal file
107
models/db/log.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"code.gitea.io/gitea/modules/log"
|
||||
|
||||
xormlog "xorm.io/xorm/log"
|
||||
)
|
||||
|
||||
// XORMLogBridge a logger bridge from Logger to xorm
|
||||
type XORMLogBridge struct {
|
||||
showSQL atomic.Bool
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewXORMLogger inits a log bridge for xorm
|
||||
func NewXORMLogger(showSQL bool) xormlog.Logger {
|
||||
l := &XORMLogBridge{logger: log.GetLogger("xorm")}
|
||||
l.showSQL.Store(showSQL)
|
||||
return l
|
||||
}
|
||||
|
||||
const stackLevel = 8
|
||||
|
||||
// Log a message with defined skip and at logging level
|
||||
func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...any) {
|
||||
l.logger.Log(skip+1, &log.Event{Level: level}, format, v...)
|
||||
}
|
||||
|
||||
// Debug show debug log
|
||||
func (l *XORMLogBridge) Debug(v ...any) {
|
||||
l.Log(stackLevel, log.DEBUG, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Debugf show debug log
|
||||
func (l *XORMLogBridge) Debugf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.DEBUG, format, v...)
|
||||
}
|
||||
|
||||
// Error show error log
|
||||
func (l *XORMLogBridge) Error(v ...any) {
|
||||
l.Log(stackLevel, log.ERROR, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Errorf show error log
|
||||
func (l *XORMLogBridge) Errorf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.ERROR, format, v...)
|
||||
}
|
||||
|
||||
// Info show information level log
|
||||
func (l *XORMLogBridge) Info(v ...any) {
|
||||
l.Log(stackLevel, log.INFO, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Infof show information level log
|
||||
func (l *XORMLogBridge) Infof(format string, v ...any) {
|
||||
l.Log(stackLevel, log.INFO, format, v...)
|
||||
}
|
||||
|
||||
// Warn show warning log
|
||||
func (l *XORMLogBridge) Warn(v ...any) {
|
||||
l.Log(stackLevel, log.WARN, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Warnf show warning log
|
||||
func (l *XORMLogBridge) Warnf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.WARN, format, v...)
|
||||
}
|
||||
|
||||
// Level get logger level
|
||||
func (l *XORMLogBridge) Level() xormlog.LogLevel {
|
||||
switch l.logger.GetLevel() {
|
||||
case log.TRACE, log.DEBUG:
|
||||
return xormlog.LOG_DEBUG
|
||||
case log.INFO:
|
||||
return xormlog.LOG_INFO
|
||||
case log.WARN:
|
||||
return xormlog.LOG_WARNING
|
||||
case log.ERROR:
|
||||
return xormlog.LOG_ERR
|
||||
case log.NONE:
|
||||
return xormlog.LOG_OFF
|
||||
}
|
||||
return xormlog.LOG_UNKNOWN
|
||||
}
|
||||
|
||||
// SetLevel set the logger level
|
||||
func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) {
|
||||
}
|
||||
|
||||
// ShowSQL set if record SQL
|
||||
func (l *XORMLogBridge) ShowSQL(show ...bool) {
|
||||
if len(show) == 0 {
|
||||
show = []bool{true}
|
||||
}
|
||||
l.showSQL.Store(show[0])
|
||||
}
|
||||
|
||||
// IsShowSQL if record SQL
|
||||
func (l *XORMLogBridge) IsShowSQL() bool {
|
||||
return l.showSQL.Load()
|
||||
}
|
||||
17
models/db/main_test.go
Normal file
17
models/db/main_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
_ "code.gitea.io/gitea/models"
|
||||
_ "code.gitea.io/gitea/models/repo"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
96
models/db/name.go
Normal file
96
models/db/name.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
// ErrNameReserved represents a "reserved name" error.
|
||||
type ErrNameReserved struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrNameReserved checks if an error is a ErrNameReserved.
|
||||
func IsErrNameReserved(err error) bool {
|
||||
_, ok := err.(ErrNameReserved)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNameReserved) Error() string {
|
||||
return fmt.Sprintf("name is reserved [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNameReserved) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// ErrNamePatternNotAllowed represents a "pattern not allowed" error.
|
||||
type ErrNamePatternNotAllowed struct {
|
||||
Pattern string
|
||||
}
|
||||
|
||||
// IsErrNamePatternNotAllowed checks if an error is an ErrNamePatternNotAllowed.
|
||||
func IsErrNamePatternNotAllowed(err error) bool {
|
||||
_, ok := err.(ErrNamePatternNotAllowed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNamePatternNotAllowed) Error() string {
|
||||
return fmt.Sprintf("name pattern is not allowed [pattern: %s]", err.Pattern)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNamePatternNotAllowed) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// ErrNameCharsNotAllowed represents a "character not allowed in name" error.
|
||||
type ErrNameCharsNotAllowed struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrNameCharsNotAllowed checks if an error is an ErrNameCharsNotAllowed.
|
||||
func IsErrNameCharsNotAllowed(err error) bool {
|
||||
_, ok := err.(ErrNameCharsNotAllowed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNameCharsNotAllowed) Error() string {
|
||||
return fmt.Sprintf("name is invalid [%s]: must be valid alpha or numeric or dash(-_) or dot characters", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNameCharsNotAllowed) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// IsUsableName checks if name is reserved or pattern of name is not allowed
|
||||
// based on given reserved names and patterns.
|
||||
// Names are exact match, patterns can be a prefix or suffix match with placeholder '*'.
|
||||
func IsUsableName(reservedNames, reservedPatterns []string, name string) error {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
if utf8.RuneCountInString(name) == 0 {
|
||||
return util.NewInvalidArgumentErrorf("name is empty")
|
||||
}
|
||||
|
||||
if slices.Contains(reservedNames, name) {
|
||||
return ErrNameReserved{name}
|
||||
}
|
||||
|
||||
for _, pat := range reservedPatterns {
|
||||
if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) ||
|
||||
(pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) {
|
||||
return ErrNamePatternNotAllowed{pat}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
14
models/db/paginator/main_test.go
Normal file
14
models/db/paginator/main_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package paginator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
7
models/db/paginator/paginator.go
Normal file
7
models/db/paginator/paginator.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package paginator
|
||||
|
||||
// dummy only. in the future, the models/db/list_options.go should be moved here to decouple from db package
|
||||
// otherwise the unit test will cause cycle import
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user