Source code for datamasque.client.discovery

import logging
import zipfile
from io import BufferedIOBase, BytesIO, TextIOBase
from pathlib import Path
from typing import Iterator, Optional, Union

from datamasque.client.base import BaseClient, UploadFile
from datamasque.client.exceptions import (
    AsyncRulesetGenerationInProgressError,
    DataMasqueException,
    FailedToStartError,
)
from datamasque.client.models.connection import ConnectionId
from datamasque.client.models.data_selection import (
    SelectedColumns,
    SelectedData,
    SelectedFileData,
)
from datamasque.client.models.discovery import (
    FileDiscoveryResult,
    FileRulesetGenerationRequest,
    RulesetGenerationRequest,
    SchemaDiscoveryPage,
    SchemaDiscoveryRequest,
    SchemaDiscoveryResult,
)
from datamasque.client.models.ruleset import Ruleset
from datamasque.client.models.runs import RunId
from datamasque.client.models.status import AsyncRulesetGenerationTaskStatus

logger = logging.getLogger(__name__)


[docs] class DiscoveryClient(BaseClient): """Schema-discovery and ruleset-generation API methods. Mixed into `DataMasqueClient`."""
[docs] def start_async_ruleset_generation(self, connection_id: ConnectionId, selected_data: SelectedData) -> None: """ Starts async ruleset generation using the most recent discovery results on the given connection. If the connection is a database connection, `selected_data` should be of type `SelectedColumns`. If the connection is a file connection, `selected_data` should be of type `SelectedFileData`. Generation runs asynchronously on the server. Poll `get_async_ruleset_generation_task_status` until it returns `AsyncRulesetGenerationTaskStatus.finished`, then call `get_generated_rulesets` to retrieve the resulting `Ruleset`. """ if not selected_data: raise ValueError("`selected_data` is a required argument to `start_async_ruleset_generation`.") data: dict = {} if isinstance(selected_data, SelectedColumns): data["selected_columns"] = selected_data.columns if selected_data.hash_columns is not None: data["hash_columns"] = { schema: {table: cfg.model_dump(exclude_none=True) for table, cfg in tables.items()} for schema, tables in selected_data.hash_columns.items() } elif isinstance(selected_data, SelectedFileData): for user_selection in selected_data.user_selections: if not (user_selection.locators and user_selection.files): raise ValueError( "Each `UserSelection` in `SelectedFileData.user_selections` " "must have a non-null list of `locators` and `files` to be selected for." ) data["selected_data"] = [s.model_dump() for s in selected_data.user_selections] else: raise TypeError( f"The argument `selected_data` to `start_async_ruleset_generation` was of an invalid type, " f"expected `SelectedColumns` or `SelectedFileData`, got {type(selected_data)}." ) self.make_request(method="POST", path=f"/api/async-generate-ruleset/{connection_id}/", data=data)
[docs] def start_async_ruleset_generation_from_csv( self, connection_id: ConnectionId, csv_content: Union[str, bytes, TextIOBase, BufferedIOBase], target_size_bytes: Optional[int] = None, ) -> None: """ Generate ruleset(s) from the schema discovery CSV file obtained from `get_db_discovery_result_report()`. `target_size_bytes` is an optional integer specifying the approximate size in bytes of each generated ruleset. `csv_content` can be: - A string (e.g. from `get_db_discovery_result_report()`) - Bytes - A text file handle (e.g. `open(path)`) - A binary file handle (e.g. `open(path, 'rb')`) Generation runs asynchronously on the server. Poll `get_async_ruleset_generation_task_status` until it returns `AsyncRulesetGenerationTaskStatus.finished`, then call `get_generated_rulesets` to retrieve the resulting `Ruleset` objects. """ content: BufferedIOBase if isinstance(csv_content, str): content = BytesIO(csv_content.encode()) elif isinstance(csv_content, bytes): content = BytesIO(csv_content) elif isinstance(csv_content, TextIOBase): content = BytesIO(csv_content.read().encode()) else: content = csv_content files = [ UploadFile( field_name="csv_or_zip_file", filename="ruleset.csv", content=content, content_type="text/csv", ), ] self.make_request( method="POST", path=f"/api/async-generate-ruleset/{connection_id}/from-csv/", data={"target_size_bytes": target_size_bytes} if target_size_bytes is not None else None, files=files, )
[docs] def get_async_ruleset_generation_task_status(self, connection_id: ConnectionId) -> AsyncRulesetGenerationTaskStatus: """Queries the status of an async ruleset generation task.""" response = self.make_request(method="GET", path=f"/api/async-generate-ruleset/{connection_id}/") response_data = response.json() status = response_data.get("status") if not status: raise DataMasqueException("Attempted to get an async ruleset generation task status but none was given.") return AsyncRulesetGenerationTaskStatus(status)
[docs] def get_generated_rulesets(self, connection_id: ConnectionId) -> list[Ruleset]: """ Return the `Ruleset` objects produced by a previously-started async ruleset generation. Use for all three async-RG flows: - Database masking from a schema-discovery CSV (`start_async_ruleset_generation_from_csv`) - returns one or more rulesets - Database masking from a column selection (`start_async_ruleset_generation` with `SelectedColumns`) - returns a list containing one ruleset - File masking from a file/locator selection (`start_async_ruleset_generation` with `SelectedFileData`) - returns a list containing one ruleset Raises `AsyncRulesetGenerationInProgressError` if the task hasn't finished yet, and `DataMasqueException` if it failed. Note that the ruleset(s) have autogenerated names, which you may want to customize before uploading. """ status = self.get_async_ruleset_generation_task_status(connection_id) if status is AsyncRulesetGenerationTaskStatus.failed: logger.error("Ruleset generation failed for connection: %s", connection_id) raise DataMasqueException(f"Ruleset generation failed for connection: {connection_id}") if status is not AsyncRulesetGenerationTaskStatus.finished: logger.error( "Ruleset generation is still in progress for connection: %s. Status: `%s`", connection_id, status.value, ) raise AsyncRulesetGenerationInProgressError( f"Ruleset generation in progress or not ready. Current status: `{status.value}`." ) # The download-rulesets endpoint returns a ZIP attachment for the CSV flow, # or issues a 303 redirect back to the task-status endpoint for the column / file flows # (which carries the generated ruleset inline as `generated_ruleset`). # `requests` follows the 303 transparently, so we distinguish by the presence of # a `Content-Disposition: attachment` header, which Django's `FileResponse` sets on the ZIP response. response = self.make_request( method="GET", path=f"/api/async-generate-ruleset/{connection_id}/download-rulesets/", ) if "attachment" in response.headers.get("Content-Disposition", "").lower(): rulesets = [] with zipfile.ZipFile(BytesIO(response.content)) as zip_file: for file_info in zip_file.infolist(): if file_info.filename.endswith((".yml", ".yaml")): with zip_file.open(file_info) as file: yaml_content = file.read().decode("utf-8") rulesets.append(Ruleset(name=Path(file_info.filename).stem, yaml=yaml_content)) return rulesets generated = response.json().get("generated_ruleset") if not generated: raise DataMasqueException( f"Ruleset generation for connection {connection_id} reported `finished` " f"but no ruleset was returned on the task-status record." ) return [Ruleset(name="generated_ruleset", yaml=generated)]
[docs] def start_schema_discovery_run(self, discovery_config: SchemaDiscoveryRequest) -> RunId: """ Starts a schema discovery run with the given configuration. Args: discovery_config: A `SchemaDiscoveryRequest` with connection ID and optional settings. Returns: RunId: The ID of the started discovery run Raises: FailedToStartError: If run fails to start """ data = discovery_config.model_dump(exclude_none=True, mode="json") response = self.make_request( "POST", "/api/schema-discovery/", data=data, require_status_check=False, ) run_data = response.json() if response.status_code == 201: logger.info("Schema discovery run %s started successfully", run_data["id"]) return RunId(run_data["id"]) logger.error("Schema discovery run failed to start: %s", run_data) raise FailedToStartError( f"Schema discovery run failed to start " f"(server responded with status {response.status_code}: {response.text}).", response=response, )
[docs] def iter_schema_discovery_results(self, run_id: RunId) -> Iterator[SchemaDiscoveryResult]: """Lazily iterate all schema discovery results for a run via the paginated v2 endpoint.""" return self._iter_paginated( f"/api/schema-discovery/v2/{run_id}/", model=SchemaDiscoveryResult, )
[docs] def list_schema_discovery_results(self, run_id: RunId) -> list[SchemaDiscoveryResult]: """Returns all schema discovery results for a run.""" return list(self.iter_schema_discovery_results(run_id))
[docs] def get_schema_discovery_page(self, run_id: RunId, *, limit: int = 50, offset: int = 0) -> SchemaDiscoveryPage: """ Returns a single page of schema discovery results including `table_metadata`. Use this when you need the table-constraint metadata alongside the results. """ response = self.make_request( "GET", f"/api/schema-discovery/v2/{run_id}/", params={"limit": limit, "offset": offset}, ) return SchemaDiscoveryPage.model_validate(response.json())
[docs] def generate_ruleset(self, generation_request: RulesetGenerationRequest) -> str: """ Generates database-masking ruleset YAML from the most recent discovery run on the given connection. `generation_request` is a `RulesetGenerationRequest`. """ data = generation_request.model_dump(exclude_none=True, mode="json") response = self.make_request("POST", "/api/generate-ruleset/v2/", data=data) return response.content.decode("utf-8")
[docs] def generate_file_ruleset(self, generation_request: FileRulesetGenerationRequest) -> str: """ Generates file-masking ruleset YAML from the most recent file-data-discovery run on the given connection. `generation_request` is a `FileRulesetGenerationRequest`. """ data = generation_request.model_dump(exclude_none=True, mode="json") response = self.make_request("POST", "/api/generate-file-ruleset/", data=data) return response.content.decode("utf-8")
[docs] def get_file_data_discovery_report(self, run_id: RunId) -> list[FileDiscoveryResult]: """Returns the file-data-discovery results for the specified run.""" response = self.make_request("GET", f"api/runs/{run_id}/file-discovery-results/") return [FileDiscoveryResult.model_validate(d) for d in response.json()]