use std::{
    collections::{BTreeMap, BTreeSet},
    str::FromStr,
    string::ToString,
};
use async_trait::async_trait;
use mas_data_model::{Client, JwksOrJwksUri, User};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_storage::{oauth2::OAuth2ClientRepository, Clock};
use oauth2_types::{
    oidc::ApplicationType,
    requests::GrantType,
    scope::{Scope, ScopeToken},
};
use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
use rand::RngCore;
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::{tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError};
pub struct PgOAuth2ClientRepository<'c> {
    conn: &'c mut PgConnection,
}
impl<'c> PgOAuth2ClientRepository<'c> {
    pub fn new(conn: &'c mut PgConnection) -> Self {
        Self { conn }
    }
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug)]
struct OAuth2ClientLookup {
    oauth2_client_id: Uuid,
    encrypted_client_secret: Option<String>,
    application_type: Option<String>,
    redirect_uris: Vec<String>,
    grant_type_authorization_code: bool,
    grant_type_refresh_token: bool,
    grant_type_client_credentials: bool,
    grant_type_device_code: bool,
    client_name: Option<String>,
    logo_uri: Option<String>,
    client_uri: Option<String>,
    policy_uri: Option<String>,
    tos_uri: Option<String>,
    jwks_uri: Option<String>,
    jwks: Option<serde_json::Value>,
    id_token_signed_response_alg: Option<String>,
    userinfo_signed_response_alg: Option<String>,
    token_endpoint_auth_method: Option<String>,
    token_endpoint_auth_signing_alg: Option<String>,
    initiate_login_uri: Option<String>,
}
impl TryInto<Client> for OAuth2ClientLookup {
    type Error = DatabaseInconsistencyError;
    #[allow(clippy::too_many_lines)] fn try_into(self) -> Result<Client, Self::Error> {
        let id = Ulid::from(self.oauth2_client_id);
        let redirect_uris: Result<Vec<Url>, _> =
            self.redirect_uris.iter().map(|s| s.parse()).collect();
        let redirect_uris = redirect_uris.map_err(|e| {
            DatabaseInconsistencyError::on("oauth2_clients")
                .column("redirect_uris")
                .row(id)
                .source(e)
        })?;
        let application_type = self
            .application_type
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("application_type")
                    .row(id)
                    .source(e)
            })?;
        let mut grant_types = Vec::new();
        if self.grant_type_authorization_code {
            grant_types.push(GrantType::AuthorizationCode);
        }
        if self.grant_type_refresh_token {
            grant_types.push(GrantType::RefreshToken);
        }
        if self.grant_type_client_credentials {
            grant_types.push(GrantType::ClientCredentials);
        }
        if self.grant_type_device_code {
            grant_types.push(GrantType::DeviceCode);
        }
        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
            DatabaseInconsistencyError::on("oauth2_clients")
                .column("logo_uri")
                .row(id)
                .source(e)
        })?;
        let client_uri = self
            .client_uri
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("client_uri")
                    .row(id)
                    .source(e)
            })?;
        let policy_uri = self
            .policy_uri
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("policy_uri")
                    .row(id)
                    .source(e)
            })?;
        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
            DatabaseInconsistencyError::on("oauth2_clients")
                .column("tos_uri")
                .row(id)
                .source(e)
        })?;
        let id_token_signed_response_alg = self
            .id_token_signed_response_alg
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("id_token_signed_response_alg")
                    .row(id)
                    .source(e)
            })?;
        let userinfo_signed_response_alg = self
            .userinfo_signed_response_alg
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("userinfo_signed_response_alg")
                    .row(id)
                    .source(e)
            })?;
        let token_endpoint_auth_method = self
            .token_endpoint_auth_method
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("token_endpoint_auth_method")
                    .row(id)
                    .source(e)
            })?;
        let token_endpoint_auth_signing_alg = self
            .token_endpoint_auth_signing_alg
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("token_endpoint_auth_signing_alg")
                    .row(id)
                    .source(e)
            })?;
        let initiate_login_uri = self
            .initiate_login_uri
            .map(|s| s.parse())
            .transpose()
            .map_err(|e| {
                DatabaseInconsistencyError::on("oauth2_clients")
                    .column("initiate_login_uri")
                    .row(id)
                    .source(e)
            })?;
        let jwks = match (self.jwks, self.jwks_uri) {
            (None, None) => None,
            (Some(jwks), None) => {
                let jwks = serde_json::from_value(jwks).map_err(|e| {
                    DatabaseInconsistencyError::on("oauth2_clients")
                        .column("jwks")
                        .row(id)
                        .source(e)
                })?;
                Some(JwksOrJwksUri::Jwks(jwks))
            }
            (None, Some(jwks_uri)) => {
                let jwks_uri = jwks_uri.parse().map_err(|e| {
                    DatabaseInconsistencyError::on("oauth2_clients")
                        .column("jwks_uri")
                        .row(id)
                        .source(e)
                })?;
                Some(JwksOrJwksUri::JwksUri(jwks_uri))
            }
            _ => {
                return Err(DatabaseInconsistencyError::on("oauth2_clients")
                    .column("jwks(_uri)")
                    .row(id))
            }
        };
        Ok(Client {
            id,
            client_id: id.to_string(),
            encrypted_client_secret: self.encrypted_client_secret,
            application_type,
            redirect_uris,
            grant_types,
            client_name: self.client_name,
            logo_uri,
            client_uri,
            policy_uri,
            tos_uri,
            jwks,
            id_token_signed_response_alg,
            userinfo_signed_response_alg,
            token_endpoint_auth_method,
            token_endpoint_auth_signing_alg,
            initiate_login_uri,
        })
    }
}
#[async_trait]
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
    type Error = DatabaseError;
    #[tracing::instrument(
        name = "db.oauth2_client.lookup",
        skip_all,
        fields(
            db.query.text,
            oauth2_client.id = %id,
        ),
        err,
    )]
    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
        let res = sqlx::query_as!(
            OAuth2ClientLookup,
            r#"
                SELECT oauth2_client_id
                     , encrypted_client_secret
                     , application_type
                     , redirect_uris
                     , grant_type_authorization_code
                     , grant_type_refresh_token
                     , grant_type_client_credentials
                     , grant_type_device_code
                     , client_name
                     , logo_uri
                     , client_uri
                     , policy_uri
                     , tos_uri
                     , jwks_uri
                     , jwks
                     , id_token_signed_response_alg
                     , userinfo_signed_response_alg
                     , token_endpoint_auth_method
                     , token_endpoint_auth_signing_alg
                     , initiate_login_uri
                FROM oauth2_clients c
                WHERE oauth2_client_id = $1
            "#,
            Uuid::from(id),
        )
        .traced()
        .fetch_optional(&mut *self.conn)
        .await?;
        let Some(res) = res else { return Ok(None) };
        Ok(Some(res.try_into()?))
    }
    #[tracing::instrument(
        name = "db.oauth2_client.load_batch",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn load_batch(
        &mut self,
        ids: BTreeSet<Ulid>,
    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
        let res = sqlx::query_as!(
            OAuth2ClientLookup,
            r#"
                SELECT oauth2_client_id
                     , encrypted_client_secret
                     , application_type
                     , redirect_uris
                     , grant_type_authorization_code
                     , grant_type_refresh_token
                     , grant_type_client_credentials
                     , grant_type_device_code
                     , client_name
                     , logo_uri
                     , client_uri
                     , policy_uri
                     , tos_uri
                     , jwks_uri
                     , jwks
                     , id_token_signed_response_alg
                     , userinfo_signed_response_alg
                     , token_endpoint_auth_method
                     , token_endpoint_auth_signing_alg
                     , initiate_login_uri
                FROM oauth2_clients c
                WHERE oauth2_client_id = ANY($1::uuid[])
            "#,
            &ids,
        )
        .traced()
        .fetch_all(&mut *self.conn)
        .await?;
        res.into_iter()
            .map(|r| {
                r.try_into()
                    .map(|c: Client| (c.id, c))
                    .map_err(DatabaseError::from)
            })
            .collect()
    }
    #[tracing::instrument(
        name = "db.oauth2_client.add",
        skip_all,
        fields(
            db.query.text,
            client.id,
            client.name = client_name
        ),
        err,
    )]
    #[allow(clippy::too_many_lines)]
    async fn add(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        redirect_uris: Vec<Url>,
        encrypted_client_secret: Option<String>,
        application_type: Option<ApplicationType>,
        grant_types: Vec<GrantType>,
        client_name: Option<String>,
        logo_uri: Option<Url>,
        client_uri: Option<Url>,
        policy_uri: Option<Url>,
        tos_uri: Option<Url>,
        jwks_uri: Option<Url>,
        jwks: Option<PublicJsonWebKeySet>,
        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
        initiate_login_uri: Option<Url>,
    ) -> Result<Client, Self::Error> {
        let now = clock.now();
        let id = Ulid::from_datetime_with_source(now.into(), rng);
        tracing::Span::current().record("client.id", tracing::field::display(id));
        let jwks_json = jwks
            .as_ref()
            .map(serde_json::to_value)
            .transpose()
            .map_err(DatabaseError::to_invalid_operation)?;
        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
        sqlx::query!(
            r#"
                INSERT INTO oauth2_clients
                    ( oauth2_client_id
                    , encrypted_client_secret
                    , application_type
                    , redirect_uris
                    , grant_type_authorization_code
                    , grant_type_refresh_token
                    , grant_type_client_credentials
                    , grant_type_device_code
                    , client_name
                    , logo_uri
                    , client_uri
                    , policy_uri
                    , tos_uri
                    , jwks_uri
                    , jwks
                    , id_token_signed_response_alg
                    , userinfo_signed_response_alg
                    , token_endpoint_auth_method
                    , token_endpoint_auth_signing_alg
                    , initiate_login_uri
                    , is_static
                    )
                VALUES
                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, FALSE)
            "#,
            Uuid::from(id),
            encrypted_client_secret,
            application_type.as_ref().map(ToString::to_string),
            &redirect_uris_array,
            grant_types.contains(&GrantType::AuthorizationCode),
            grant_types.contains(&GrantType::RefreshToken),
            grant_types.contains(&GrantType::ClientCredentials),
            grant_types.contains(&GrantType::DeviceCode),
            client_name,
            logo_uri.as_ref().map(Url::as_str),
            client_uri.as_ref().map(Url::as_str),
            policy_uri.as_ref().map(Url::as_str),
            tos_uri.as_ref().map(Url::as_str),
            jwks_uri.as_ref().map(Url::as_str),
            jwks_json,
            id_token_signed_response_alg
                .as_ref()
                .map(ToString::to_string),
            userinfo_signed_response_alg
                .as_ref()
                .map(ToString::to_string),
            token_endpoint_auth_method.as_ref().map(ToString::to_string),
            token_endpoint_auth_signing_alg
                .as_ref()
                .map(ToString::to_string),
            initiate_login_uri.as_ref().map(Url::as_str),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        let jwks = match (jwks, jwks_uri) {
            (None, None) => None,
            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
            _ => return Err(DatabaseError::invalid_operation()),
        };
        Ok(Client {
            id,
            client_id: id.to_string(),
            encrypted_client_secret,
            application_type,
            redirect_uris,
            grant_types,
            client_name,
            logo_uri,
            client_uri,
            policy_uri,
            tos_uri,
            jwks,
            id_token_signed_response_alg,
            userinfo_signed_response_alg,
            token_endpoint_auth_method,
            token_endpoint_auth_signing_alg,
            initiate_login_uri,
        })
    }
    #[tracing::instrument(
        name = "db.oauth2_client.upsert_static",
        skip_all,
        fields(
            db.query.text,
            client.id = %client_id,
        ),
        err,
    )]
    async fn upsert_static(
        &mut self,
        client_id: Ulid,
        client_auth_method: OAuthClientAuthenticationMethod,
        encrypted_client_secret: Option<String>,
        jwks: Option<PublicJsonWebKeySet>,
        jwks_uri: Option<Url>,
        redirect_uris: Vec<Url>,
    ) -> Result<Client, Self::Error> {
        let jwks_json = jwks
            .as_ref()
            .map(serde_json::to_value)
            .transpose()
            .map_err(DatabaseError::to_invalid_operation)?;
        let client_auth_method = client_auth_method.to_string();
        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
        sqlx::query!(
            r#"
                INSERT INTO oauth2_clients
                    ( oauth2_client_id
                    , encrypted_client_secret
                    , redirect_uris
                    , grant_type_authorization_code
                    , grant_type_refresh_token
                    , grant_type_client_credentials
                    , grant_type_device_code
                    , token_endpoint_auth_method
                    , jwks
                    , jwks_uri
                    , is_static
                    )
                VALUES
                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
                ON CONFLICT (oauth2_client_id)
                DO
                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
                             , redirect_uris = EXCLUDED.redirect_uris
                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
                             , grant_type_device_code = EXCLUDED.grant_type_device_code
                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
                             , jwks = EXCLUDED.jwks
                             , jwks_uri = EXCLUDED.jwks_uri
                             , is_static = TRUE
            "#,
            Uuid::from(client_id),
            encrypted_client_secret,
            &redirect_uris_array,
            true,
            true,
            true,
            true,
            client_auth_method,
            jwks_json,
            jwks_uri.as_ref().map(Url::as_str),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        let jwks = match (jwks, jwks_uri) {
            (None, None) => None,
            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
            _ => return Err(DatabaseError::invalid_operation()),
        };
        Ok(Client {
            id: client_id,
            client_id: client_id.to_string(),
            encrypted_client_secret,
            application_type: None,
            redirect_uris,
            grant_types: vec![
                GrantType::AuthorizationCode,
                GrantType::RefreshToken,
                GrantType::ClientCredentials,
            ],
            client_name: None,
            logo_uri: None,
            client_uri: None,
            policy_uri: None,
            tos_uri: None,
            jwks,
            id_token_signed_response_alg: None,
            userinfo_signed_response_alg: None,
            token_endpoint_auth_method: None,
            token_endpoint_auth_signing_alg: None,
            initiate_login_uri: None,
        })
    }
    #[tracing::instrument(
        name = "db.oauth2_client.all_static",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
        let res = sqlx::query_as!(
            OAuth2ClientLookup,
            r#"
                SELECT oauth2_client_id
                     , encrypted_client_secret
                     , application_type
                     , redirect_uris
                     , grant_type_authorization_code
                     , grant_type_refresh_token
                     , grant_type_client_credentials
                     , grant_type_device_code
                     , client_name
                     , logo_uri
                     , client_uri
                     , policy_uri
                     , tos_uri
                     , jwks_uri
                     , jwks
                     , id_token_signed_response_alg
                     , userinfo_signed_response_alg
                     , token_endpoint_auth_method
                     , token_endpoint_auth_signing_alg
                     , initiate_login_uri
                FROM oauth2_clients c
                WHERE is_static = TRUE
            "#,
        )
        .traced()
        .fetch_all(&mut *self.conn)
        .await?;
        res.into_iter()
            .map(|r| r.try_into().map_err(DatabaseError::from))
            .collect()
    }
    #[tracing::instrument(
        name = "db.oauth2_client.get_consent_for_user",
        skip_all,
        fields(
            db.query.text,
            %user.id,
            %client.id,
        ),
        err,
    )]
    async fn get_consent_for_user(
        &mut self,
        client: &Client,
        user: &User,
    ) -> Result<Scope, Self::Error> {
        let scope_tokens: Vec<String> = sqlx::query_scalar!(
            r#"
                SELECT scope_token
                FROM oauth2_consents
                WHERE user_id = $1 AND oauth2_client_id = $2
            "#,
            Uuid::from(user.id),
            Uuid::from(client.id),
        )
        .fetch_all(&mut *self.conn)
        .await?;
        let scope: Result<Scope, _> = scope_tokens
            .into_iter()
            .map(|s| ScopeToken::from_str(&s))
            .collect();
        let scope = scope.map_err(|e| {
            DatabaseInconsistencyError::on("oauth2_consents")
                .column("scope_token")
                .source(e)
        })?;
        Ok(scope)
    }
    #[tracing::instrument(
        name = "db.oauth2_client.give_consent_for_user",
        skip_all,
        fields(
            db.query.text,
            %user.id,
            %client.id,
            %scope,
        ),
        err,
    )]
    async fn give_consent_for_user(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
        client: &Client,
        user: &User,
        scope: &Scope,
    ) -> Result<(), Self::Error> {
        let now = clock.now();
        let (tokens, ids): (Vec<String>, Vec<Uuid>) = scope
            .iter()
            .map(|token| {
                (
                    token.to_string(),
                    Uuid::from(Ulid::from_datetime_with_source(now.into(), rng)),
                )
            })
            .unzip();
        sqlx::query!(
            r#"
                INSERT INTO oauth2_consents
                    (oauth2_consent_id, user_id, oauth2_client_id, scope_token, created_at)
                SELECT id, $2, $3, scope_token, $5 FROM UNNEST($1::uuid[], $4::text[]) u(id, scope_token)
                ON CONFLICT (user_id, oauth2_client_id, scope_token) DO UPDATE SET refreshed_at = $5
            "#,
            &ids,
            Uuid::from(user.id),
            Uuid::from(client.id),
            &tokens,
            now,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        Ok(())
    }
    #[tracing::instrument(
        name = "db.oauth2_client.delete_by_id",
        skip_all,
        fields(
            db.query.text,
            client.id = %id,
        ),
        err,
    )]
    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
        {
            let span = info_span!(
                "db.oauth2_client.delete_by_id.authorization_grants",
                { DB_QUERY_TEXT } = tracing::field::Empty,
            );
            sqlx::query!(
                r#"
                    DELETE FROM oauth2_authorization_grants
                    WHERE oauth2_client_id = $1
                "#,
                Uuid::from(id),
            )
            .record(&span)
            .execute(&mut *self.conn)
            .instrument(span)
            .await?;
        }
        {
            let span = info_span!(
                "db.oauth2_client.delete_by_id.consents",
                { DB_QUERY_TEXT } = tracing::field::Empty,
            );
            sqlx::query!(
                r#"
                    DELETE FROM oauth2_consents
                    WHERE oauth2_client_id = $1
                "#,
                Uuid::from(id),
            )
            .record(&span)
            .execute(&mut *self.conn)
            .instrument(span)
            .await?;
        }
        {
            let span = info_span!(
                "db.oauth2_client.delete_by_id.access_tokens",
                { DB_QUERY_TEXT } = tracing::field::Empty,
            );
            sqlx::query!(
                r#"
                    DELETE FROM oauth2_access_tokens
                    WHERE oauth2_session_id IN (
                        SELECT oauth2_session_id
                        FROM oauth2_sessions
                        WHERE oauth2_client_id = $1
                    )
                "#,
                Uuid::from(id),
            )
            .record(&span)
            .execute(&mut *self.conn)
            .instrument(span)
            .await?;
        }
        {
            let span = info_span!(
                "db.oauth2_client.delete_by_id.refresh_tokens",
                { DB_QUERY_TEXT } = tracing::field::Empty,
            );
            sqlx::query!(
                r#"
                    DELETE FROM oauth2_refresh_tokens
                    WHERE oauth2_session_id IN (
                        SELECT oauth2_session_id
                        FROM oauth2_sessions
                        WHERE oauth2_client_id = $1
                    )
                "#,
                Uuid::from(id),
            )
            .record(&span)
            .execute(&mut *self.conn)
            .instrument(span)
            .await?;
        }
        {
            let span = info_span!(
                "db.oauth2_client.delete_by_id.sessions",
                { DB_QUERY_TEXT } = tracing::field::Empty,
            );
            sqlx::query!(
                r#"
                    DELETE FROM oauth2_sessions
                    WHERE oauth2_client_id = $1
                "#,
                Uuid::from(id),
            )
            .record(&span)
            .execute(&mut *self.conn)
            .instrument(span)
            .await?;
        }
        let res = sqlx::query!(
            r#"
                DELETE FROM oauth2_clients
                WHERE oauth2_client_id = $1
            "#,
            Uuid::from(id),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;
        DatabaseError::ensure_affected_rows(&res, 1)
    }
}