Source code for datamasque.client.models.connection

"""Connection configuration models for the DataMasque API."""

from enum import Enum
from typing import Any, Callable, Literal, NewType, Optional

from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator

from datamasque.client.exceptions import DataMasqueException
from datamasque.client.models.files import FileId

ConnectionId = NewType("ConnectionId", str)


[docs] def unwrap_connection_id(value: Any) -> Any: """ Coerce a `ConnectionConfig` to its `id`; pass other values through unchanged. Used by request-model validators that accept either a `ConnectionId` or a full `ConnectionConfig` for user convenience. Raises `ValueError` if the config has no `id` (i.e. the caller hasn't yet created it on the server). """ if isinstance(value, ConnectionConfig): if value.id is None: raise ValueError("Connection has not been created yet (id is None)") return value.id return value
[docs] class DatabaseType(Enum): """Supported database engines for `DatabaseConnectionConfig`.""" postgres = "postgres" mysql = "mysql" oracle = "oracle" mariadb = "mariadb" sql_server = "mssql" redshift = "redshift" dynamodb = "dynamo_db" db2_luw = "db2_luw" db2i = "db2i" mssql_linked = "mssql_linked" snowflake = "snowflake" mongodb = "mongodb" databricks_lakebase = "databricks_lakebase" databricks = "databricks"
[docs] class SnowflakeStageLocation(str, Enum): """Storage backend for a Snowflake connection's external stage.""" local = "local" # Not supported for production use aws_s3 = "aws_s3" azure_blob_storage = "azure_blob_storage"
[docs] class SseSelection(Enum): """Mirrors the available options in the AWS console for DynamoDB Server-Side Encryption.""" dynamodb_owned = "dynamodb_owned" aws_managed = "aws_managed" account_managed = "account_managed" use_source = "use_source"
[docs] class SseConfig(BaseModel): """ Server-side encryption configuration for a DynamoDB connection. `kms_key_id` is required when `selection` is `SseSelection.account_managed` and must be `None` for every other selection. """ model_config = ConfigDict(extra="forbid") selection: SseSelection kms_key_id: Optional[str] = None # Required when `selection` is `account_managed`; must be None otherwise @model_validator(mode="after") def _validate_kms_key(self) -> "SseConfig": if self.selection is SseSelection.account_managed: if self.kms_key_id is None: raise ValueError( "A KMS key ID must be specified when the SSE key is stored in your account, and owned " "and managed by you." ) elif self.kms_key_id is not None: raise ValueError( "A KMS key ID can only be specified when the SSE key is stored in your account, and " "owned and managed by you." ) return self
[docs] class ConnectionConfig(BaseModel): """ Base class for all connection configurations. Use `validate_connection(payload)` to deserialize an API response into the appropriate concrete subclass. """ model_config = ConfigDict(extra="allow", populate_by_name=True) name: str id: Optional[ConnectionId] = None
[docs] class DynamoConnectionConfig(ConnectionConfig): """Connection configuration for a DynamoDB table.""" s3_bucket_name: Optional[str] = None dynamo_append_datetime: bool = False dynamo_append_suffix: str = "-MASKED" dynamo_replace_tables: bool = True dynamo_default_region: Optional[str] = None dynamo_default_sse: SseConfig = SseConfig(selection=SseSelection.dynamodb_owned, kms_key_id=None) iam_role_arn: Optional[str] = None export_s3_prefix: Optional[str] = None mask_type: Literal["database"] = "database" db_type: Literal["dynamo_db"] = "dynamo_db" @property def database_type(self) -> DatabaseType: return DatabaseType.dynamodb @model_serializer(mode="wrap") def _serialize(self, handler: Callable) -> dict: d = handler(self) # The admin server requires these placeholder fields for Dynamo connections. d.setdefault("host", "") d.setdefault("port", None) d.setdefault("user", "") d.setdefault("password", "") d.setdefault("database", "") d.setdefault("schema", "") return d @model_validator(mode="before") @classmethod def _strip_server_only_fields(cls, data: dict) -> dict: """Drop fields that come back from the server but aren't part of this model.""" if isinstance(data, dict): for key in ("password_encrypted", "dbpassword"): data.pop(key, None) return data
[docs] class MongoConnectionConfig(ConnectionConfig): """Connection configuration for a MongoDB instance.""" host: str = "" port: int = 27017 database: str = "" user: str = "" password: Optional[str] = None auth_source: str = "admin" tls: bool = False direct_connection: bool = False replica_set: str = "" is_read_only: bool = False mask_type: Literal["database"] = "database" db_type: Literal["mongodb"] = "mongodb" @property def database_type(self) -> DatabaseType: return DatabaseType.mongodb @model_serializer(mode="wrap") def _serialize(self, handler: Callable) -> dict: d = handler(self) # The server expects the password under the `dbpassword` key. password = d.pop("password", None) if password: d["dbpassword"] = password if not d.get("tls"): d.pop("tls", None) if not d.get("direct_connection"): d.pop("direct_connection", None) if not d.get("replica_set"): d.pop("replica_set", None) if not d.get("user"): d.pop("user", None) return d @model_validator(mode="before") @classmethod def _strip_encrypted_password(cls, data: dict) -> dict: if isinstance(data, dict): for key in ("password_encrypted", "dbpassword"): data.pop(key, None) return data
[docs] class SnowflakeConnectionConfig(ConnectionConfig): """ Connection configuration for a Snowflake database. Supports password authentication (`password`) and key-pair authentication (`snowflake_private_key` + optional `snowflake_private_key_passphrase`). """ database: str user: str snowflake_account_id: str snowflake_warehouse: str snowflake_storage_integration_name: str host: str = "" port: Optional[int] = None db_schema: Optional[str] = Field(default=None, alias="schema") snowflake_role: str = "" is_read_only: bool = False password: Optional[str] = None snowflake_private_key: Optional[FileId] = None snowflake_private_key_passphrase: Optional[str] = None snowflake_stage_location: Optional[SnowflakeStageLocation] = None s3_bucket_name: Optional[str] = None iam_role_arn: Optional[str] = None snowflake_azure_container_name: Optional[str] = None snowflake_azure_connection_string: Optional[str] = None snowflake_azure_connection_string_encrypted: Optional[str] = None mask_type: Literal["database"] = "database" db_type: Literal["snowflake"] = "snowflake" @property def database_type(self) -> DatabaseType: return DatabaseType.snowflake @model_serializer(mode="wrap") def _serialize(self, handler: Callable) -> dict: d = handler(self) # The server expects the password under the `dbpassword` key. password = d.pop("password", None) if password is not None: d["dbpassword"] = password # Snowflake requires `schema` even when the user hasn't set one. if d.get("schema") is None: d["schema"] = "" return d @model_validator(mode="before") @classmethod def _strip_encrypted_password(cls, data: dict) -> dict: if isinstance(data, dict): for key in ("password_encrypted", "dbpassword"): data.pop(key, None) return data
[docs] class DatabaseConnectionConfig(ConnectionConfig): """ Connection configuration for a SQL database. Use `DynamoConnectionConfig` for DynamoDB, `SnowflakeConnectionConfig` for Snowflake, and `MongoConnectionConfig` for MongoDB. """ host: str port: int database: str user: str password: Optional[str] = None database_type: DatabaseType engine_options: Optional[dict] = None db_schema: Optional[str] = Field(default=None, alias="schema") data_encoding: Optional[str] = None is_read_only: bool = False s3_bucket_name: Optional[str] = None s3_redshift_iam_role: Optional[str] = None @model_validator(mode="after") def _reject_special_engines(self) -> "DatabaseConnectionConfig": if self.database_type is DatabaseType.dynamodb: raise ValueError("For DynamoDB, use the DynamoConnectionConfig class instead") if self.database_type is DatabaseType.snowflake: raise ValueError("For Snowflake, use the SnowflakeConnectionConfig class instead") if self.database_type is DatabaseType.mongodb: raise ValueError("For MongoDB, use the MongoConnectionConfig class instead") if self.database_type is DatabaseType.databricks: raise ValueError("For Databricks SQL Warehouse, use the DatabricksConnectionConfig class instead") return self mask_type: Literal["database"] = "database" @property def db_type(self) -> str: return self.database_type.value @model_serializer(mode="wrap") def _serialize(self, handler: Callable) -> dict: d = handler(self) # The server expects the password under the `dbpassword` key. password = d.pop("password", None) if password is not None: d["dbpassword"] = password d.pop("database_type", None) d["db_type"] = self.db_type # The server requires certain fields to be present or absent # depending on the engine type. db_type = self.database_type if db_type in {DatabaseType.mysql, DatabaseType.mariadb} or d.get("schema") is None: d["schema"] = "" if db_type not in {DatabaseType.mysql, DatabaseType.mariadb, DatabaseType.oracle, DatabaseType.postgres}: d.pop("data_encoding", None) if db_type is not DatabaseType.redshift: d.pop("s3_bucket_name", None) d.pop("s3_redshift_iam_role", None) if not d.get("engine_options"): d.pop("engine_options", None) return d @model_validator(mode="before") @classmethod def _normalize_incoming(cls, data: dict) -> dict: if isinstance(data, dict): for key in ("password_encrypted", "dbpassword"): data.pop(key, None) # Determine the engine type from whichever key is present. engine = data.get("database_type") or data.get("db_type", "") if isinstance(engine, DatabaseType): engine = engine.value # The API returns a `schema` value for engines that don't have schemas (MySQL/MariaDB). # Drop it so the model accurately reflects "not applicable". if engine in {DatabaseType.mysql.value, DatabaseType.mariadb.value}: data.pop("schema", None) # Map `db_type` → `database_type` for incoming payloads. if "db_type" in data and "database_type" not in data: data["database_type"] = data.pop("db_type") return data
[docs] class MssqlLinkedServerConnectionConfig(DatabaseConnectionConfig): """Connection configuration for a Microsoft SQL Server linked-server setup.""" linked_server: str = ""
[docs] class FileConnectionConfig(ConnectionConfig): """ Abstract base for file-based connections. `is_file_mask_source` and `is_file_mask_destination` control whether the connection can be used as the source, destination, or both of a masking run. """ base_directory: str = "" is_file_mask_source: bool = False is_file_mask_destination: bool = False mask_type: Literal["file"] = "file"
[docs] class S3ConnectionConfig(FileConnectionConfig): """Connection configuration for an S3 bucket.""" type: Literal["s3_connection"] = "s3_connection" bucket: str = "" iam_role_arn: Optional[str] = None
[docs] class AzureConnectionConfig(FileConnectionConfig): """ Connection configuration for an Azure Blob Storage container. `connection_string` comes back encrypted from `list_connections` and is write-only in practice. """ type: Literal["azure_blob_connection"] = "azure_blob_connection" container: str = "" connection_string: Optional[str] = None @model_validator(mode="before") @classmethod def _strip_encrypted_connection_string(cls, data: dict) -> dict: if isinstance(data, dict): # The API returns the encrypted form; drop it so `connection_string` stays None. data.pop("connection_string_encrypted", None) return data
[docs] class MountedShareConnectionConfig(FileConnectionConfig): """Connection configuration for a mounted file share.""" type: Literal["mounted_share_connection"] = "mounted_share_connection"
[docs] class DatabricksConnectionConfig(ConnectionConfig): """Connection configuration for a Databricks SQL Warehouse.""" server_hostname: str http_path: str access_token: Optional[str] = None catalog: str db_schema: Optional[str] = Field(default=None, alias="schema") is_read_only: bool = False version: str = "1.0" mask_type: Literal["database"] = "database" db_type: Literal["databricks"] = "databricks" @property def database_type(self) -> DatabaseType: return DatabaseType.databricks @model_validator(mode="before") @classmethod def _strip_encrypted_token(cls, data: dict) -> dict: if isinstance(data, dict): data.pop("access_token_encrypted", None) return data
FILE_TYPE_MAP: dict[str, type[FileConnectionConfig]] = { "s3_connection": S3ConnectionConfig, "azure_blob_connection": AzureConnectionConfig, "mounted_share_connection": MountedShareConnectionConfig, } DB_TYPE_MAP: dict[str, type[ConnectionConfig]] = { DatabaseType.dynamodb.value: DynamoConnectionConfig, DatabaseType.mongodb.value: MongoConnectionConfig, DatabaseType.snowflake.value: SnowflakeConnectionConfig, DatabaseType.mssql_linked.value: MssqlLinkedServerConnectionConfig, DatabaseType.databricks.value: DatabricksConnectionConfig, # others use the default `DatabaseConnectionConfig` }
[docs] def validate_connection(payload: dict) -> ConnectionConfig: """ Validate an API response payload into the appropriate concrete `ConnectionConfig` subclass. Dispatches on `mask_type`, then on `type` (file) or `db_type` (database). """ mask_type = payload.get("mask_type") if mask_type == "file": file_type = payload.get("type", "") klass = FILE_TYPE_MAP.get(file_type) if klass is None: raise DataMasqueException(f"Unexpected file connection type: {file_type}") return klass.model_validate(payload) if mask_type == "database": db_type = payload.get("db_type", "") db_klass = DB_TYPE_MAP.get(db_type, DatabaseConnectionConfig) return db_klass.model_validate(payload) raise DataMasqueException(f"Unexpected connection mask_type: {mask_type}")