|
@@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
|
|
t.Parallel()
|
|
t.Parallel()
|
|
|
|
|
|
ctx := context.Background()
|
|
ctx := context.Background()
|
|
- db := &loginSourcesStore{
|
|
|
|
- DB: newTestDB(t, "loginSourcesStore"),
|
|
|
|
|
|
+ s := &LoginSourcesStore{
|
|
|
|
+ db: newTestDB(t, "LoginSourcesStore"),
|
|
}
|
|
}
|
|
|
|
|
|
for _, tc := range []struct {
|
|
for _, tc := range []struct {
|
|
name string
|
|
name string
|
|
- test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
|
|
|
|
+ test func(t *testing.T, ctx context.Context, s *LoginSourcesStore)
|
|
}{
|
|
}{
|
|
{"Create", loginSourcesCreate},
|
|
{"Create", loginSourcesCreate},
|
|
{"Count", loginSourcesCount},
|
|
{"Count", loginSourcesCount},
|
|
@@ -181,10 +181,10 @@ func TestLoginSources(t *testing.T) {
|
|
} {
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Cleanup(func() {
|
|
t.Cleanup(func() {
|
|
- err := clearTables(t, db.DB)
|
|
|
|
|
|
+ err := clearTables(t, s.db)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
})
|
|
})
|
|
- tc.test(t, ctx, db)
|
|
|
|
|
|
+ tc.test(t, ctx, s)
|
|
})
|
|
})
|
|
if t.Failed() {
|
|
if t.Failed() {
|
|
break
|
|
break
|
|
@@ -192,9 +192,9 @@ func TestLoginSources(t *testing.T) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesCreate(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
// Create first login source with name "GitHub"
|
|
// Create first login source with name "GitHub"
|
|
- source, err := db.Create(ctx,
|
|
|
|
|
|
+ source, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -208,20 +208,28 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Get it back and check the Created field
|
|
// Get it back and check the Created field
|
|
- source, err = db.GetByID(ctx, source.ID)
|
|
|
|
|
|
+ source, err = s.GetByID(ctx, source.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
- assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
|
|
|
- assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
|
|
|
|
|
+ assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
|
|
|
+ assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
|
|
|
|
|
- // Try create second login source with same name should fail
|
|
|
|
- _, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
|
|
|
|
|
|
+ // Try to create second login source with same name should fail.
|
|
|
|
+ _, err = s.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
|
|
wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
|
wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
|
assert.Equal(t, wantErr, err)
|
|
assert.Equal(t, wantErr, err)
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func setMockLoginSourceFilesStore(t *testing.T, s *LoginSourcesStore, mock loginSourceFilesStore) {
|
|
|
|
+ before := s.files
|
|
|
|
+ s.files = mock
|
|
|
|
+ t.Cleanup(func() {
|
|
|
|
+ s.files = before
|
|
|
|
+ })
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func loginSourcesCount(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
// Create two login sources, one in database and one as source file.
|
|
// Create two login sources, one in database and one as source file.
|
|
- _, err := db.Create(ctx,
|
|
|
|
|
|
+ _, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -236,14 +244,14 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
|
|
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock.LenFunc.SetDefaultReturn(2)
|
|
mock.LenFunc.SetDefaultReturn(2)
|
|
- setMockLoginSourceFilesStore(t, db, mock)
|
|
|
|
|
|
+ setMockLoginSourceFilesStore(t, s, mock)
|
|
|
|
|
|
- assert.Equal(t, int64(3), db.Count(ctx))
|
|
|
|
|
|
+ assert.Equal(t, int64(3), s.Count(ctx))
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesDeleteByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
t.Run("delete but in used", func(t *testing.T) {
|
|
t.Run("delete but in used", func(t *testing.T) {
|
|
- source, err := db.Create(ctx,
|
|
|
|
|
|
+ source, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -257,7 +265,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Create a user that uses this login source
|
|
// Create a user that uses this login source
|
|
- _, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
|
|
|
|
|
|
+ _, err = NewUsersStore(s.db).Create(ctx, "alice", "",
|
|
CreateUserOptions{
|
|
CreateUserOptions{
|
|
LoginSource: source.ID,
|
|
LoginSource: source.ID,
|
|
},
|
|
},
|
|
@@ -265,7 +273,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Delete the login source will result in error
|
|
// Delete the login source will result in error
|
|
- err = db.DeleteByID(ctx, source.ID)
|
|
|
|
|
|
+ err = s.DeleteByID(ctx, source.ID)
|
|
wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
|
|
wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
|
|
assert.Equal(t, wantErr, err)
|
|
assert.Equal(t, wantErr, err)
|
|
})
|
|
})
|
|
@@ -274,10 +282,10 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
|
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
|
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
|
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
|
})
|
|
})
|
|
- setMockLoginSourceFilesStore(t, db, mock)
|
|
|
|
|
|
+ setMockLoginSourceFilesStore(t, s, mock)
|
|
|
|
|
|
// Create a login source with name "GitHub2"
|
|
// Create a login source with name "GitHub2"
|
|
- source, err := db.Create(ctx,
|
|
|
|
|
|
+ source, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub2",
|
|
Name: "GitHub2",
|
|
@@ -291,24 +299,24 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Delete a non-existent ID is noop
|
|
// Delete a non-existent ID is noop
|
|
- err = db.DeleteByID(ctx, 9999)
|
|
|
|
|
|
+ err = s.DeleteByID(ctx, 9999)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// We should be able to get it back
|
|
// We should be able to get it back
|
|
- _, err = db.GetByID(ctx, source.ID)
|
|
|
|
|
|
+ _, err = s.GetByID(ctx, source.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Now delete this login source with ID
|
|
// Now delete this login source with ID
|
|
- err = db.DeleteByID(ctx, source.ID)
|
|
|
|
|
|
+ err = s.DeleteByID(ctx, source.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// We should get token not found error
|
|
// We should get token not found error
|
|
- _, err = db.GetByID(ctx, source.ID)
|
|
|
|
|
|
+ _, err = s.GetByID(ctx, source.ID)
|
|
wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
|
|
wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
|
|
assert.Equal(t, wantErr, err)
|
|
assert.Equal(t, wantErr, err)
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesGetByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
|
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
|
if id != 101 {
|
|
if id != 101 {
|
|
@@ -316,14 +324,14 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
|
|
}
|
|
}
|
|
return &LoginSource{ID: id}, nil
|
|
return &LoginSource{ID: id}, nil
|
|
})
|
|
})
|
|
- setMockLoginSourceFilesStore(t, db, mock)
|
|
|
|
|
|
+ setMockLoginSourceFilesStore(t, s, mock)
|
|
|
|
|
|
expConfig := &github.Config{
|
|
expConfig := &github.Config{
|
|
APIEndpoint: "https://api.github.com",
|
|
APIEndpoint: "https://api.github.com",
|
|
}
|
|
}
|
|
|
|
|
|
// Create a login source with name "GitHub"
|
|
// Create a login source with name "GitHub"
|
|
- source, err := db.Create(ctx,
|
|
|
|
|
|
+ source, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -335,16 +343,16 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Get the one in the database and test the read/write hooks
|
|
// Get the one in the database and test the read/write hooks
|
|
- source, err = db.GetByID(ctx, source.ID)
|
|
|
|
|
|
+ source, err = s.GetByID(ctx, source.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expConfig, source.Provider.Config())
|
|
assert.Equal(t, expConfig, source.Provider.Config())
|
|
|
|
|
|
// Get the one in source file store
|
|
// Get the one in source file store
|
|
- _, err = db.GetByID(ctx, 101)
|
|
|
|
|
|
+ _, err = s.GetByID(ctx, 101)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesList(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
|
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
|
if opts.OnlyActivated {
|
|
if opts.OnlyActivated {
|
|
@@ -357,10 +365,10 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
{ID: 2},
|
|
{ID: 2},
|
|
}
|
|
}
|
|
})
|
|
})
|
|
- setMockLoginSourceFilesStore(t, db, mock)
|
|
|
|
|
|
+ setMockLoginSourceFilesStore(t, s, mock)
|
|
|
|
|
|
// Create two login sources in database, one activated and the other one not
|
|
// Create two login sources in database, one activated and the other one not
|
|
- _, err := db.Create(ctx,
|
|
|
|
|
|
+ _, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.PAM,
|
|
Type: auth.PAM,
|
|
Name: "PAM",
|
|
Name: "PAM",
|
|
@@ -370,7 +378,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
},
|
|
},
|
|
)
|
|
)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
- _, err = db.Create(ctx,
|
|
|
|
|
|
+ _, err = s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -383,17 +391,17 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// List all login sources
|
|
// List all login sources
|
|
- sources, err := db.List(ctx, ListLoginSourceOptions{})
|
|
|
|
|
|
+ sources, err := s.List(ctx, ListLoginSourceOptions{})
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 4, len(sources), "number of sources")
|
|
assert.Equal(t, 4, len(sources), "number of sources")
|
|
|
|
|
|
// Only list activated login sources
|
|
// Only list activated login sources
|
|
- sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
|
|
|
|
|
|
+ sources, err = s.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 2, len(sources), "number of sources")
|
|
assert.Equal(t, 2, len(sources), "number of sources")
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock := NewMockLoginSourceFilesStore()
|
|
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
|
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
|
|
mockFile := NewMockLoginSourceFileStore()
|
|
mockFile := NewMockLoginSourceFileStore()
|
|
@@ -407,10 +415,10 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|
},
|
|
},
|
|
}
|
|
}
|
|
})
|
|
})
|
|
- setMockLoginSourceFilesStore(t, db, mock)
|
|
|
|
|
|
+ setMockLoginSourceFilesStore(t, s, mock)
|
|
|
|
|
|
// Create two login sources both have default on
|
|
// Create two login sources both have default on
|
|
- source1, err := db.Create(ctx,
|
|
|
|
|
|
+ source1, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.PAM,
|
|
Type: auth.PAM,
|
|
Name: "PAM",
|
|
Name: "PAM",
|
|
@@ -421,7 +429,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|
},
|
|
},
|
|
)
|
|
)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
- source2, err := db.Create(ctx,
|
|
|
|
|
|
+ source2, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -435,23 +443,23 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Set source 1 as default
|
|
// Set source 1 as default
|
|
- err = db.ResetNonDefault(ctx, source1)
|
|
|
|
|
|
+ err = s.ResetNonDefault(ctx, source1)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
// Verify the default state
|
|
// Verify the default state
|
|
- source1, err = db.GetByID(ctx, source1.ID)
|
|
|
|
|
|
+ source1, err = s.GetByID(ctx, source1.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.True(t, source1.IsDefault)
|
|
assert.True(t, source1.IsDefault)
|
|
|
|
|
|
- source2, err = db.GetByID(ctx, source2.ID)
|
|
|
|
|
|
+ source2, err = s.GetByID(ctx, source2.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.False(t, source2.IsDefault)
|
|
assert.False(t, source2.IsDefault)
|
|
}
|
|
}
|
|
|
|
|
|
-func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
|
|
|
|
|
|
+func loginSourcesSave(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
|
|
t.Run("save to database", func(t *testing.T) {
|
|
t.Run("save to database", func(t *testing.T) {
|
|
// Create a login source with name "GitHub"
|
|
// Create a login source with name "GitHub"
|
|
- source, err := db.Create(ctx,
|
|
|
|
|
|
+ source, err := s.Create(ctx,
|
|
CreateLoginSourceOptions{
|
|
CreateLoginSourceOptions{
|
|
Type: auth.GitHub,
|
|
Type: auth.GitHub,
|
|
Name: "GitHub",
|
|
Name: "GitHub",
|
|
@@ -468,10 +476,10 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
source.Provider = github.NewProvider(&github.Config{
|
|
source.Provider = github.NewProvider(&github.Config{
|
|
APIEndpoint: "https://api2.github.com",
|
|
APIEndpoint: "https://api2.github.com",
|
|
})
|
|
})
|
|
- err = db.Save(ctx, source)
|
|
|
|
|
|
+ err = s.Save(ctx, source)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
|
|
|
|
- source, err = db.GetByID(ctx, source.ID)
|
|
|
|
|
|
+ source, err = s.GetByID(ctx, source.ID)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
assert.False(t, source.IsActived)
|
|
assert.False(t, source.IsActived)
|
|
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
|
|
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
|
|
@@ -485,7 +493,7 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
|
|
}),
|
|
}),
|
|
File: mockFile,
|
|
File: mockFile,
|
|
}
|
|
}
|
|
- err := db.Save(ctx, source)
|
|
|
|
|
|
+ err := s.Save(ctx, source)
|
|
require.NoError(t, err)
|
|
require.NoError(t, err)
|
|
mockrequire.Called(t, mockFile.SaveFunc)
|
|
mockrequire.Called(t, mockFile.SaveFunc)
|
|
})
|
|
})
|