ensure dtype cast

This commit is contained in:
Prince Canuma
2026-01-17 13:03:48 +01:00
parent e4cdbb7eab
commit 883c6b0ad8
6 changed files with 52 additions and 32 deletions

View File

@@ -95,6 +95,7 @@ def apply_conditioning(
Updated LatentState with conditioning applied
"""
state = state.clone()
dtype = state.latent.dtype
b, c, f, h, w = state.latent.shape
for cond in conditionings:
@@ -132,7 +133,7 @@ def apply_conditioning(
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
@@ -161,7 +162,8 @@ def apply_denoise_mask(
Returns:
Blended latent
"""
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
one = mx.array(1.0, dtype=denoised.dtype)
return denoised * denoise_mask + clean * (one - denoise_mask)
def add_noise_with_state(
@@ -191,6 +193,7 @@ def add_noise_with_state(
# But we scale sigma by the mask for conditioned regions
effective_scale = noise_scale * state.denoise_mask
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
one = mx.array(1.0, dtype=state.latent.dtype)
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
return state