SwiftOutput

class swift.tuners.utils.SwiftOutput(model=None, config=None, state_dict_callback=None, save_callback=None, mark_trainable_callback=None, optimizer_group_callback=None, load_state_dict_callback=None, load_callback=None)[源代码]

The output class returned by all tuners.

参数:
  • model (torch.nn.Module) -- The model wrapped

  • config (SwiftConfig) -- The swift config instance.

  • state_dict_callback (FunctionType) --

    A callback returned by the tuner which is used to get the tuner's state dict among the model's state dict. This callback should receive a state dict, and returns a created state dict. .. rubric:: 示例

    >>> def state_dict_callback(state_dict, adapter_name):
    >>>     return {
    >>>         key: value
    >>>         for key, value in state_dict.items() if adapter_name in key
    >>>     }
    

  • save_callback (FunctionType) -- A callback used to save trained model.

  • mark_trainable_callback (FunctionType) --

    A callback returned by the tuner which is used to mark the tuner's adapter's parameters to trainable. This callback should receive a model instance, and returns nothing. .. rubric:: 示例

    >>> def mark_trainable_callback(model):
    >>>     mark_lora_as_trainable(model, config.bias)
    

  • optimizer_group_callback (FunctionType) -- A callback returned the param group cared by the tuner.

  • load_state_dict_callback (FunctionType) -- A callback called before load_state_dict of the tuner.

  • load_callback (FunctionType) -- A callback used to load trained model.