# Copyright (c) Alibaba, Inc. and its affiliates.
import concurrent.futures
import os
import shutil
from multiprocessing import Manager, Process, Value
from swift.utils.logger import get_logger
from .api import HubApi
from .constants import DEFAULT_REPOSITORY_REVISION, ModelVisibility
logger = get_logger()
_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_queues = dict()
_flags = dict()
_tasks = dict()
_manager = None
def _api_push_to_hub(repo_name,
output_dir,
token,
private=True,
commit_message='',
tag=None,
source_repo='',
ignore_file_pattern=None,
revision=DEFAULT_REPOSITORY_REVISION):
try:
api = HubApi()
api.login(token)
api.push_model(
repo_name,
output_dir,
visibility=ModelVisibility.PUBLIC if not private else ModelVisibility.PRIVATE,
chinese_name=repo_name,
commit_message=commit_message,
tag=tag,
original_model_id=source_repo,
ignore_file_pattern=ignore_file_pattern,
revision=revision)
commit_message = commit_message or 'No commit message'
logger.info(f'Successfully upload the model to {repo_name} with message: {commit_message}')
return True
except Exception as e:
logger.error(f'Error happens when uploading model {repo_name} with message: {commit_message}: {e}')
return False
[文档]
def push_to_hub(repo_name,
output_dir,
token=None,
private=True,
retry=3,
commit_message='',
tag=None,
source_repo='',
ignore_file_pattern=None,
revision=DEFAULT_REPOSITORY_REVISION):
"""
Args:
repo_name: The repo name for the modelhub repo
output_dir: The local output_dir for the checkpoint
token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
private: If is a private repo, default True
retry: Retry times if something error in uploading, default 3
commit_message: The commit message
tag: The tag of this commit
source_repo: The source repo (model id) which this model comes from
ignore_file_pattern: The file pattern to be ignored in uploading.
revision: The branch to commit to
Returns:
The boolean value to represent whether the model is uploaded.
"""
if token is None:
token = os.environ.get('MODELSCOPE_API_TOKEN')
if ignore_file_pattern is None:
ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
assert repo_name is not None
assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
assert os.path.isdir(output_dir)
assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
or 'configuration.yml' in os.listdir(output_dir)
logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}')
for i in range(retry):
if _api_push_to_hub(repo_name, output_dir, token, private, commit_message, tag, source_repo,
ignore_file_pattern, revision):
return True
return False
[文档]
def push_to_hub_async(repo_name,
output_dir,
token=None,
private=True,
commit_message='',
tag=None,
source_repo='',
ignore_file_pattern=None,
revision=DEFAULT_REPOSITORY_REVISION):
"""
Args:
repo_name: The repo name for the modelhub repo
output_dir: The local output_dir for the checkpoint
token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
private: If is a private repo, default True
commit_message: The commit message
tag: The tag of this commit
source_repo: The source repo (model id) which this model comes from
ignore_file_pattern: The file pattern to be ignored in uploading
revision: The branch to commit to
Returns:
A handler to check the result and the status
"""
if token is None:
token = os.environ.get('MODELSCOPE_API_TOKEN')
if ignore_file_pattern is None:
ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
assert repo_name is not None
assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
assert os.path.isdir(output_dir)
assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
or 'configuration.yml' in os.listdir(output_dir)
logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}')
return _executor.submit(_api_push_to_hub, repo_name, output_dir, token, private, commit_message, tag, source_repo,
ignore_file_pattern, revision)
def submit_task(q, b):
while True:
b.value = False
item = q.get()
logger.info(item)
b.value = True
if not item.pop('done', False):
delete_dir = item.pop('delete_dir', False)
output_dir = item.get('output_dir')
try:
push_to_hub(**item)
if delete_dir and os.path.exists(output_dir):
shutil.rmtree(output_dir)
except Exception as e:
logger.error(e)
else:
break
class UploadStrategy:
cancel = 'cancel'
wait = 'wait'
def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
assert queue_name is not None and len(queue_name) > 0, 'Please specify a valid queue name!'
global _manager
if _manager is None:
_manager = Manager()
if queue_name not in _queues:
_queues[queue_name] = _manager.Queue()
_flags[queue_name] = Value('b', False)
process = Process(target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
process.start()
_tasks[queue_name] = process
queue = _queues[queue_name]
flag: Value = _flags[queue_name]
if kwargs.get('done', False):
queue.put(kwargs)
elif flag.value and strategy == UploadStrategy.cancel:
logger.error(f'Another uploading is running, '
f'this uploading with message {kwargs.get("commit_message")} will be canceled.')
else:
queue.put(kwargs)
def wait_for_done(queue_name):
process: Process = _tasks.pop(queue_name, None)
if process is None:
return
process.join()
_queues.pop(queue_name)
_flags.pop(queue_name)