diff options
Diffstat (limited to 'service/auth.go')
-rw-r--r-- | service/auth.go | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/service/auth.go b/service/auth.go index e517383..909a9a2 100644 --- a/service/auth.go +++ b/service/auth.go @@ -11,7 +11,8 @@ import ( ) var ( - ErrInvalidSession = errors.New("invalid session") + ErrInvalidSession = errors.New("invalid session") + ErrInvalidCSRFToken = errors.New("invalid csrf token") ) type authService struct { @@ -47,6 +48,14 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error return c, nil } +func checkCSRF(ctx context.Context, c *model.Client) (err error) { + csrfToken, ok := ctx.Value("csrf_token").(string) + if !ok || csrfToken != c.Session.CSRFToken { + return ErrInvalidCSRFToken + } + return nil +} + func (s *authService) GetAuthUrl(ctx context.Context, instance string) ( redirectUrl string, sessionID string, err error) { return s.Service.GetAuthUrl(ctx, instance) @@ -184,6 +193,10 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.SaveSettings(ctx, client, c, settings) } @@ -192,6 +205,10 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Like(ctx, client, c, id) } @@ -200,6 +217,10 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnLike(ctx, client, c, id) } @@ -208,6 +229,10 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Retweet(ctx, client, c, id) } @@ -216,6 +241,10 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnRetweet(ctx, client, c, id) } @@ -224,6 +253,10 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) } @@ -232,6 +265,10 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.Follow(ctx, client, c, id) } @@ -240,5 +277,9 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C if err != nil { return } + err = checkCSRF(ctx, c) + if err != nil { + return + } return s.Service.UnFollow(ctx, client, c, id) } |