diff --git a/cmd/root.go b/cmd/root.go index 3692997..e68bb80 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -74,6 +74,11 @@ func initConfig() { logrus.Info("Using config file: ", viper.ConfigFileUsed()) } + viper.SetConfigName("secrets") + if err := viper.MergeInConfig(); err == nil { + logrus.Info("Using secret file: ", viper.ConfigFileUsed()) + } + viper.AutomaticEnv() } diff --git a/internal/billing/service.go b/internal/billing/service.go index 89dc609..c2a2b1b 100644 --- a/internal/billing/service.go +++ b/internal/billing/service.go @@ -9,6 +9,8 @@ import ( "github.com/stripe/stripe-go/v76" portalsession "github.com/stripe/stripe-go/v76/billingportal/session" "github.com/stripe/stripe-go/v76/checkout/session" + "github.com/stripe/stripe-go/v76/price" + "github.com/stripe/stripe-go/v76/subscription" "github.com/stripe/stripe-go/v76/webhook" ) @@ -22,8 +24,13 @@ type Config struct { type Service interface { PostConfig() StripeConfig CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error) + UpdateSubscribtion(team *entities.Team, priceLookupKey string) (*stripe.Subscription, error) + CancelSubscribtion(team *entities.Team) (*stripe.Subscription, error) GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error) GetLineItems(sessionID string) *session.LineItemIter + GetPrice(priceLookupKey string) (*stripe.Price, error) + GetCustomerSubscribtion(customerID string) *subscription.Iter + GetSubscribtion(subID string) (*stripe.Subscription, error) CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error) ConstructEvent(payload []byte, header string) (stripe.Event, error) } @@ -54,21 +61,26 @@ func (s *ServiceImpl) PostConfig() StripeConfig { } func (s *ServiceImpl) CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error) { + priceID, err := s.GetPrice(priceLookupKey) + if err != nil { + return nil, errors.Wrap(err, "failed to get price") + } + params := &stripe.CheckoutSessionParams{ - SuccessURL: stripe.String(s.Config.Domain + "/team/plan"), - CancelURL: stripe.String(s.Config.Domain + "/canceled.html"), + SuccessURL: stripe.String(s.Config.Domain + "/team/subscription"), + CancelURL: stripe.String(s.Config.Domain + "/team/subscription"), Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), ClientReferenceID: stripe.String(strconv.FormatUint(uint64(team.ID), 10)), LineItems: []*stripe.CheckoutSessionLineItemParams{ { - Price: stripe.String(priceLookupKey), + + Price: stripe.String(priceID.ID), Quantity: stripe.Int64(1), }, }, } - // Set Customer on session if already a customer if team.StripeCustomerID != nil { params.Customer = stripe.String(*team.StripeCustomerID) } @@ -76,6 +88,50 @@ func (s *ServiceImpl) CreateCheckoutSession(team *entities.Team, priceLookupKey return session.New(params) } +func (s *ServiceImpl) UpdateSubscribtion(team *entities.Team, priceLookupKey string) (*stripe.Subscription, error) { + // Set Customer on session if already a customer + + priceID, err := s.GetPrice(priceLookupKey) + if err != nil { + return nil, errors.Wrap(err, "failed to get price") + } + sub := s.GetCustomerSubscribtion(*team.StripeCustomerID) + sub.Next() + teamSubscription := sub.Subscription() + + params := &stripe.SubscriptionParams{ + Items: []*stripe.SubscriptionItemsParams{ + { + ID: stripe.String(teamSubscription.Items.Data[0].ID), + Price: stripe.String(priceID.ID), + }, + }, + } + result, err := subscription.Update(teamSubscription.ID, params) + if err != nil { + return nil, errors.Wrap(err, "failed to update subscription") + } + + return result, nil +} + +func (s *ServiceImpl) CancelSubscribtion(team *entities.Team) (*stripe.Subscription, error) { + // Set Customer on session if already a customer + + sub := s.GetCustomerSubscribtion(*team.StripeCustomerID) + sub.Next() + teamSubscription := sub.Subscription() + + params := &stripe.SubscriptionCancelParams{} + result, err := subscription.Cancel(teamSubscription.ID, params) + + if err != nil { + return nil, errors.Wrap(err, "failed to update subscription") + } + + return result, nil +} + func (s *ServiceImpl) GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error) { return session.Get(sessionID, &stripe.CheckoutSessionParams{}) } @@ -87,6 +143,36 @@ func (s *ServiceImpl) GetLineItems(sessionID string) *session.LineItemIter { return session.ListLineItems(params) } +func (s *ServiceImpl) GetPrice(priceLookupKey string) (*stripe.Price, error) { + params := &stripe.PriceListParams{ + LookupKeys: stripe.StringSlice([]string{ + priceLookupKey, + }), + } + i := price.List(params) + + var price *stripe.Price + for i.Next() { + p := i.Price() + price = p + } + if price == nil { + return nil, errors.New("Price not found for lookup key" + priceLookupKey) + } + return price, nil +} + +func (s *ServiceImpl) GetCustomerSubscribtion(customerID string) *subscription.Iter { + params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)} + return subscription.List(params) +} + +func (s *ServiceImpl) GetSubscribtion(subID string) (*stripe.Subscription, error) { + params := &stripe.SubscriptionParams{} + return subscription.Get(subID, params) + +} + func (s *ServiceImpl) CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error) { // For demonstration purposes, we're using the Checkout session to retrieve the customer ID. // Typically this is stored alongside the authenticated user in your database. diff --git a/internal/rest/controllers/teams/billing.go b/internal/rest/controllers/teams/billing.go index cc40508..6d6dba3 100644 --- a/internal/rest/controllers/teams/billing.go +++ b/internal/rest/controllers/teams/billing.go @@ -24,8 +24,6 @@ func (h *Handlers) PostCreateCheckoutSession(c hs.AuthenticatedContext) error { return echo.ErrBadRequest } - c.Log.Error(req.PriceLookupKey) - c.Log.Info(req.TeamID) team, err := h.TeamService.GetByID(c.Request().Context(), req.TeamID) if err != nil { @@ -34,14 +32,51 @@ func (h *Handlers) PostCreateCheckoutSession(c hs.AuthenticatedContext) error { return echo.ErrInternalServerError } - s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey) + if team.PaymentPlan == req.PriceLookupKey { + return c.JSON(http.StatusOK, "") + } + + if team.StripeCustomerID == nil { + if req.PriceLookupKey == "FREE" { + return c.JSON(http.StatusOK, "") + } + s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey) + if err != nil { + c.Log.WithError(err).Debug("create stripe checkout session") + + return echo.ErrInternalServerError + } + return c.JSON(http.StatusOK, s.URL) + } + + if req.PriceLookupKey == "FREE" { + _, err := h.BillingService.CancelSubscribtion(team) + if err != nil { + c.Log.WithError(err).Debug("cancel subscription") + + return echo.ErrInternalServerError + } + return c.JSON(http.StatusOK, "") + } + + if team.PaymentPlan == "FREE" { + s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey) + if err != nil { + c.Log.WithError(err).Debug("create stripe checkout session") + + return echo.ErrInternalServerError + } + return c.JSON(http.StatusOK, s.URL) + } + + _, err = h.BillingService.UpdateSubscribtion(team, req.PriceLookupKey) if err != nil { - c.Log.WithError(err).Debug("create stripe checkout session") + c.Log.WithError(err).Debug("update subscription") return echo.ErrInternalServerError } + return c.JSON(http.StatusOK, "") - return c.JSON(http.StatusOK, s.URL) } type GetCheckoutSession struct { diff --git a/internal/rest/controllers/webhooks/routes.go b/internal/rest/controllers/webhooks/routes.go index e56f156..ade1f45 100644 --- a/internal/rest/controllers/webhooks/routes.go +++ b/internal/rest/controllers/webhooks/routes.go @@ -24,6 +24,7 @@ func Register( ) { h := &Handlers{ BillingService: billingService, + TeamService: teamService, } root := e.Group( diff --git a/internal/rest/controllers/webhooks/stripe.go b/internal/rest/controllers/webhooks/stripe.go index 7e70725..8541d89 100644 --- a/internal/rest/controllers/webhooks/stripe.go +++ b/internal/rest/controllers/webhooks/stripe.go @@ -37,22 +37,13 @@ func (h *Handlers) handleWebhook(c hs.StripeContext) error { return echo.ErrBadRequest } - params := &stripe.CheckoutSessionParams{} - params.AddExpand("line_items") - - // Retrieve the session. If you require line items in the response, you may include them by expanding line_items. - // sessionWithLineItems, err := h.BillingService.GetCheckoutSession(session.ID) - // if err != nil { - // c.Log.WithError(err).Debug("Error getting checkout session") - // return echo.ErrBadRequest - // } - - c.Log.Error(session.ClientReferenceID) - c.Log.Error(session) - c.Log.Error(session.ID) - c.Log.Error(event.Data.Object["customer"].(string)) - lineItems := h.BillingService.GetLineItems(session.ID).List() - c.Log.Error(lineItems) + c.Log.Error(event.Data.Object["subscription"].(string)) + subscription, err := h.BillingService.GetSubscribtion(event.Data.Object["subscription"].(string)) + if err != nil { + c.Log.WithError(err).Debug("Error getting subscription") + } + c.Log.Error(subscription.Items.Data[0]) + c.Log.Error(subscription.Items.Data[0].Price.LookupKey) teamID, err := strconv.ParseUint(session.ClientReferenceID, 10, 32) if err != nil { @@ -60,24 +51,98 @@ func (h *Handlers) handleWebhook(c hs.StripeContext) error { return c.NoContent(http.StatusInternalServerError) } + c.Log.Error(c.Request().Context()) + c.Log.Error(uint(teamID)) customerTeam, err := h.TeamService.GetByID(c.Request().Context(), uint(teamID)) if err != nil { c.Log.WithError(err).Debug("Error getting team by id", teamID) return c.NoContent(http.StatusInternalServerError) } - // TODO if not same plan remove old plan - - if customerTeam.PaymentPlan == "TEST" { //lineItems.Price.Product.Name { - return c.NoContent(http.StatusOK) - } - customerID := event.Data.Object["customer"].(string) if customerTeam.StripeCustomerID == nil { customerTeam.StripeCustomerID = &customerID } - h.TeamService.UpdateTeam(c.Request().Context(), customerTeam) + customerTeam.PaymentPlan = subscription.Items.Data[0].Price.LookupKey + + err = h.TeamService.UpdateTeam(c.Request().Context(), customerTeam) + if err != nil { + c.Log.WithError(err).Debug("Error updating team") + return c.NoContent(http.StatusInternalServerError) + } + case "customer.subscription.updated": + var subscription stripe.Subscription + err := json.Unmarshal(event.Data.Raw, &subscription) + if err != nil { + c.Log.WithError(err).Debug("Error parsing webhook JSON") + return echo.ErrBadRequest + } + + team, err := h.TeamService.GetByStripeID(c.Request().Context(), subscription.Customer.ID) + if err != nil { + c.Log.WithError(err).Debug("Error getting team by stripe id") + return c.NoContent(http.StatusInternalServerError) + } + team.PaymentPlan = subscription.Items.Data[0].Price.LookupKey + + if subscription.Status == "canceled" { + team.PaymentPlan = "FREE" + } + + err = h.TeamService.UpdateTeam(c.Request().Context(), team) + if err != nil { + c.Log.WithError(err).Debug("Error updating team") + return c.NoContent(http.StatusInternalServerError) + } + + case "customer.subscription.deleted": + var subscription stripe.Subscription + err := json.Unmarshal(event.Data.Raw, &subscription) + if err != nil { + c.Log.WithError(err).Debug("Error parsing webhook JSON") + return echo.ErrBadRequest + } + + team, err := h.TeamService.GetByStripeID(c.Request().Context(), subscription.Customer.ID) + if err != nil { + c.Log.WithError(err).Debug("Error getting team by stripe id") + return c.NoContent(http.StatusInternalServerError) + } + team.PaymentPlan = subscription.Items.Data[0].Price.LookupKey + + if subscription.Status == "canceled" { + team.PaymentPlan = "FREE" + } + + err = h.TeamService.UpdateTeam(c.Request().Context(), team) + if err != nil { + c.Log.WithError(err).Debug("Error updating team") + return c.NoContent(http.StatusInternalServerError) + } + + case "customer.deleted": + var customer stripe.Customer + err := json.Unmarshal(event.Data.Raw, &customer) + if err != nil { + c.Log.WithError(err).Debug("Error parsing webhook JSON") + return echo.ErrBadRequest + } + + team, err := h.TeamService.GetByStripeID(c.Request().Context(), customer.ID) + if err != nil { + c.Log.WithError(err).Debug("Error getting team by stripe id") + return c.NoContent(http.StatusInternalServerError) + } + + team.PaymentPlan = "FREE" + team.StripeCustomerID = nil + + err = h.TeamService.UpdateTeam(c.Request().Context(), team) + if err != nil { + c.Log.WithError(err).Debug("Error updating team") + return c.NoContent(http.StatusInternalServerError) + } default: c.Log.WithField("event", event.Type).Debug("Unhandled event type")