aboutsummaryrefslogtreecommitdiff
path: root/service/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'service/auth.go')
-rw-r--r--service/auth.go208
1 files changed, 109 insertions, 99 deletions
diff --git a/service/auth.go b/service/auth.go
index 909a9a2..78934fd 100644
--- a/service/auth.go
+++ b/service/auth.go
@@ -3,7 +3,6 @@ package service
import (
"context"
"errors"
- "io"
"mime/multipart"
"bloat/model"
@@ -11,28 +10,28 @@ import (
)
var (
- ErrInvalidSession = errors.New("invalid session")
- ErrInvalidCSRFToken = errors.New("invalid csrf token")
+ errInvalidSession = errors.New("invalid session")
+ errInvalidCSRFToken = errors.New("invalid csrf token")
)
-type authService struct {
- sessionRepo model.SessionRepository
- appRepo model.AppRepository
+type as struct {
+ sessionRepo model.SessionRepo
+ appRepo model.AppRepo
Service
}
-func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppRepository, s Service) Service {
- return &authService{sessionRepo, appRepo, s}
+func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service {
+ return &as{sessionRepo, appRepo, s}
}
-func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) {
+func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) {
sessionID, ok := ctx.Value("session_id").(string)
if !ok || len(sessionID) < 1 {
- return nil, ErrInvalidSession
+ return errInvalidSession
}
session, err := s.sessionRepo.Get(sessionID)
if err != nil {
- return nil, ErrInvalidSession
+ return errInvalidSession
}
client, err := s.appRepo.Get(session.InstanceDomain)
if err != nil {
@@ -44,152 +43,163 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error
ClientSecret: client.ClientSecret,
AccessToken: session.AccessToken,
})
- c = &model.Client{Client: mc, Session: session}
- return c, nil
+ if c == nil {
+ c = &model.Client{}
+ }
+ c.Client = mc
+ c.Session = session
+ return nil
}
func checkCSRF(ctx context.Context, c *model.Client) (err error) {
- csrfToken, ok := ctx.Value("csrf_token").(string)
- if !ok || csrfToken != c.Session.CSRFToken {
- return ErrInvalidCSRFToken
+ token, ok := ctx.Value("csrf_token").(string)
+ if !ok || token != c.Session.CSRFToken {
+ return errInvalidCSRFToken
}
return nil
}
-func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
- redirectUrl string, sessionID string, err error) {
- return s.Service.GetAuthUrl(ctx, instance)
+func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) {
+ s.authenticateClient(ctx, c)
+ s.Service.ServeErrorPage(ctx, c, err)
}
-func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client,
- code string) (token string, err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) {
+ return s.Service.ServeSigninPage(ctx, c)
+}
+
+func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string,
+ maxID string, minID string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
+ return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID)
+}
- token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code)
+func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
+ return s.Service.ServeThreadPage(ctx, c, id, reply)
+}
- c.Session.AccessToken = token
- err = s.sessionRepo.Add(c.Session)
+func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
-
- return
-}
-
-func (s *authService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) {
- c, _ = s.getClient(ctx)
- s.Service.ServeErrorPage(ctx, client, c, err)
-}
-
-func (s *authService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) {
- return s.Service.ServeSigninPage(ctx, client)
+ return s.Service.ServeLikedByPage(ctx, c, id)
}
-func (s *authService) ServeTimelinePage(ctx context.Context, client io.Writer,
- c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID)
+ return s.Service.ServeRetweetedByPage(ctx, c, id)
}
-func (s *authService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeFollowingPage(ctx context.Context, c *model.Client, id string,
+ maxID string, minID string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeThreadPage(ctx, client, c, id, reply)
+ return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID)
}
-func (s *authService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeFollowersPage(ctx context.Context, c *model.Client, id string,
+ maxID string, minID string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID)
+ return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID)
}
-func (s *authService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client,
+ maxID string, minID string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID)
+ return s.Service.ServeNotificationPage(ctx, c, maxID, minID)
}
-func (s *authService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string,
+ maxID string, minID string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeAboutPage(ctx, client, c)
+ return s.Service.ServeUserPage(ctx, c, id, maxID, minID)
}
-func (s *authService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeEmojiPage(ctx, client, c)
+ return s.Service.ServeAboutPage(ctx, c)
}
-func (s *authService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeLikedByPage(ctx, client, c, id)
+ return s.Service.ServeEmojiPage(ctx, c)
}
-func (s *authService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string,
+ qType string, offset int) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeRetweetedByPage(ctx, client, c, id)
+ return s.Service.ServeSearchPage(ctx, c, q, qType, offset)
}
-func (s *authService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID)
+ return s.Service.ServeSettingsPage(ctx, c)
}
-func (s *authService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string,
+ sessionID string, err error) {
+ return s.Service.NewSession(ctx, instance)
+}
+
+func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string,
+ code string) (token string, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
- return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID)
-}
-func (s *authService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) {
- c, err = s.getClient(ctx)
+ token, err = s.Service.Signin(ctx, c, c.Session.ID, code)
if err != nil {
return
}
- return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset)
-}
-func (s *authService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
- c, err = s.getClient(ctx)
+ c.Session.AccessToken = token
+ err = s.sessionRepo.Add(c.Session)
if err != nil {
return
}
- return s.Service.ServeSettingsPage(ctx, client, c)
+
+ return
}
-func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) Post(ctx context.Context, c *model.Client, content string,
+ replyToID string, format string, visibility string, isNSFW bool,
+ files []*multipart.FileHeader) (id string, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -197,11 +207,11 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod
if err != nil {
return
}
- return s.Service.SaveSettings(ctx, client, c, settings)
+ return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files)
}
-func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
- c, err = s.getClient(ctx)
+func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -209,11 +219,11 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien
if err != nil {
return
}
- return s.Service.Like(ctx, client, c, id)
+ return s.Service.Like(ctx, c, id)
}
-func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
- c, err = s.getClient(ctx)
+func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -221,11 +231,11 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli
if err != nil {
return
}
- return s.Service.UnLike(ctx, client, c, id)
+ return s.Service.UnLike(ctx, c, id)
}
-func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
- c, err = s.getClient(ctx)
+func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -233,11 +243,11 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl
if err != nil {
return
}
- return s.Service.Retweet(ctx, client, c, id)
+ return s.Service.Retweet(ctx, c, id)
}
-func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
- c, err = s.getClient(ctx)
+func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -245,11 +255,11 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.
if err != nil {
return
}
- return s.Service.UnRetweet(ctx, client, c, id)
+ return s.Service.UnRetweet(ctx, c, id)
}
-func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) {
- c, err = s.getClient(ctx)
+func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -257,11 +267,11 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.
if err != nil {
return
}
- return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files)
+ return s.Service.Follow(ctx, c, id)
}
-func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -269,11 +279,11 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli
if err != nil {
return
}
- return s.Service.Follow(ctx, client, c, id)
+ return s.Service.UnFollow(ctx, c, id)
}
-func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
- c, err = s.getClient(ctx)
+func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) {
+ err = s.authenticateClient(ctx, c)
if err != nil {
return
}
@@ -281,5 +291,5 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C
if err != nil {
return
}
- return s.Service.UnFollow(ctx, client, c, id)
+ return s.Service.SaveSettings(ctx, c, settings)
}