Source code for simple_aws_redshift.connect

# -*- coding: utf-8 -*-

"""
Redshift connection parameters and utility functions.
"""

import typing as T
import dataclasses
from datetime import datetime

try:
    import redshift_connector
except ImportError:  # pragma: no cover
    pass
try:
    import sqlalchemy as sa
    from .dialect import RedshiftPostgresDialect, driver_name, dialect_name
except ImportError:  # pragma: no cover
    pass

from func_args.api import REQ, OPT, remove_optional, BaseModel

from .redshift.api import (
    RedshiftCluster,
    get_redshift_cluster,
)
from .redshift_serverless.api import (
    RedshiftServerlessNamespace,
    RedshiftServerlessWorkgroup,
    get_namespace,
    get_workgroup,
)

if T.TYPE_CHECKING:  # pragma: no cover
    from mypy_boto3_redshift.client import RedshiftClient
    from mypy_boto3_redshift_serverless.client import RedshiftServerlessClient


[docs] @dataclasses.dataclass class BaseRedshiftConnectionParams(BaseModel): """ Base class for Redshift connection parameters. """ host: str = dataclasses.field(default=REQ) port: int = dataclasses.field(default=REQ) username: str = dataclasses.field(default=REQ) password: str = dataclasses.field(default=REQ) database: str = dataclasses.field(default=REQ)
[docs] def get_connection( self, timeout: int = 3, ) -> "redshift_connector.Connection": """ Create a Redshift connection using the parameters. :return: A redshift_connector.Connection object. """ return redshift_connector.connect( host=self.host, port=self.port, user=self.username, password=self.password, database=self.database, is_serverless=True, timeout=timeout, )
@property def sqlalchemy_db_url(self) -> "sa.URL": url = sa.URL.create( drivername=f"{dialect_name}+{driver_name}", username=self.username, password=self.password, host=self.host, port=self.port, database=self.database, ) return url def get_engine(self, **kwargs) -> "sa.Engine": return sa.create_engine(self.sqlalchemy_db_url, **kwargs)
[docs] @dataclasses.dataclass class RedshiftClusterConnectionParams(BaseRedshiftConnectionParams): """ Parameters for connecting to a Redshift cluster. Inherits from RedshiftConnectionParams. """ expiration: datetime = dataclasses.field(default=REQ) next_refresh_time: datetime = dataclasses.field(default=REQ) cluster: RedshiftCluster = dataclasses.field(default=REQ)
[docs] @classmethod def new( cls, redshift_client: "RedshiftClient", db_name: str = OPT, cluster_identifier: str = OPT, duration_seconds: int = OPT, custom_domain_name: str = OPT, ): """ Create a new instance of :class:`RedshiftClusterConnectionParams` based on the Redshift cluster identifier. :param redshift_client: boto3.client("redshift") object :param db_name: The name of the database to connect to. :param cluster_identifier: The identifier of the Redshift cluster. :param duration_seconds: Optional duration in seconds for the credentials. :param custom_domain_name: Optional custom domain name for the connection. """ cluster = get_redshift_cluster( redshift_client=redshift_client, cluster_identifier=cluster_identifier, ) kwargs = dict( DbName=db_name, ClusterIdentifier=cluster_identifier, DurationSeconds=duration_seconds, CustomDomainName=custom_domain_name, ) response = redshift_client.get_cluster_credentials_with_iam( **remove_optional(**kwargs) ) return cls( host=cluster.endpoint_address, port=cluster.endpoint_port, username=response["DbUser"], password=response["DbPassword"], database=db_name, cluster=cluster, expiration=response["Expiration"], next_refresh_time=response["NextRefreshTime"], )
[docs] @dataclasses.dataclass class RedshiftServerlessConnectionParams(BaseRedshiftConnectionParams): expiration: datetime = dataclasses.field(default=REQ) next_refresh_time: datetime = dataclasses.field(default=REQ) namespace: RedshiftServerlessNamespace = dataclasses.field(default=REQ) workgroup: RedshiftServerlessWorkgroup = dataclasses.field(default=REQ)
[docs] @classmethod def new( cls, redshift_serverless_client: "RedshiftServerlessClient", namespace_name: str, workgroup_name: str, custom_domain_name: str = OPT, duration_seconds: int = OPT, ): """ Create a new instance of :class:`RedshiftServerlessConnectionParams` based on the redshift serverless namespace and workgroup. :param redshift_serverless_client: boto3.client("redshift-serverless") object :param namespace_name: The name of the Redshift serverless namespace. :param workgroup_name: The name of the Redshift serverless workgroup. :param custom_domain_name: Optional custom domain name for the connection. :param duration_seconds: Optional duration in seconds for the credentials. """ namespace = get_namespace( redshift_serverless_client=redshift_serverless_client, namespace_name=namespace_name, ) workgroup = get_workgroup( redshift_serverless_client=redshift_serverless_client, workgroup_name=workgroup_name, ) kwargs = dict( dbName=namespace.db_name, workgroupName=workgroup_name, customDomainName=custom_domain_name, durationSeconds=duration_seconds, ) response = redshift_serverless_client.get_credentials( **remove_optional(**kwargs) ) params = cls( host=workgroup.endpoint_address, port=workgroup.endpoint_port, username=response["dbUser"], password=response["dbPassword"], database=namespace.db_name, expiration=response["expiration"], next_refresh_time=response["nextRefreshTime"], namespace=namespace, workgroup=workgroup, ) return params