SwiftOutput
- class swift.tuners.utils.SwiftOutput(model=None, config=None, state_dict_callback=None, save_callback=None, mark_trainable_callback=None, optimizer_group_callback=None, load_state_dict_callback=None, load_callback=None)[源代码]
The output class returned by all tuners.
- 参数:
model (torch.nn.Module) -- The model wrapped
config (SwiftConfig) -- The swift config instance.
state_dict_callback (FunctionType) --
A callback returned by the tuner which is used to get the tuner's state dict among the model's state dict. This callback should receive a state dict, and returns a created state dict. .. rubric:: 示例
>>> def state_dict_callback(state_dict, adapter_name): >>> return { >>> key: value >>> for key, value in state_dict.items() if adapter_name in key >>> }
save_callback (FunctionType) -- A callback used to save trained model.
mark_trainable_callback (FunctionType) --
A callback returned by the tuner which is used to mark the tuner's adapter's parameters to trainable. This callback should receive a model instance, and returns nothing. .. rubric:: 示例
>>> def mark_trainable_callback(model): >>> mark_lora_as_trainable(model, config.bias)
optimizer_group_callback (FunctionType) -- A callback returned the param group cared by the tuner.
load_state_dict_callback (FunctionType) -- A callback called before load_state_dict of the tuner.
load_callback (FunctionType) -- A callback used to load trained model.