forked from ashen-sensored/sd_webui_SAG
-
Notifications
You must be signed in to change notification settings - Fork 1
/
automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch
118 lines (101 loc) · 4.8 KB
/
automatic1111-CFGDenoiser-and-script_callbacks-mod-for-SAG.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
From cb129e420612813d74043aa4a6a49575b53e9c14 Mon Sep 17 00:00:00 2001
From: Ashen <[email protected]>
Date: Fri, 21 Apr 2023 09:40:59 -0700
Subject: [PATCH] CFGDenoiser and script_callbacks mod for SAG
---
modules/script_callbacks.py | 34 +++++++++++++++++++++++++++++++
modules/sd_samplers_kdiffusion.py | 7 ++++++-
2 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 07911876..d3d3df14 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -53,6 +53,21 @@ class CFGDenoiserParams:
class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.inner_model = inner_model
+ """Inner model reference that is being used for denoising"""
+
+
+class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +78,8 @@ class CFGDenoisedParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+ self.output_altered = False
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback"""
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
@@ -87,6 +104,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -177,6 +195,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -318,6 +344,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9f08518..6ff55ba6 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -9,6 +9,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -146,7 +147,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -164,6 +165,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
self.step += 1
return denoised
--
2.40.0