swift.hub.push_to_hub 源代码

# Copyright (c) Alibaba, Inc. and its affiliates.

import concurrent.futures
import os
import shutil
from multiprocessing import Manager, Process, Value

from swift.utils.logger import get_logger
from .api import HubApi
from .constants import DEFAULT_REPOSITORY_REVISION, ModelVisibility

logger = get_logger()

_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_queues = dict()
_flags = dict()
_tasks = dict()
_manager = None


def _api_push_to_hub(repo_name,
                     output_dir,
                     token,
                     private=True,
                     commit_message='',
                     tag=None,
                     source_repo='',
                     ignore_file_pattern=None,
                     revision=DEFAULT_REPOSITORY_REVISION):
    try:
        api = HubApi()
        api.login(token)
        api.push_model(
            repo_name,
            output_dir,
            visibility=ModelVisibility.PUBLIC if not private else ModelVisibility.PRIVATE,
            chinese_name=repo_name,
            commit_message=commit_message,
            tag=tag,
            original_model_id=source_repo,
            ignore_file_pattern=ignore_file_pattern,
            revision=revision)
        commit_message = commit_message or 'No commit message'
        logger.info(f'Successfully upload the model to {repo_name} with message: {commit_message}')
        return True
    except Exception as e:
        logger.error(f'Error happens when uploading model {repo_name} with message: {commit_message}: {e}')
        return False


[文档] def push_to_hub(repo_name, output_dir, token=None, private=True, retry=3, commit_message='', tag=None, source_repo='', ignore_file_pattern=None, revision=DEFAULT_REPOSITORY_REVISION): """ Args: repo_name: The repo name for the modelhub repo output_dir: The local output_dir for the checkpoint token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None private: If is a private repo, default True retry: Retry times if something error in uploading, default 3 commit_message: The commit message tag: The tag of this commit source_repo: The source repo (model id) which this model comes from ignore_file_pattern: The file pattern to be ignored in uploading. revision: The branch to commit to Returns: The boolean value to represent whether the model is uploaded. """ if token is None: token = os.environ.get('MODELSCOPE_API_TOKEN') if ignore_file_pattern is None: ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN') assert repo_name is not None assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.' assert os.path.isdir(output_dir) assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \ or 'configuration.yml' in os.listdir(output_dir) logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}') for i in range(retry): if _api_push_to_hub(repo_name, output_dir, token, private, commit_message, tag, source_repo, ignore_file_pattern, revision): return True return False
[文档] def push_to_hub_async(repo_name, output_dir, token=None, private=True, commit_message='', tag=None, source_repo='', ignore_file_pattern=None, revision=DEFAULT_REPOSITORY_REVISION): """ Args: repo_name: The repo name for the modelhub repo output_dir: The local output_dir for the checkpoint token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None private: If is a private repo, default True commit_message: The commit message tag: The tag of this commit source_repo: The source repo (model id) which this model comes from ignore_file_pattern: The file pattern to be ignored in uploading revision: The branch to commit to Returns: A handler to check the result and the status """ if token is None: token = os.environ.get('MODELSCOPE_API_TOKEN') if ignore_file_pattern is None: ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN') assert repo_name is not None assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.' assert os.path.isdir(output_dir) assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \ or 'configuration.yml' in os.listdir(output_dir) logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}') return _executor.submit(_api_push_to_hub, repo_name, output_dir, token, private, commit_message, tag, source_repo, ignore_file_pattern, revision)
def submit_task(q, b): while True: b.value = False item = q.get() logger.info(item) b.value = True if not item.pop('done', False): delete_dir = item.pop('delete_dir', False) output_dir = item.get('output_dir') try: push_to_hub(**item) if delete_dir and os.path.exists(output_dir): shutil.rmtree(output_dir) except Exception as e: logger.error(e) else: break class UploadStrategy: cancel = 'cancel' wait = 'wait' def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs): assert queue_name is not None and len(queue_name) > 0, 'Please specify a valid queue name!' global _manager if _manager is None: _manager = Manager() if queue_name not in _queues: _queues[queue_name] = _manager.Queue() _flags[queue_name] = Value('b', False) process = Process(target=submit_task, args=(_queues[queue_name], _flags[queue_name])) process.start() _tasks[queue_name] = process queue = _queues[queue_name] flag: Value = _flags[queue_name] if kwargs.get('done', False): queue.put(kwargs) elif flag.value and strategy == UploadStrategy.cancel: logger.error(f'Another uploading is running, ' f'this uploading with message {kwargs.get("commit_message")} will be canceled.') else: queue.put(kwargs) def wait_for_done(queue_name): process: Process = _tasks.pop(queue_name, None) if process is None: return process.join() _queues.pop(queue_name) _flags.pop(queue_name)