aboutsummaryrefslogtreecommitdiff
path: root/service/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'service/auth.go')
-rw-r--r--service/auth.go25
1 files changed, 7 insertions, 18 deletions
diff --git a/service/auth.go b/service/auth.go
index 0209273..13e9c50 100644
--- a/service/auth.go
+++ b/service/auth.go
@@ -23,17 +23,9 @@ func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppReposi
return &authService{sessionRepo, appRepo, s}
}
-func getSessionID(ctx context.Context) (sessionID string, err error) {
+func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) {
sessionID, ok := ctx.Value("session_id").(string)
if !ok || len(sessionID) < 1 {
- return "", ErrInvalidSession
- }
- return sessionID, nil
-}
-
-func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) {
- sessionID, err := getSessionID(ctx)
- if err != nil {
return nil, ErrInvalidSession
}
session, err := s.sessionRepo.Get(sessionID)
@@ -50,7 +42,7 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error
ClientSecret: client.ClientSecret,
AccessToken: session.AccessToken,
})
- c = &model.Client{Client: mc}
+ c = &model.Client{Client: mc, Session: session}
return c, nil
}
@@ -61,21 +53,18 @@ func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client,
code string) (token string, err error) {
- sessionID, err = getSessionID(ctx)
- if err != nil {
- return
- }
c, err = s.getClient(ctx)
if err != nil {
return
}
- token, err = s.Service.GetUserToken(ctx, sessionID, c, code)
+ token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code)
if err != nil {
return
}
- err = s.sessionRepo.Update(sessionID, token)
+ c.Session.AccessToken = token
+ err = s.sessionRepo.Add(c.Session)
if err != nil {
return
}
@@ -168,12 +157,12 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.
return s.Service.UnRetweet(ctx, client, c, id)
}
-func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, files []*multipart.FileHeader) (id string, err error) {
+func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, visibility string, files []*multipart.FileHeader) (id string, err error) {
c, err = s.getClient(ctx)
if err != nil {
return
}
- return s.Service.PostTweet(ctx, client, c, content, replyToID, files)
+ return s.Service.PostTweet(ctx, client, c, content, replyToID, visibility, files)
}
func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {