Source code for mdai.client

import os
import threading
import re
import math
import zipfile
import requests
import urllib3.exceptions
from retrying import retry
from tqdm import tqdm
import arrow
from .preprocess import Project


[docs]def retry_on_http_error(exception): valid_exceptions = [ requests.exceptions.HTTPError, requests.exceptions.ConnectionError, urllib3.exceptions.HTTPError, ] return any([isinstance(exception, e) for e in valid_exceptions])
ANNOTATIONS_IMPORT_DEFAULT_CHUNK_SIZE = 100000
[docs]class Client: """Client for communicating with MD.ai backend API. Communication is via user access tokens (in MD.ai Hub, Settings -> User Access Tokens). """ def __init__(self, domain="public.md.ai", access_token=None): domain_pattern = r"^\w+\.md\.ai(:\d+)?$" if not re.match(domain_pattern, domain): raise ValueError(f"domain {domain} is invalid: should be format *.md.ai") self.domain = domain self.access_token = access_token self.session = requests.Session() self._test_endpoint()
[docs] def project(self, project_id, path=".", force_download=False, annotations_only=False): """Initializes Project class given project id. Arguments: project_id: hash ID of project. path: directory used for data. """ if path == ".": print("Using working directory for data.") else: os.makedirs(path, exist_ok=True) print(f"Using path '{path}' for data.") data_manager_kwargs = { "domain": self.domain, "project_id": project_id, "path": path, "session": self.session, "headers": self._create_headers(), "force_download": force_download, } annotations_data_manager = ProjectDataManager("annotations", **data_manager_kwargs) annotations_data_manager.create_data_export_job() if not annotations_only: images_data_manager = ProjectDataManager("images", **data_manager_kwargs) images_data_manager.create_data_export_job() annotations_data_manager.wait_until_ready() if not annotations_only: images_data_manager.wait_until_ready() p = Project( annotations_fp=annotations_data_manager.data_path, images_dir=images_data_manager.data_path, ) return p else: print("No project created. Downloaded annotations only.") return None
[docs] def load_model_annotations(self): """Deprecated method: use `import_annotations` instead. """ print("Deprecated method: use `import_annotations` instead.")
[docs] def import_annotations( self, annotations, project_id, dataset_id, model_id=None, chunk_size=ANNOTATIONS_IMPORT_DEFAULT_CHUNK_SIZE, ): """Import annotations into project. For example, this method can be used to load machine learning model results into project as annotations, or quickly populate metadata labels. Arguments: project_id: hash ID of project. dataset_id: hash ID of machine learning model. model_id: hash ID of machine learning model. annotations: list of annotations to load. chunk_size: number of annotations to load as a chunk. """ if not annotations: print(f"No annotations provided.") if not project_id: print(f"project_id is required.") if not dataset_id: print(f"dataset_id is required.") num_chunks = math.ceil(len(annotations) / chunk_size) if num_chunks > 1: print(f"Importing {len(annotations)} total annotations in {num_chunks} chunks...") for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size annotations_chunk = annotations[start:end] manager = AnnotationsImportManager( annotations=annotations_chunk, project_id=project_id, dataset_id=dataset_id, model_id=model_id, session=self.session, domain=self.domain, headers=self._create_headers(), ) manager.create_job() manager.wait_until_ready()
def _create_headers(self): headers = {} if self.access_token: headers["x-access-token"] = self.access_token return headers def _test_endpoint(self): """Checks endpoint for validity and authorization. """ test_endpoint = f"https://{self.domain}/api/test" r = self.session.get(test_endpoint, headers=self._create_headers()) if r.status_code == 200: print(f"Successfully authenticated to {self.domain}.") else: raise Exception("Authorization error. Make sure your access token is valid.") @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _gql(self, query, variables=None): """Executes GraphQL query. """ gql_endpoint = f"https://{self.domain}/api/graphql" headers = self._create_headers() headers["Accept"] = "application/json" headers["Content-Type"] = "application/json" data = {"query": query, "variables": variables} r = self.session.post(gql_endpoint, headers=headers, json=data) if r.status_code != 200: r.raise_for_status() body = r.json() data = body["data"] if "data" in body else None errors = body["errors"] if "errors" in body else None return data, errors
[docs]class ProjectDataManager: """Manager for project data exports and downloads. """ def __init__( self, data_type, domain=None, project_id=None, path=".", session=None, headers=None, force_download=False, ): if data_type not in ["images", "annotations"]: raise ValueError("data_type must be 'images' or 'annotations'.") if not domain: raise ValueError("domain is not specified.") if not project_id: raise ValueError("project_id is not specified.") if not os.path.exists(path): raise OSError(f"Path '{path}' does not exist.") self.data_type = data_type self.force_download = force_download self.domain = domain self.project_id = project_id self.path = path if session and isinstance(session, requests.Session): self.session = session else: self.session = requests.Session() self.headers = headers # path for downloaded data self.data_path = None # ready threading event self._ready = threading.Event()
[docs] def create_data_export_job(self): """Create data export job through MD.ai API. This is an async operation. Status code of 202 indicates successful creation of job. """ endpoint = f"https://{self.domain}/api/data-export/{self.data_type}" params = self._get_data_export_params() r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code == 202: msg = f"Preparing {self.data_type} export for project {self.project_id}..." print(msg.ljust(100)) self._check_data_export_job_progress() else: if r.status_code == 401: msg = ( f"Project {self.project_id} at domain {self.domain}" + " does not exist or you do not have sufficient permissions for access." ) print(msg) self._on_data_export_job_error()
[docs] def wait_until_ready(self): self._ready.wait()
def _get_data_export_params(self): if self.data_type == "images": params = {"projectHashId": self.project_id, "exportFormat": "zip"} elif self.data_type == "annotations": # TODO: restrict to assigned labelgroup params = { "projectHashId": self.project_id, "labelGroupNum": None, "exportFormat": "json", } return params @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _check_data_export_job_progress(self): """Poll for data export job progress. """ endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/progress" params = self._get_data_export_params() r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() try: body = r.json() status = body["status"] except (TypeError, KeyError): self._on_data_export_job_error() return if status == "running": try: progress = int(body["progress"]) except (TypeError, ValueError): progress = 0 try: time_remaining = int(body["timeRemaining"]) except (TypeError, ValueError): time_remaining = 0 # print formatted progress info if progress > 0 and progress <= 100 and time_remaining > 0: if time_remaining > 45: time_remaining_fmt = ( arrow.now().shift(seconds=time_remaining).humanize(only_distance=True) ) else: # arrow humanizes <= 45 to 'in seconds' or 'just now', # so we will opt to be explicit instead. time_remaining_fmt = f"in {time_remaining} seconds" end_char = "\r" if progress < 100 else "\n" msg = ( f"Exporting {self.data_type} for project {self.project_id}..." + f"{progress}% (time remaining: {time_remaining_fmt})." ) print(msg.ljust(100), end=end_char, flush=True) # run progress check at 1s intervals so long as status == 'running' and progress < 100 if progress < 100: t = threading.Timer(1.0, self._check_data_export_job_progress) t.start() return elif status == "done": self._on_data_export_job_done() elif status == "error": self._on_data_export_job_error() @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _on_data_export_job_done(self): endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/done" params = self._get_data_export_params() r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() try: file_keys = r.json()["fileKeys"] if file_keys: data_path = self._get_data_path(file_keys) if self.force_download or not os.path.exists(data_path): # download in separate thread t = threading.Thread(target=self._download_files, args=(file_keys,)) t.start() else: # use existing data self.data_path = data_path print(f"Using cached {self.data_type} data for project {self.project_id}.") # fire ready threading.Event self._ready.set() except (TypeError, KeyError): self._on_data_export_job_error() @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _on_data_export_job_error(self): endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/error" params = self._get_data_export_params() r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() print(f"Error exporting {self.data_type} for project {self.project_id}.") # fire ready threading.Event self._ready.set() def _get_data_path(self, file_keys): if self.data_type == "images": # should be folder for zip file: # xxxx.zip -> xxxx/ # xxxx_part1of3.zip -> xxxx/ images_dir = re.sub(r"(_part\d+of\d+)?\.\S+$", "", file_keys[0]) return os.path.join(self.path, images_dir) elif self.data_type == "annotations": # annotations export will be single file annotations_fp = file_keys[0] return os.path.join(self.path, annotations_fp) def _download_files(self, file_keys): """Downloads files via signed URL requested from MD.ai API. """ for file_key in file_keys: print(f"Downloading file: {file_key}") filepath = os.path.join(self.path, file_key) key = requests.utils.quote(file_key) url = f"https://{self.domain}/api/project-files/signedurl/get?key={key}" # stream response so we can display progress bar r = requests.get(url, stream=True, headers=self.headers) # total size in bytes total_size = int(r.headers.get("content-length", 0)) block_size = 32 * 1024 wrote = 0 with open(filepath, "wb") as f: with tqdm(total=total_size, unit="B", unit_scale=True, unit_divisor=1024) as pbar: for chunk in r.iter_content(block_size): f.write(chunk) wrote = wrote + len(chunk) pbar.update(block_size) if total_size != 0 and wrote != total_size: raise IOError(f"Error downloading file {file_key}.") if self.data_type == "images": # unzip archive print(f"Extracting archive: {file_key}") with zipfile.ZipFile(filepath, "r") as f: f.extractall(self.path) self.data_path = self._get_data_path(file_keys) print(f"Success: {self.data_type} data for project {self.project_id} ready.") # fire ready threading.Event self._ready.set()
[docs]class AnnotationsImportManager: """Manager for importing annotations. """ def __init__( self, annotations=None, project_id=None, dataset_id=None, model_id=None, session=None, domain=None, headers=None, ): if not domain: raise ValueError("domain is not specified.") if not project_id: raise ValueError("project_id is not specified.") self.annotations = annotations self.project_id = project_id self.dataset_id = dataset_id self.model_id = model_id if session and isinstance(session, requests.Session): self.session = session else: self.session = requests.Session() self.domain = domain self.headers = headers self.job_id = None # ready threading event self._ready = threading.Event()
[docs] def create_job(self): """Create annotations import job through MD.ai API. This is an async operation. Status code of 202 indicates successful creation of job. """ endpoint = f"https://{self.domain}/api/data-import/annotations" params = { "projectHashId": self.project_id, "datasetHashId": self.dataset_id, "modelHashId": self.model_id, "annotations": self.annotations, } r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code == 202: self.job_id = r.json()["jobId"] msg = f"Importing {len(self.annotations)} annotations into project {self.project_id}" msg += f", dataset {self.dataset_id}" if self.model_id: msg += f", model {self.model_id}..." else: msg += "..." print(msg.ljust(100)) self._check_job_progress() else: print(r.status_code) if r.status_code == 401: msg = "Provided IDs are invalid, or you do not have sufficient permissions." print(msg) self._on_job_error()
[docs] def wait_until_ready(self): self._ready.wait()
@retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _check_job_progress(self): """Poll for annotations import job progress. """ endpoint = f"https://{self.domain}/api/data-import/annotations/progress" params = {"projectHashId": self.project_id, "jobId": self.job_id} r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() try: body = r.json() status = body["status"] except (TypeError, KeyError): self._on_job_error() return if status == "running": try: progress = int(body["progress"]) except (TypeError, ValueError): progress = 0 try: time_remaining = int(body["timeRemaining"]) except (TypeError, ValueError): time_remaining = 0 # print formatted progress info if progress > 0 and progress <= 100 and time_remaining > 0: if time_remaining > 45: time_remaining_fmt = ( arrow.now().shift(seconds=time_remaining).humanize(only_distance=True) ) else: # arrow humanizes <= 45 to 'in seconds' or 'just now', # so we will opt to be explicit instead. time_remaining_fmt = f"in {time_remaining} seconds" end_char = "\r" if progress < 100 else "\n" msg = ( f"Annotations import for project {self.project_id}..." + f"{progress}% (time remaining: {time_remaining_fmt})." ) print(msg.ljust(100), end=end_char, flush=True) # run progress check at 1s intervals so long as status == 'running' and progress < 100 if progress < 100: t = threading.Timer(1.0, self._check_job_progress) t.start() return elif status == "done": self._on_job_done() elif status == "error": self._on_job_error() @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _on_job_done(self): endpoint = f"https://{self.domain}/api/data-import/annotations/done" params = {"projectHashId": self.project_id, "jobId": self.job_id} r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() print(f"Successfully imported annotations into project {self.project_id}.") # fire ready threading.Event self._ready.set() @retry( retry_on_exception=retry_on_http_error, wait_exponential_multiplier=100, wait_exponential_max=1000, stop_max_attempt_number=10, ) def _on_job_error(self): endpoint = f"https://{self.domain}/api/data-import/annotations/error" params = {"projectHashId": self.project_id, "jobId": self.job_id} r = self.session.post(endpoint, json=params, headers=self.headers) if r.status_code != 200: r.raise_for_status() print(f"Error importing annotations into project {self.project_id}.") # fire ready threading.Event self._ready.set()