From 2af37d47783aac8c650ffd1578e2297b5784c73d Mon Sep 17 00:00:00 2001 From: r Date: Tue, 28 Jan 2020 17:51:00 +0000 Subject: Refactor everything --- service/auth.go | 208 ++++++------- service/logging.go | 190 ++++++------ service/service.go | 811 +++++++++++++++++++++++---------------------------- service/transport.go | 650 ++++++++++++++++++++++------------------- 4 files changed, 916 insertions(+), 943 deletions(-) (limited to 'service') diff --git a/service/auth.go b/service/auth.go index 909a9a2..78934fd 100644 --- a/service/auth.go +++ b/service/auth.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "io" "mime/multipart" "bloat/model" @@ -11,28 +10,28 @@ import ( ) var ( - ErrInvalidSession = errors.New("invalid session") - ErrInvalidCSRFToken = errors.New("invalid csrf token") + errInvalidSession = errors.New("invalid session") + errInvalidCSRFToken = errors.New("invalid csrf token") ) -type authService struct { - sessionRepo model.SessionRepository - appRepo model.AppRepository +type as struct { + sessionRepo model.SessionRepo + appRepo model.AppRepo Service } -func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppRepository, s Service) Service { - return &authService{sessionRepo, appRepo, s} +func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service { + return &as{sessionRepo, appRepo, s} } -func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) { +func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) { sessionID, ok := ctx.Value("session_id").(string) if !ok || len(sessionID) < 1 { - return nil, ErrInvalidSession + return errInvalidSession } session, err := s.sessionRepo.Get(sessionID) if err != nil { - return nil, ErrInvalidSession + return errInvalidSession } client, err := s.appRepo.Get(session.InstanceDomain) if err != nil { @@ -44,152 +43,163 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error ClientSecret: client.ClientSecret, AccessToken: session.AccessToken, }) - c = &model.Client{Client: mc, Session: session} - return c, nil + if c == nil { + c = &model.Client{} + } + c.Client = mc + c.Session = session + return nil } func checkCSRF(ctx context.Context, c *model.Client) (err error) { - csrfToken, ok := ctx.Value("csrf_token").(string) - if !ok || csrfToken != c.Session.CSRFToken { - return ErrInvalidCSRFToken + token, ok := ctx.Value("csrf_token").(string) + if !ok || token != c.Session.CSRFToken { + return errInvalidCSRFToken } return nil } -func (s *authService) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { - return s.Service.GetAuthUrl(ctx, instance) +func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) { + s.authenticateClient(ctx, c) + s.Service.ServeErrorPage(ctx, c, err) } -func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, - code string) (token string, err error) { - c, err = s.getClient(ctx) +func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) { + return s.Service.ServeSigninPage(ctx, c) +} + +func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } + return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID) +} - token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code) +func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } + return s.Service.ServeThreadPage(ctx, c, id, reply) +} - c.Session.AccessToken = token - err = s.sessionRepo.Add(c.Session) +func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - - return -} - -func (s *authService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { - c, _ = s.getClient(ctx) - s.Service.ServeErrorPage(ctx, client, c, err) -} - -func (s *authService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { - return s.Service.ServeSigninPage(ctx, client) + return s.Service.ServeLikedByPage(ctx, c, id) } -func (s *authService) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) + return s.Service.ServeRetweetedByPage(ctx, c, id) } -func (s *authService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeFollowingPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeThreadPage(ctx, client, c, id, reply) + return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID) } -func (s *authService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeFollowersPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) + return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID) } -func (s *authService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeNotificationPage(ctx, c, maxID, minID) } -func (s *authService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeAboutPage(ctx, client, c) + return s.Service.ServeUserPage(ctx, c, id, maxID, minID) } -func (s *authService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeEmojiPage(ctx, client, c) + return s.Service.ServeAboutPage(ctx, c) } -func (s *authService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeLikedByPage(ctx, client, c, id) + return s.Service.ServeEmojiPage(ctx, c) } -func (s *authService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string, + qType string, offset int) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeRetweetedByPage(ctx, client, c, id) + return s.Service.ServeSearchPage(ctx, c, q, qType, offset) } -func (s *authService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) +func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeSettingsPage(ctx, c) } -func (s *authService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - c, err = s.getClient(ctx) +func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string, + sessionID string, err error) { + return s.Service.NewSession(ctx, instance) +} + +func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string, + code string) (token string, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } - return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) -} -func (s *authService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { - c, err = s.getClient(ctx) + token, err = s.Service.Signin(ctx, c, c.Session.ID, code) if err != nil { return } - return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) -} -func (s *authService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - c, err = s.getClient(ctx) + c.Session.AccessToken = token + err = s.sessionRepo.Add(c.Session) if err != nil { return } - return s.Service.ServeSettingsPage(ctx, client, c) + + return } -func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { - c, err = s.getClient(ctx) +func (s *as) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -197,11 +207,11 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod if err != nil { return } - return s.Service.SaveSettings(ctx, client, c, settings) + return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files) } -func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -209,11 +219,11 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien if err != nil { return } - return s.Service.Like(ctx, client, c, id) + return s.Service.Like(ctx, c, id) } -func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -221,11 +231,11 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } - return s.Service.UnLike(ctx, client, c, id) + return s.Service.UnLike(ctx, c, id) } -func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -233,11 +243,11 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl if err != nil { return } - return s.Service.Retweet(ctx, client, c, id) + return s.Service.Retweet(ctx, c, id) } -func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - c, err = s.getClient(ctx) +func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -245,11 +255,11 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } - return s.Service.UnRetweet(ctx, client, c, id) + return s.Service.UnRetweet(ctx, c, id) } -func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { - c, err = s.getClient(ctx) +func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -257,11 +267,11 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model. if err != nil { return } - return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) + return s.Service.Follow(ctx, c, id) } -func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -269,11 +279,11 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli if err != nil { return } - return s.Service.Follow(ctx, client, c, id) + return s.Service.UnFollow(ctx, c, id) } -func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - c, err = s.getClient(ctx) +func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) { + err = s.authenticateClient(ctx, c) if err != nil { return } @@ -281,5 +291,5 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C if err != nil { return } - return s.Service.UnFollow(ctx, client, c, id) + return s.Service.SaveSettings(ctx, c, settings) } diff --git a/service/logging.go b/service/logging.go index cafd815..e4f8985 100644 --- a/service/logging.go +++ b/service/logging.go @@ -2,7 +2,6 @@ package service import ( "context" - "io" "log" "mime/multipart" "time" @@ -10,206 +9,215 @@ import ( "bloat/model" ) -type loggingService struct { +type ls struct { logger *log.Logger Service } func NewLoggingService(logger *log.Logger, s Service) Service { - return &loggingService{logger, s} + return &ls{logger, s} } -func (s *loggingService) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { +func (s *ls) ServeErrorPage(ctx context.Context, c *model.Client, err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", - "GetAuthUrl", instance, time.Since(begin), err) + s.logger.Printf("method=%v, err=%v, took=%v\n", + "ServeErrorPage", err, time.Since(begin)) }(time.Now()) - return s.Service.GetAuthUrl(ctx, instance) + s.Service.ServeErrorPage(ctx, c, err) } -func (s *loggingService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, - code string) (token string, err error) { +func (s *ls) ServeSigninPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, session_id=%v, code=%v, took=%v, err=%v\n", - "GetUserToken", sessionID, code, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeSigninPage", time.Since(begin), err) }(time.Now()) - return s.Service.GetUserToken(ctx, sessionID, c, code) + return s.Service.ServeSigninPage(ctx, c) } -func (s *loggingService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (s *ls) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, err=%v, took=%v\n", - "ServeErrorPage", err, time.Since(begin)) + s.logger.Printf("method=%v, type=%v, took=%v, err=%v\n", + "ServeTimelinePage", tType, time.Since(begin), err) }(time.Now()) - s.Service.ServeErrorPage(ctx, client, c, err) + return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID) } -func (s *loggingService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { +func (s *ls) ServeThreadPage(ctx context.Context, c *model.Client, id string, + reply bool) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeSigninPage", time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeThreadPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeSigninPage(ctx, client) + return s.Service.ServeThreadPage(ctx, c, id, reply) } -func (s *loggingService) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { +func (s *ls) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, timeline_type=%v, max_id=%v, since_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeTimelinePage", timelineType, maxID, sinceID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeLikedByPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) + return s.Service.ServeLikedByPage(ctx, c, id) } -func (s *loggingService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { +func (s *ls) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, reply=%v, took=%v, err=%v\n", - "ServeThreadPage", id, reply, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeRetweetedByPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeThreadPage(ctx, client, c, id, reply) + return s.Service.ServeRetweetedByPage(ctx, c, id) } -func (s *loggingService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { +func (s *ls) ServeFollowingPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeNotificationPage", maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeFollowingPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) + return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID) } -func (s *loggingService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeFollowersPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeUserPage", id, maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeFollowersPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID) } -func (s *loggingService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeAboutPage", time.Since(begin), err) + "ServeNotificationPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeAboutPage(ctx, client, c) + return s.Service.ServeNotificationPage(ctx, c, maxID, minID) } -func (s *loggingService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) ServeUserPage(ctx context.Context, c *model.Client, id string, + maxID string, minID string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeEmojiPage", time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "ServeUserPage", id, time.Since(begin), err) }(time.Now()) - return s.Service.ServeEmojiPage(ctx, client, c) + return s.Service.ServeUserPage(ctx, c, id, maxID, minID) } -func (s *loggingService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "ServeLikedByPage", id, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeAboutPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeLikedByPage(ctx, client, c, id) + return s.Service.ServeAboutPage(ctx, c) } -func (s *loggingService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "ServeRetweetedByPage", id, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeEmojiPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeRetweetedByPage(ctx, client, c, id) + return s.Service.ServeEmojiPage(ctx, c) } -func (s *loggingService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeSearchPage(ctx context.Context, c *model.Client, q string, + qType string, offset int) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeFollowingPage", id, maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeSearchPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeSearchPage(ctx, c, q, qType, offset) } -func (s *loggingService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", - "ServeFollowersPage", id, maxID, minID, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "ServeSettingsPage", time.Since(begin), err) }(time.Now()) - return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) + return s.Service.ServeSettingsPage(ctx, c) } -func (s *loggingService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { +func (s *ls) NewSession(ctx context.Context, instance string) (redirectUrl string, + sessionID string, err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, q=%v, type=%v, offset=%v, took=%v, err=%v\n", - "ServeSearchPage", q, qType, offset, time.Since(begin), err) + s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", + "NewSession", instance, time.Since(begin), err) }(time.Now()) - return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) + return s.Service.NewSession(ctx, instance) } -func (s *loggingService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) Signin(ctx context.Context, c *model.Client, sessionID string, + code string) (token string, err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeSettingsPage", time.Since(begin), err) + s.logger.Printf("method=%v, session_id=%v, took=%v, err=%v\n", + "Signin", sessionID, time.Since(begin), err) }(time.Now()) - return s.Service.ServeSettingsPage(ctx, client, c) + return s.Service.Signin(ctx, c, sessionID, code) } -func (s *loggingService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { +func (s *ls) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, took=%v, err=%v\n", - "SaveSettings", time.Since(begin), err) + "Post", time.Since(begin), err) }(time.Now()) - return s.Service.SaveSettings(ctx, client, c, settings) + return s.Service.Post(ctx, c, content, replyToID, format, + visibility, isNSFW, files) } -func (s *loggingService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "Like", id, time.Since(begin), err) }(time.Now()) - return s.Service.Like(ctx, client, c, id) + return s.Service.Like(ctx, c, id) } -func (s *loggingService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "UnLike", id, time.Since(begin), err) }(time.Now()) - return s.Service.UnLike(ctx, client, c, id) + return s.Service.UnLike(ctx, c, id) } -func (s *loggingService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "Retweet", id, time.Since(begin), err) }(time.Now()) - return s.Service.Retweet(ctx, client, c, id) + return s.Service.Retweet(ctx, c, id) } -func (s *loggingService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", "UnRetweet", id, time.Since(begin), err) }(time.Now()) - return s.Service.UnRetweet(ctx, client, c, id) + return s.Service.UnRetweet(ctx, c, id) } -func (s *loggingService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { +func (s *ls) Follow(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, content=%v, reply_to_id=%v, format=%v, visibility=%v, is_nsfw=%v, took=%v, err=%v\n", - "PostTweet", content, replyToID, format, visibility, isNSFW, time.Since(begin), err) + s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", + "Follow", id, time.Since(begin), err) }(time.Now()) - return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) + return s.Service.Follow(ctx, c, id) } -func (s *loggingService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { defer func(begin time.Time) { s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Follow", id, time.Since(begin), err) + "UnFollow", id, time.Since(begin), err) }(time.Now()) - return s.Service.Follow(ctx, client, c, id) + return s.Service.UnFollow(ctx, c, id) } -func (s *loggingService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) { defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnFollow", id, time.Since(begin), err) + s.logger.Printf("method=%v, took=%v, err=%v\n", + "SaveSettings", time.Since(begin), err) }(time.Now()) - return s.Service.UnFollow(ctx, client, c, id) + return s.Service.SaveSettings(ctx, c, settings) } diff --git a/service/service.go b/service/service.go index c9fccb4..7ad860f 100644 --- a/service/service.go +++ b/service/service.go @@ -1,14 +1,10 @@ package service import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" "mime/multipart" - "net/http" "net/url" "strings" @@ -19,37 +15,35 @@ import ( ) var ( - ErrInvalidArgument = errors.New("invalid argument") - ErrInvalidToken = errors.New("invalid token") - ErrInvalidClient = errors.New("invalid client") - ErrInvalidTimeline = errors.New("invalid timeline") + errInvalidArgument = errors.New("invalid argument") ) type Service interface { - GetAuthUrl(ctx context.Context, instance string) (url string, sessionID string, err error) - GetUserToken(ctx context.Context, sessionID string, c *model.Client, token string) (accessToken string, err error) - ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) - ServeSigninPage(ctx context.Context, client io.Writer) (err error) - ServeTimelinePage(ctx context.Context, client io.Writer, c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) - ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) - ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) - ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) - ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) - ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) - SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) - Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) - PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) - Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) - UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) + ServeErrorPage(ctx context.Context, c *model.Client, err error) + ServeSigninPage(ctx context.Context, c *model.Client) (err error) + ServeTimelinePage(ctx context.Context, c *model.Client, tType string, maxID string, minID string) (err error) + ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) + ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) + ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) + ServeFollowingPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeFollowersPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeNotificationPage(ctx context.Context, c *model.Client, maxID string, minID string) (err error) + ServeUserPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) + ServeAboutPage(ctx context.Context, c *model.Client) (err error) + ServeEmojiPage(ctx context.Context, c *model.Client) (err error) + ServeSearchPage(ctx context.Context, c *model.Client, q string, qType string, offset int) (err error) + ServeSettingsPage(ctx context.Context, c *model.Client) (err error) + NewSession(ctx context.Context, instance string) (redirectUrl string, sessionID string, err error) + Signin(ctx context.Context, c *model.Client, sessionID string, code string) (token string, err error) + Post(ctx context.Context, c *model.Client, content string, replyToID string, format string, + visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) + Like(ctx context.Context, c *model.Client, id string) (count int64, err error) + UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) + Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) + UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) + Follow(ctx context.Context, c *model.Client, id string) (err error) + UnFollow(ctx context.Context, c *model.Client, id string) (err error) + SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) } type service struct { @@ -59,13 +53,19 @@ type service struct { customCSS string postFormats []model.PostFormat renderer renderer.Renderer - sessionRepo model.SessionRepository - appRepo model.AppRepository + sessionRepo model.SessionRepo + appRepo model.AppRepo } -func NewService(clientName string, clientScope string, clientWebsite string, - customCSS string, postFormats []model.PostFormat, renderer renderer.Renderer, - sessionRepo model.SessionRepository, appRepo model.AppRepository) Service { +func NewService(clientName string, + clientScope string, + clientWebsite string, + customCSS string, + postFormats []model.PostFormat, + renderer renderer.Renderer, + sessionRepo model.SessionRepo, + appRepo model.AppRepo, +) Service { return &service{ clientName: clientName, clientScope: clientScope, @@ -96,137 +96,75 @@ func getRendererContext(c *model.Client) *renderer.Context { } } -func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( - redirectUrl string, sessionID string, err error) { - var instanceURL string - if strings.HasPrefix(instance, "https://") { - instanceURL = instance - instance = strings.TrimPrefix(instance, "https://") - } else { - instanceURL = "https://" + instance - } - - sessionID, err = util.NewSessionId() - if err != nil { - return - } - csrfToken, err := util.NewCSRFToken() - if err != nil { - return - } - session := model.Session{ - ID: sessionID, - InstanceDomain: instance, - CSRFToken: csrfToken, - Settings: *model.NewSettings(), - } - err = svc.sessionRepo.Add(session) - if err != nil { +func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, + val string, number int) { + if key == nil { return } - app, err := svc.appRepo.Get(instance) - if err != nil { - if err != model.ErrAppNotFound { - return - } - - var mastoApp *mastodon.Application - mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{ - Server: instanceURL, - ClientName: svc.clientName, - Scopes: svc.clientScope, - Website: svc.clientWebsite, - RedirectURIs: svc.clientWebsite + "/oauth_callback", - }) - if err != nil { - return - } - - app = model.App{ - InstanceDomain: instance, - InstanceURL: instanceURL, - ClientID: mastoApp.ClientID, - ClientSecret: mastoApp.ClientSecret, - } - - err = svc.appRepo.Add(app) - if err != nil { - return - } - } - - u, err := url.Parse("/oauth/authorize") - if err != nil { + keyStr, ok := key.(string) + if !ok { return } - q := make(url.Values) - q.Set("scope", "read write follow") - q.Set("client_id", app.ClientID) - q.Set("response_type", "code") - q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") - u.RawQuery = q.Encode() - - redirectUrl = instanceURL + u.String() + _, ok = m[keyStr] + if !ok { + m[keyStr] = []mastodon.ReplyInfo{} + } - return + m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) } -func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model.Client, - code string) (token string, err error) { - if len(code) < 1 { - err = ErrInvalidArgument - return +func (svc *service) getCommonData(ctx context.Context, c *model.Client, + title string) (data *renderer.CommonData, err error) { + + data = new(renderer.CommonData) + data.HeaderData = &renderer.HeaderData{ + Title: title + " - " + svc.clientName, + NotificationCount: 0, + CustomCSS: svc.customCSS, } - session, err := svc.sessionRepo.Get(sessionID) - if err != nil { + if c == nil || !c.Session.IsLoggedIn() { return } - app, err := svc.appRepo.Get(session.InstanceDomain) + notifications, err := c.GetNotifications(ctx, nil) if err != nil { - return + return nil, err } - data := &bytes.Buffer{} - err = json.NewEncoder(data).Encode(map[string]string{ - "client_id": app.ClientID, - "client_secret": app.ClientSecret, - "grant_type": "authorization_code", - "code": code, - "redirect_uri": svc.clientWebsite + "/oauth_callback", - }) - if err != nil { - return + var notificationCount int + for i := range notifications { + if notifications[i].Pleroma != nil && + !notifications[i].Pleroma.IsSeen { + notificationCount++ + } } - resp, err := http.Post(app.InstanceURL+"/oauth/token", "application/json", data) + u, err := c.GetAccountCurrentUser(ctx) if err != nil { - return + return nil, err } - defer resp.Body.Close() - var res struct { - AccessToken string `json:"access_token"` + data.NavbarData = &renderer.NavbarData{ + User: u, + NotificationCount: notificationCount, } - err = json.NewDecoder(resp.Body).Decode(&res) - if err != nil { - return - } + data.HeaderData.NotificationCount = notificationCount + data.HeaderData.CSRFToken = c.Session.CSRFToken - return res.AccessToken, nil + return } -func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (svc *service) ServeErrorPage(ctx context.Context, c *model.Client, err error) { var errStr string if err != nil { errStr = err.Error() } - commonData, err := svc.getCommonData(ctx, client, nil, "error") + commonData, err := svc.getCommonData(ctx, nil, "error") if err != nil { return } @@ -237,12 +175,13 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod } rCtx := getRendererContext(c) - - svc.renderer.RenderErrorPage(rCtx, client, data) + svc.renderer.RenderErrorPage(rCtx, c.Writer, data) } -func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { - commonData, err := svc.getCommonData(ctx, client, nil, "signin") +func (svc *service) ServeSigninPage(ctx context.Context, c *model.Client) ( + err error) { + + commonData, err := svc.getCommonData(ctx, nil, "signin") if err != nil { return } @@ -252,26 +191,23 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err } rCtx := getRendererContext(nil) - return svc.renderer.RenderSigninPage(rCtx, client, data) + return svc.renderer.RenderSigninPage(rCtx, c.Writer, data) } -func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, - c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { - - var hasNext, hasPrev bool - var nextLink, prevLink string +func (svc *service) ServeTimelinePage(ctx context.Context, c *model.Client, + tType string, maxID string, minID string) (err error) { + var nextLink, prevLink, title string + var statuses []*mastodon.Status var pg = mastodon.Pagination{ MaxID: maxID, MinID: minID, Limit: 20, } - var statuses []*mastodon.Status - var title string - switch timelineType { + switch tType { default: - return ErrInvalidTimeline + return errInvalidArgument case "home": statuses, err = c.GetTimelineHome(ctx, &pg) title = "Timeline" @@ -293,29 +229,31 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, } if len(maxID) > 0 && len(statuses) > 0 { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/$s?min_id=%s", timelineType, statuses[0].ID) + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", tType, + statuses[0].ID) } + if len(minID) > 0 && len(pg.MinID) > 0 { - newStatuses, err := c.GetTimelineHome(ctx, &mastodon.Pagination{MinID: pg.MinID, Limit: 20}) + newPg := &mastodon.Pagination{MinID: pg.MinID, Limit: 20} + newStatuses, err := c.GetTimelineHome(ctx, newPg) if err != nil { return err } - newStatusesLen := len(newStatuses) - if newStatusesLen == 20 { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, pg.MinID) + newLen := len(newStatuses) + if newLen == 20 { + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", + tType, pg.MinID) } else { - i := 20 - newStatusesLen - 1 + i := 20 - newLen - 1 if len(statuses) > i { - hasPrev = true - prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, statuses[i].ID) + prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", + tType, statuses[i].ID) } } } + if len(pg.MaxID) > 0 { - hasNext = true - nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", timelineType, pg.MaxID) + nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", tType, pg.MaxID) } postContext := model.PostContext{ @@ -323,7 +261,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, Formats: svc.postFormats, } - commonData, err := svc.getCommonData(ctx, client, c, timelineType+" timeline ") + commonData, err := svc.getCommonData(ctx, c, tType+" timeline ") if err != nil { return } @@ -331,24 +269,21 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, data := &renderer.TimelineData{ Title: title, Statuses: statuses, - HasNext: hasNext, NextLink: nextLink, - HasPrev: hasPrev, PrevLink: prevLink, PostContext: postContext, CommonData: commonData, } + rCtx := getRendererContext(c) + return svc.renderer.RenderTimelinePage(rCtx, c.Writer, data) +} - err = svc.renderer.RenderTimelinePage(rCtx, client, data) - if err != nil { - return - } +func (svc *service) ServeThreadPage(ctx context.Context, c *model.Client, + id string, reply bool) (err error) { - return -} + var postContext model.PostContext -func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { status, err := c.GetStatus(ctx, id) if err != nil { return @@ -359,19 +294,19 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo return } - var postContext model.PostContext if reply { var content string + var visibility string if u.ID != status.Account.ID { content += "@" + status.Account.Acct + " " } for i := range status.Mentions { - if status.Mentions[i].ID != u.ID && status.Mentions[i].ID != status.Account.ID { + if status.Mentions[i].ID != u.ID && + status.Mentions[i].ID != status.Account.ID { content += "@" + status.Mentions[i].Acct + " " } } - var visibility string if c.Session.Settings.CopyScope { s, err := c.GetStatus(ctx, id) if err != nil { @@ -400,16 +335,15 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo } statuses := append(append(context.Ancestors, status), context.Descendants...) - - replyMap := make(map[string][]mastodon.ReplyInfo) + replies := make(map[string][]mastodon.ReplyInfo) for i := range statuses { statuses[i].ShowReplies = true - statuses[i].ReplyMap = replyMap - addToReplyMap(replyMap, statuses[i].InReplyToID, statuses[i].ID, i+1) + statuses[i].ReplyMap = replies + addToReplyMap(replies, statuses[i].InReplyToID, statuses[i].ID, i+1) } - commonData, err := svc.getCommonData(ctx, client, c, "post by "+status.Account.DisplayName) + commonData, err := svc.getCommonData(ctx, c, "post by "+status.Account.DisplayName) if err != nil { return } @@ -417,224 +351,182 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo data := &renderer.ThreadData{ Statuses: statuses, PostContext: postContext, - ReplyMap: replyMap, + ReplyMap: replies, CommonData: commonData, } - rCtx := getRendererContext(c) - err = svc.renderer.RenderThreadPage(rCtx, client, data) - if err != nil { - return - } - - return + rCtx := getRendererContext(c) + return svc.renderer.RenderThreadPage(rCtx, c.Writer, data) } -func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { - var hasNext bool - var nextLink string +func (svc *service) ServeLikedByPage(ctx context.Context, c *model.Client, + id string) (err error) { - var pg = mastodon.Pagination{ - MaxID: maxID, - MinID: minID, - Limit: 20, + likers, err := c.GetFavouritedBy(ctx, id, nil) + if err != nil { + return } - notifications, err := c.GetNotifications(ctx, &pg) + commonData, err := svc.getCommonData(ctx, c, "likes") if err != nil { return } - var unreadCount int - for i := range notifications { - if notifications[i].Status != nil { - notifications[i].Status.CreatedAt = notifications[i].CreatedAt - switch notifications[i].Type { - case "reblog", "favourite": - notifications[i].Status.HideAccountInfo = true - } - } - if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { - unreadCount++ - } + data := &renderer.LikedByData{ + CommonData: commonData, + Users: likers, } - if unreadCount > 0 { - err := c.ReadNotifications(ctx, notifications[0].ID) - if err != nil { - return err - } - } + rCtx := getRendererContext(c) + return svc.renderer.RenderLikedByPage(rCtx, c.Writer, data) +} - if len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/notifications?max_id=" + pg.MaxID - } +func (svc *service) ServeRetweetedByPage(ctx context.Context, c *model.Client, + id string) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "notifications") + retweeters, err := c.GetRebloggedBy(ctx, id, nil) if err != nil { return } - data := &renderer.NotificationData{ - Notifications: notifications, - HasNext: hasNext, - NextLink: nextLink, - CommonData: commonData, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderNotificationPage(rCtx, client, data) + commonData, err := svc.getCommonData(ctx, c, "retweets") if err != nil { return } - return + data := &renderer.RetweetedByData{ + CommonData: commonData, + Users: retweeters, + } + + rCtx := getRendererContext(c) + return svc.renderer.RenderRetweetedByPage(rCtx, c.Writer, data) } -func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - user, err := c.GetAccount(ctx, id) - if err != nil { - return - } +func (svc *service) ServeFollowingPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { - var hasNext bool var nextLink string - var pg = mastodon.Pagination{ MaxID: maxID, MinID: minID, Limit: 20, } - statuses, err := c.GetAccountStatuses(ctx, id, &pg) + followings, err := c.GetAccountFollowing(ctx, id, &pg) if err != nil { return } - if len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/user/" + id + "?max_id=" + pg.MaxID + if len(followings) == 20 && len(pg.MaxID) > 0 { + nextLink = "/following/" + id + "?max_id=" + pg.MaxID } - commonData, err := svc.getCommonData(ctx, client, c, user.DisplayName) + commonData, err := svc.getCommonData(ctx, c, "following") if err != nil { return } - data := &renderer.UserData{ - User: user, - Statuses: statuses, - HasNext: hasNext, - NextLink: nextLink, + data := &renderer.FollowingData{ CommonData: commonData, - } - rCtx := getRendererContext(c) - - err = svc.renderer.RenderUserPage(rCtx, client, data) - if err != nil { - return + Users: followings, + NextLink: nextLink, } - return + rCtx := getRendererContext(c) + return svc.renderer.RenderFollowingPage(rCtx, c.Writer, data) } -func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "about") - if err != nil { - return - } +func (svc *service) ServeFollowersPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { - data := &renderer.AboutData{ - CommonData: commonData, + var nextLink string + var pg = mastodon.Pagination{ + MaxID: maxID, + MinID: minID, + Limit: 20, } - rCtx := getRendererContext(c) - err = svc.renderer.RenderAboutPage(rCtx, client, data) + followers, err := c.GetAccountFollowers(ctx, id, &pg) if err != nil { return } - return -} - -func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "emojis") - if err != nil { - return + if len(followers) == 20 && len(pg.MaxID) > 0 { + nextLink = "/followers/" + id + "?max_id=" + pg.MaxID } - emojis, err := c.GetInstanceEmojis(ctx) + commonData, err := svc.getCommonData(ctx, c, "followers") if err != nil { return } - data := &renderer.EmojiData{ - Emojis: emojis, + data := &renderer.FollowersData{ CommonData: commonData, + Users: followers, + NextLink: nextLink, } rCtx := getRendererContext(c) - - err = svc.renderer.RenderEmojiPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderFollowersPage(rCtx, c.Writer, data) } -func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - likers, err := c.GetFavouritedBy(ctx, id, nil) - if err != nil { - return +func (svc *service) ServeNotificationPage(ctx context.Context, c *model.Client, + maxID string, minID string) (err error) { + + var nextLink string + var unreadCount int + var pg = mastodon.Pagination{ + MaxID: maxID, + MinID: minID, + Limit: 20, } - commonData, err := svc.getCommonData(ctx, client, c, "likes") + notifications, err := c.GetNotifications(ctx, &pg) if err != nil { return } - data := &renderer.LikedByData{ - CommonData: commonData, - Users: likers, + for i := range notifications { + if notifications[i].Status != nil { + notifications[i].Status.CreatedAt = notifications[i].CreatedAt + switch notifications[i].Type { + case "reblog", "favourite": + notifications[i].Status.HideAccountInfo = true + } + } + if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { + unreadCount++ + } } - rCtx := getRendererContext(c) - err = svc.renderer.RenderLikedByPage(rCtx, client, data) - if err != nil { - return + if unreadCount > 0 { + err := c.ReadNotifications(ctx, notifications[0].ID) + if err != nil { + return err + } } - return -} - -func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - retweeters, err := c.GetRebloggedBy(ctx, id, nil) - if err != nil { - return + if len(pg.MaxID) > 0 { + nextLink = "/notifications?max_id=" + pg.MaxID } - commonData, err := svc.getCommonData(ctx, client, c, "retweets") + commonData, err := svc.getCommonData(ctx, c, "notifications") if err != nil { return } - data := &renderer.RetweetedByData{ - CommonData: commonData, - Users: retweeters, + data := &renderer.NotificationData{ + Notifications: notifications, + NextLink: nextLink, + CommonData: commonData, } rCtx := getRendererContext(c) - - err = svc.renderer.RenderRetweetedByPage(rCtx, client, data) - if err != nil { - return - } - - return + return svc.renderer.RenderNotificationPage(rCtx, c.Writer, data) } -func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - var hasNext bool +func (svc *service) ServeUserPage(ctx context.Context, c *model.Client, + id string, maxID string, minID string) (err error) { + var nextLink string var pg = mastodon.Pagination{ @@ -643,104 +535,91 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c Limit: 20, } - followings, err := c.GetAccountFollowing(ctx, id, &pg) + user, err := c.GetAccount(ctx, id) if err != nil { return } - if len(followings) == 20 && len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/following/" + id + "?max_id=" + pg.MaxID + statuses, err := c.GetAccountStatuses(ctx, id, &pg) + if err != nil { + return + } + + if len(pg.MaxID) > 0 { + nextLink = "/user/" + id + "?max_id=" + pg.MaxID } - commonData, err := svc.getCommonData(ctx, client, c, "following") + commonData, err := svc.getCommonData(ctx, c, user.DisplayName) if err != nil { return } - data := &renderer.FollowingData{ - CommonData: commonData, - Users: followings, - HasNext: hasNext, + data := &renderer.UserData{ + User: user, + Statuses: statuses, NextLink: nextLink, + CommonData: commonData, } rCtx := getRendererContext(c) + return svc.renderer.RenderUserPage(rCtx, c.Writer, data) +} - err = svc.renderer.RenderFollowingPage(rCtx, client, data) +func (svc *service) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "about") if err != nil { return } - return -} - -func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { - var hasNext bool - var nextLink string - - var pg = mastodon.Pagination{ - MaxID: maxID, - MinID: minID, - Limit: 20, + data := &renderer.AboutData{ + CommonData: commonData, } - followers, err := c.GetAccountFollowers(ctx, id, &pg) + rCtx := getRendererContext(c) + return svc.renderer.RenderAboutPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "emojis") if err != nil { return } - if len(followers) == 20 && len(pg.MaxID) > 0 { - hasNext = true - nextLink = "/followers/" + id + "?max_id=" + pg.MaxID - } - - commonData, err := svc.getCommonData(ctx, client, c, "followers") + emojis, err := c.GetInstanceEmojis(ctx) if err != nil { return } - data := &renderer.FollowersData{ + data := &renderer.EmojiData{ + Emojis: emojis, CommonData: commonData, - Users: followers, - HasNext: hasNext, - NextLink: nextLink, } - rCtx := getRendererContext(c) - err = svc.renderer.RenderFollowersPage(rCtx, client, data) - if err != nil { - return - } - - return + rCtx := getRendererContext(c) + return svc.renderer.RenderEmojiPage(rCtx, c.Writer, data) } -func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { - var hasNext bool +func (svc *service) ServeSearchPage(ctx context.Context, c *model.Client, + q string, qType string, offset int) (err error) { + var nextLink string + var title = "search" results, err := c.Search(ctx, q, qType, 20, true, offset) if err != nil { return } - switch qType { - case "accounts": - hasNext = len(results.Accounts) == 20 - case "statuses": - hasNext = len(results.Statuses) == 20 - } - - if hasNext { + if (qType == "accounts" && len(results.Accounts) == 20) || + (qType == "statuses" && len(results.Statuses) == 20) { offset += 20 nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", q, qType, offset) } - var title = "search" if len(q) > 0 { title += " \"" + q + "\"" } - commonData, err := svc.getCommonData(ctx, client, c, title) + + commonData, err := svc.getCommonData(ctx, c, title) if err != nil { return } @@ -751,21 +630,15 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo Type: qType, Users: results.Accounts, Statuses: results.Statuses, - HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c) - err = svc.renderer.RenderSearchPage(rCtx, client, data) - if err != nil { - return - } - - return + rCtx := getRendererContext(c) + return svc.renderer.RenderSearchPage(rCtx, c.Writer, data) } -func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { - commonData, err := svc.getCommonData(ctx, client, c, "settings") +func (svc *service) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { + commonData, err := svc.getCommonData(ctx, c, "settings") if err != nil { return } @@ -774,122 +647,125 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c * CommonData: commonData, Settings: &c.Session.Settings, } + rCtx := getRendererContext(c) + return svc.renderer.RenderSettingsPage(rCtx, c.Writer, data) +} - err = svc.renderer.RenderSettingsPage(rCtx, client, data) +func (svc *service) NewSession(ctx context.Context, instance string) ( + redirectUrl string, sessionID string, err error) { + + var instanceURL string + if strings.HasPrefix(instance, "https://") { + instanceURL = instance + instance = strings.TrimPrefix(instance, "https://") + } else { + instanceURL = "https://" + instance + } + + sessionID, err = util.NewSessionID() if err != nil { return } - return -} - -func (svc *service) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { - session, err := svc.sessionRepo.Get(c.Session.ID) + csrfToken, err := util.NewCSRFToken() if err != nil { return } - session.Settings = *settings + session := model.Session{ + ID: sessionID, + InstanceDomain: instance, + CSRFToken: csrfToken, + Settings: *model.NewSettings(), + } + err = svc.sessionRepo.Add(session) if err != nil { return } - return -} - -func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *model.Client, title string) (data *renderer.CommonData, err error) { - data = new(renderer.CommonData) - - data.HeaderData = &renderer.HeaderData{ - Title: title + " - " + svc.clientName, - NotificationCount: 0, - CustomCSS: svc.customCSS, - } + app, err := svc.appRepo.Get(instance) + if err != nil { + if err != model.ErrAppNotFound { + return + } - if c != nil && c.Session.IsLoggedIn() { - notifications, err := c.GetNotifications(ctx, nil) + mastoApp, err := mastodon.RegisterApp(ctx, &mastodon.AppConfig{ + Server: instanceURL, + ClientName: svc.clientName, + Scopes: svc.clientScope, + Website: svc.clientWebsite, + RedirectURIs: svc.clientWebsite + "/oauth_callback", + }) if err != nil { - return nil, err + return "", "", err } - var notificationCount int - for i := range notifications { - if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { - notificationCount++ - } + app = model.App{ + InstanceDomain: instance, + InstanceURL: instanceURL, + ClientID: mastoApp.ClientID, + ClientSecret: mastoApp.ClientSecret, } - u, err := c.GetAccountCurrentUser(ctx) + err = svc.appRepo.Add(app) if err != nil { - return nil, err - } - - data.NavbarData = &renderer.NavbarData{ - User: u, - NotificationCount: notificationCount, + return "", "", err } - - data.HeaderData.NotificationCount = notificationCount - data.HeaderData.CSRFToken = c.Session.CSRFToken } - return -} - -func (svc *service) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Favourite(ctx, id) + u, err := url.Parse("/oauth/authorize") if err != nil { return } - count = s.FavouritesCount - return -} -func (svc *service) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Unfavourite(ctx, id) - if err != nil { - return - } - count = s.FavouritesCount + q := make(url.Values) + q.Set("scope", "read write follow") + q.Set("client_id", app.ClientID) + q.Set("response_type", "code") + q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") + u.RawQuery = q.Encode() + + redirectUrl = instanceURL + u.String() + return } -func (svc *service) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Reblog(ctx, id) - if err != nil { +func (svc *service) Signin(ctx context.Context, c *model.Client, + sessionID string, code string) (token string, err error) { + + if len(code) < 1 { + err = errInvalidArgument return } - if s.Reblog != nil { - count = s.Reblog.ReblogsCount - } - return -} -func (svc *service) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { - s, err := c.Unreblog(ctx, id) + err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback") if err != nil { return } - count = s.ReblogsCount + token = c.GetAccessToken(ctx) + return } -func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { - var mediaIds []string +func (svc *service) Post(ctx context.Context, c *model.Client, content string, + replyToID string, format string, visibility string, isNSFW bool, + files []*multipart.FileHeader) (id string, err error) { + + var mediaIDs []string for _, f := range files { a, err := c.UploadMediaFromMultipartFileHeader(ctx, f) if err != nil { return "", err } - mediaIds = append(mediaIds, a.ID) + mediaIDs = append(mediaIDs, a.ID) } tweet := &mastodon.Toot{ Status: content, InReplyToID: replyToID, - MediaIDs: mediaIds, + MediaIDs: mediaIDs, ContentType: format, Visibility: visibility, Sensitive: isNSFW, @@ -903,29 +779,66 @@ func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Cl return s.ID, nil } -func (svc *service) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - _, err = c.AccountFollow(ctx, id) +func (svc *service) Like(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Favourite(ctx, id) + if err != nil { + return + } + count = s.FavouritesCount return } -func (svc *service) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { - _, err = c.AccountUnfollow(ctx, id) +func (svc *service) UnLike(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Unfavourite(ctx, id) + if err != nil { + return + } + count = s.FavouritesCount return } -func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, val string, number int) { - if key == nil { +func (svc *service) Retweet(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Reblog(ctx, id) + if err != nil { return } + if s.Reblog != nil { + count = s.Reblog.ReblogsCount + } + return +} - keyStr, ok := key.(string) - if !ok { +func (svc *service) UnRetweet(ctx context.Context, c *model.Client, id string) ( + count int64, err error) { + s, err := c.Unreblog(ctx, id) + if err != nil { return } - _, ok = m[keyStr] - if !ok { - m[keyStr] = []mastodon.ReplyInfo{} + count = s.ReblogsCount + return +} + +func (svc *service) Follow(ctx context.Context, c *model.Client, id string) (err error) { + _, err = c.AccountFollow(ctx, id) + return +} + +func (svc *service) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { + _, err = c.AccountUnfollow(ctx, id) + return +} + +func (svc *service) SaveSettings(ctx context.Context, c *model.Client, + settings *model.Settings) (err error) { + + session, err := svc.sessionRepo.Get(c.Session.ID) + if err != nil { + return } - m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) + session.Settings = *settings + return svc.sessionRepo.Add(session) } diff --git a/service/transport.go b/service/transport.go index e878f8d..fbab2e5 100644 --- a/service/transport.go +++ b/service/transport.go @@ -15,327 +15,292 @@ import ( "github.com/gorilla/mux" ) -var ( - ctx = context.Background() - cookieAge = "31536000" -) +func newClient(w io.Writer) *model.Client { + return &model.Client{ + Writer: w, + } +} + +func newCtxWithSesion(req *http.Request) context.Context { + ctx := context.Background() + sessionID, err := req.Cookie("session_id") + if err != nil { + return ctx + } + return context.WithValue(ctx, "session_id", sessionID.Value) +} + +func newCtxWithSesionCSRF(req *http.Request, csrfToken string) context.Context { + ctx := newCtxWithSesion(req) + return context.WithValue(ctx, "csrf_token", csrfToken) +} + +func getMultipartFormValue(mf *multipart.Form, key string) (val string) { + vals, ok := mf.Value[key] + if !ok { + return "" + } + if len(vals) < 1 { + return "" + } + return vals[0] +} + +func serveJson(w io.Writer, data interface{}) (err error) { + var d = make(map[string]interface{}) + d["data"] = data + return json.NewEncoder(w).Encode(d) +} func NewHandler(s Service, staticDir string) http.Handler { r := mux.NewRouter() - r.PathPrefix("/static").Handler(http.StripPrefix("/static", - http.FileServer(http.Dir(path.Join(".", staticDir))))) + rootPage := func(w http.ResponseWriter, req *http.Request) { + sessionID, _ := req.Cookie("session_id") - r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { location := "/signin" - - sessionID, _ := req.Cookie("session_id") if sessionID != nil && len(sessionID.Value) > 0 { location = "/timeline/home" } w.Header().Add("Location", location) w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { - err := s.ServeSigninPage(ctx, w) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) + } - r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { - instance := req.FormValue("instance") - url, sessionID, err := s.GetAuthUrl(ctx, instance) + signinPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := context.Background() + err := s.ServeSigninPage(ctx, c) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sessionID, - Expires: time.Now().Add(365 * 24 * time.Hour), - }) - - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) + timelinePage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + tType, _ := mux.Vars(req)["type"] + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") - r.HandleFunc("/oauth_callback", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - token := req.URL.Query().Get("code") - _, err := s.GetUserToken(ctx, "", nil, token) + err := s.ServeTimelinePage(ctx, c, tType, maxID, minID) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } + timelineOldPage := func(w http.ResponseWriter, req *http.Request) { w.Header().Add("Location", "/timeline/home") w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/timeline", func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Location", "/timeline/home") - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) - - r.HandleFunc("/timeline/{type}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - timelineType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - sinceID := req.URL.Query().Get("since_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeTimelinePage(ctx, w, nil, timelineType, maxID, sinceID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) + } - r.HandleFunc("/thread/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + threadPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] reply := req.URL.Query().Get("reply") - err := s.ServeThreadPage(ctx, w, nil, id, len(reply) > 1) + + err := s.ServeThreadPage(ctx, c, id, len(reply) > 1) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) + } - r.HandleFunc("/likedby/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + likedByPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] - err := s.ServeLikedByPage(ctx, w, nil, id) + err := s.ServeLikedByPage(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) + } - r.HandleFunc("/retweetedby/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + retweetedByPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] - err := s.ServeRetweetedByPage(ctx, w, nil, id) + err := s.ServeRetweetedByPage(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - - r.HandleFunc("/following/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + } + followingPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] maxID := req.URL.Query().Get("max_id") minID := req.URL.Query().Get("min_id") - err := s.ServeFollowingPage(ctx, w, nil, id, maxID, minID) + err := s.ServeFollowingPage(ctx, c, id, maxID, minID) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - - r.HandleFunc("/followers/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + } + followersPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] maxID := req.URL.Query().Get("max_id") minID := req.URL.Query().Get("min_id") - err := s.ServeFollowersPage(ctx, w, nil, id, maxID, minID) + err := s.ServeFollowersPage(ctx, c, id, maxID, minID) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - - r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + } - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + notificationsPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") - _, err := s.Like(ctx, w, nil, id) + err := s.ServeNotificationPage(ctx, c, maxID, minID) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - + userPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + maxID := req.URL.Query().Get("max_id") + minID := req.URL.Query().Get("min_id") - _, err := s.UnLike(ctx, w, nil, id) + err := s.ServeUserPage(ctx, c, id, maxID, minID) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + aboutPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) - _, err := s.Retweet(ctx, w, nil, id) + err := s.ServeAboutPage(ctx, c) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + emojisPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - - _, err := s.UnRetweet(ctx, w, nil, id) + err := s.ServeEmojiPage(ctx, c) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } + } - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID - } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + searchPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + q := req.URL.Query().Get("q") + qType := req.URL.Query().Get("type") + offsetStr := req.URL.Query().Get("offset") - id, _ := mux.Vars(req)["id"] - count, err := s.Like(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return + var offset int + var err error + if len(offsetStr) > 1 { + offset, err = strconv.Atoi(offsetStr) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } } - err = serveJson(w, count) + err = s.ServeSearchPage(ctx, c, q, qType, offset) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodPost) - - r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + } - id, _ := mux.Vars(req)["id"] - count, err := s.UnLike(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } + settingsPage := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) - err = serveJson(w, count) + err := s.ServeSettingsPage(ctx, c) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodPost) + } - r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + signin := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := context.Background() + instance := req.FormValue("instance") - id, _ := mux.Vars(req)["id"] - count, err := s.Retweet(ctx, w, nil, id) + url, sessionID, err := s.NewSession(ctx, instance) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - err = serveJson(w, count) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodPost) + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: sessionID, + Expires: time.Now().Add(365 * 24 * time.Hour), + }) - r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + w.Header().Add("Location", url) + w.WriteHeader(http.StatusFound) + } - id, _ := mux.Vars(req)["id"] - count, err := s.UnRetweet(ctx, w, nil, id) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } + oauthCallback := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesion(req) + token := req.URL.Query().Get("code") - err = serveJson(w, count) + _, err := s.Signin(ctx, c, "", token) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodPost) - r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Location", "/timeline/home") + w.WriteHeader(http.StatusFound) + } + + post := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) err := req.ParseMultipartForm(4 << 20) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(context.Background(), c, err) return } - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", + ctx := newCtxWithSesionCSRF(req, getMultipartFormValue(req.MultipartForm, "csrf_token")) - content := getMultipartFormValue(req.MultipartForm, "content") replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") format := getMultipartFormValue(req.MultipartForm, "format") visibility := getMultipartFormValue(req.MultipartForm, "visibility") isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") - files := req.MultipartForm.File["attachments"] - id, err := s.PostTweet(ctx, w, nil, content, replyToID, format, visibility, isNSFW, files) + id, err := s.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } @@ -345,131 +310,129 @@ func NewHandler(s Service, staticDir string) http.Handler { } w.Header().Add("Location", location) w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/notifications", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeNotificationPage(ctx, w, nil, maxID, minID) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - }).Methods(http.MethodGet) - - r.HandleFunc("/user/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + } + like := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") + retweetedByID := req.FormValue("retweeted_by_id") - err := s.ServeUserPage(ctx, w, nil, id, maxID, minID) + _, err := s.Like(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } + unlike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") - err := s.Follow(ctx, w, nil, id) + _, err := s.UnLike(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - w.Header().Add("Location", req.Header.Get("Referer")) + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) - - r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + } + retweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") - err := s.UnFollow(ctx, w, nil, id) + _, err := s.Retweet(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - w.Header().Add("Location", req.Header.Get("Referer")) + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID + } + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) + } - r.HandleFunc("/about", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + unretweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + retweetedByID := req.FormValue("retweeted_by_id") - err := s.ServeAboutPage(ctx, w, nil) + _, err := s.UnRetweet(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - r.HandleFunc("/emojis", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - - err := s.ServeEmojiPage(ctx, w, nil) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return + rID := id + if len(retweetedByID) > 0 { + rID = retweetedByID } - }).Methods(http.MethodGet) - - r.HandleFunc("/search", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - q := req.URL.Query().Get("q") - qType := req.URL.Query().Get("type") - offsetStr := req.URL.Query().Get("offset") + w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) + w.WriteHeader(http.StatusFound) + } - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) - if err != nil { - s.ServeErrorPage(ctx, w, nil, err) - return - } - } + follow := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] - err = s.ServeSearchPage(ctx, w, nil, q, qType, offset) + err := s.Follow(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) + w.Header().Add("Location", req.Header.Get("Referer")) + w.WriteHeader(http.StatusFound) + } + + unfollow := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] - err := s.ServeSettingsPage(ctx, w, nil) + err := s.UnFollow(ctx, c, id) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } - }).Methods(http.MethodGet) - r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + w.Header().Add("Location", req.Header.Get("Referer")) + w.WriteHeader(http.StatusFound) + } + settings := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) visibility := req.FormValue("visibility") copyScope := req.FormValue("copy_scope") == "true" threadInNewTab := req.FormValue("thread_in_new_tab") == "true" maskNSFW := req.FormValue("mask_nsfw") == "true" fluorideMode := req.FormValue("fluoride_mode") == "true" darkMode := req.FormValue("dark_mode") == "true" + settings := &model.Settings{ DefaultVisibility: visibility, CopyScope: copyScope, @@ -479,17 +442,17 @@ func NewHandler(s Service, staticDir string) http.Handler { DarkMode: darkMode, } - err := s.SaveSettings(ctx, w, nil, settings) + err := s.SaveSettings(ctx, c, settings) if err != nil { - s.ServeErrorPage(ctx, w, nil, err) + s.ServeErrorPage(ctx, c, err) return } w.Header().Add("Location", req.Header.Get("Referer")) w.WriteHeader(http.StatusFound) - }).Methods(http.MethodPost) + } - r.HandleFunc("/signout", func(w http.ResponseWriter, req *http.Request) { + signout := func(w http.ResponseWriter, req *http.Request) { // TODO remove session from database http.SetCookie(w, &http.Cookie{ Name: "session_id", @@ -498,32 +461,111 @@ func NewHandler(s Service, staticDir string) http.Handler { }) w.Header().Add("Location", "/") w.WriteHeader(http.StatusFound) - }).Methods(http.MethodGet) + } - return r -} + fLike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] -func getContextWithSession(ctx context.Context, req *http.Request) context.Context { - sessionID, err := req.Cookie("session_id") - if err != nil { - return ctx + count, err := s.Like(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } } - return context.WithValue(ctx, "session_id", sessionID.Value) -} -func getMultipartFormValue(mf *multipart.Form, key string) (val string) { - vals, ok := mf.Value[key] - if !ok { - return "" + fUnlike := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + count, err := s.UnLike(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } } - if len(vals) < 1 { - return "" + + fRetweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + count, err := s.Retweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } } - return vals[0] -} -func serveJson(w io.Writer, data interface{}) (err error) { - var d = make(map[string]interface{}) - d["data"] = data - return json.NewEncoder(w).Encode(d) + fUnretweet := func(w http.ResponseWriter, req *http.Request) { + c := newClient(w) + ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] + + count, err := s.UnRetweet(ctx, c, id) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + + err = serveJson(w, count) + if err != nil { + s.ServeErrorPage(ctx, c, err) + return + } + } + + r.HandleFunc("/", rootPage).Methods(http.MethodGet) + r.HandleFunc("/signin", signinPage).Methods(http.MethodGet) + r.HandleFunc("/timeline/{type}", timelinePage).Methods(http.MethodGet) + r.HandleFunc("/timeline", timelineOldPage).Methods(http.MethodGet) + r.HandleFunc("/thread/{id}", threadPage).Methods(http.MethodGet) + r.HandleFunc("/likedby/{id}", likedByPage).Methods(http.MethodGet) + r.HandleFunc("/retweetedby/{id}", retweetedByPage).Methods(http.MethodGet) + r.HandleFunc("/following/{id}", followingPage).Methods(http.MethodGet) + r.HandleFunc("/followers/{id}", followersPage).Methods(http.MethodGet) + r.HandleFunc("/notifications", notificationsPage).Methods(http.MethodGet) + r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet) + r.HandleFunc("/about", aboutPage).Methods(http.MethodGet) + r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet) + r.HandleFunc("/search", searchPage).Methods(http.MethodGet) + r.HandleFunc("/settings", settingsPage).Methods(http.MethodGet) + r.HandleFunc("/signin", signin).Methods(http.MethodPost) + r.HandleFunc("/oauth_callback", oauthCallback).Methods(http.MethodGet) + r.HandleFunc("/post", post).Methods(http.MethodPost) + r.HandleFunc("/like/{id}", like).Methods(http.MethodPost) + r.HandleFunc("/unlike/{id}", unlike).Methods(http.MethodPost) + r.HandleFunc("/retweet/{id}", retweet).Methods(http.MethodPost) + r.HandleFunc("/unretweet/{id}", unretweet).Methods(http.MethodPost) + r.HandleFunc("/follow/{id}", follow).Methods(http.MethodPost) + r.HandleFunc("/unfollow/{id}", unfollow).Methods(http.MethodPost) + r.HandleFunc("/settings", settings).Methods(http.MethodPost) + r.HandleFunc("/signout", signout).Methods(http.MethodGet) + r.HandleFunc("/fluoride/like/{id}", fLike).Methods(http.MethodPost) + r.HandleFunc("/fluoride/unlike/{id}", fUnlike).Methods(http.MethodPost) + r.HandleFunc("/fluoride/retweet/{id}", fRetweet).Methods(http.MethodPost) + r.HandleFunc("/fluoride/unretweet/{id}", fUnretweet).Methods(http.MethodPost) + r.PathPrefix("/static").Handler(http.StripPrefix("/static", + http.FileServer(http.Dir(path.Join(".", staticDir))))) + + return r } -- cgit v1.2.3