ensure dtype cast
This commit is contained in:
@@ -95,6 +95,7 @@ def apply_conditioning(
|
||||
Updated LatentState with conditioning applied
|
||||
"""
|
||||
state = state.clone()
|
||||
dtype = state.latent.dtype
|
||||
b, c, f, h, w = state.latent.shape
|
||||
|
||||
for cond in conditionings:
|
||||
@@ -132,7 +133,7 @@ def apply_conditioning(
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
@@ -161,7 +162,8 @@ def apply_denoise_mask(
|
||||
Returns:
|
||||
Blended latent
|
||||
"""
|
||||
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
|
||||
one = mx.array(1.0, dtype=denoised.dtype)
|
||||
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||
|
||||
|
||||
def add_noise_with_state(
|
||||
@@ -191,6 +193,7 @@ def add_noise_with_state(
|
||||
# But we scale sigma by the mask for conditioned regions
|
||||
|
||||
effective_scale = noise_scale * state.denoise_mask
|
||||
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
|
||||
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||
|
||||
return state
|
||||
|
||||
Reference in New Issue
Block a user