aboutsummaryrefslogtreecommitdiff
path: root/service/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'service/auth.go')
-rw-r--r--service/auth.go43
1 files changed, 42 insertions, 1 deletions
diff --git a/service/auth.go b/service/auth.go
index e517383..909a9a2 100644
--- a/service/auth.go
+++ b/service/auth.go
@@ -11,7 +11,8 @@ import (
)
var (
- ErrInvalidSession = errors.New("invalid session")
+ ErrInvalidSession = errors.New("invalid session")
+ ErrInvalidCSRFToken = errors.New("invalid csrf token")
)
type authService struct {
@@ -47,6 +48,14 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error
return c, nil
}
+func checkCSRF(ctx context.Context, c *model.Client) (err error) {
+ csrfToken, ok := ctx.Value("csrf_token").(string)
+ if !ok || csrfToken != c.Session.CSRFToken {
+ return ErrInvalidCSRFToken
+ }
+ return nil
+}
+
func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
redirectUrl string, sessionID string, err error) {
return s.Service.GetAuthUrl(ctx, instance)
@@ -184,6 +193,10 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.SaveSettings(ctx, client, c, settings)
}
@@ -192,6 +205,10 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.Like(ctx, client, c, id)
}
@@ -200,6 +217,10 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.UnLike(ctx, client, c, id)
}
@@ -208,6 +229,10 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.Retweet(ctx, client, c, id)
}
@@ -216,6 +241,10 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.UnRetweet(ctx, client, c, id)
}
@@ -224,6 +253,10 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files)
}
@@ -232,6 +265,10 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.Follow(ctx, client, c, id)
}
@@ -240,5 +277,9 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C
if err != nil {
return
}
+ err = checkCSRF(ctx, c)
+ if err != nil {
+ return
+ }
return s.Service.UnFollow(ctx, client, c, id)
}