Skip to content

【Hackathon 8th No.22】Add TurbDiff generative turbulence model implementation #1162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/turbdiff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# TurbDiff: Generative Modeling for 3D Flow Simulation

This is an implementation of the TurbDiff model as described in the paper "From Zero to Turbulence: Generative Modeling for 3D Flow Simulation" (ICLR 2024) by Marten Lienen, David Lüdke, Jan Hansen-Palmus, and Stephan Günnemann.

## Overview

TurbDiff is a denoising diffusion probabilistic model (DDPM) designed for generating realistic 3D turbulent flow fields. Unlike traditional autoregressive approaches, TurbDiff directly learns the manifold of all possible turbulent flow states without relying on any initial flow state.

## Model Architecture

The model architecture consists of:

1. A 3D U-Net backbone with attention mechanisms
2. Specialized conditioning for boundary conditions and geometry
3. A diffusion process based on the DDPM framework
4. Custom components for handling turbulent flow characteristics

## Usage

Please refer to the `train.py` and `infer.py` scripts for training and inference examples.

## Citation

```
@inproceedings{lienen2024zero,
title = {From {{Zero}} to {{Turbulence}}: {{Generative Modeling}} for {{3D Flow Simulation}}},
author = {Lienen, Marten and L{\"u}dke, David and {Hansen-Palmus}, Jan and G{\"u}nnemann, Stephan},
booktitle = {International {{Conference}} on {{Learning Representations}}},
year = {2024},
}
```

## References

- Original implementation: [https://github.com/martenlienen/generative-turbulence](https://github.com/martenlienen/generative-turbulence)
- Paper: [https://arxiv.org/abs/2306.01776](https://arxiv.org/abs/2306.01776)
168 changes: 168 additions & 0 deletions examples/turbdiff/conditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Conditioning module for TurbDiff model.
"""

from enum import Enum, auto
import paddle
import paddle.nn as nn


class ConditioningType(Enum):
"""Types of conditioning supported by the model."""
LOCAL = auto() # Local conditioning like boundary conditions
GLOBAL = auto() # Global conditioning like domain parameters


class Conditioning:
"""Handles the conditioning for the TurbDiff model."""

def __init__(self, cell_type_embedding=None, use_cell_pos=False):
"""
Initialize the conditioning.

Args:
cell_type_embedding: Module to embed cell types
use_cell_pos: Whether to use cell positions as features
"""
self.cell_type_embedding = cell_type_embedding
self.use_cell_pos = use_cell_pos

# Calculate conditioning dimensions
self.local_conditioning_dim = 0
if cell_type_embedding is not None:
self.local_conditioning_dim += cell_type_embedding.embedding_dim
if use_cell_pos:
self.local_conditioning_dim += 3 # x, y, z positions

self.global_conditioning_dim = 0 # Will be set by specific implementations

def prepare_local_conditioning(self, data):
"""
Prepare local conditioning from data.

Args:
data: Input data containing cell types and possibly other information

Returns:
Tensor of local conditioning features
"""
conditioning_elements = []

# Add cell type embeddings if available
if self.cell_type_embedding is not None and hasattr(data, 'cell_type'):
cell_type_emb = self.cell_type_embedding(data.cell_type)
conditioning_elements.append(cell_type_emb)

# Add cell positions if requested
if self.use_cell_pos and hasattr(data, 'cell_pos'):
# Normalize cell positions to [-1, 1] range
pos_min = paddle.min(data.cell_pos, axis=[0, 2, 3, 4], keepdim=True)
pos_max = paddle.max(data.cell_pos, axis=[0, 2, 3, 4], keepdim=True)
pos_norm = 2 * (data.cell_pos - pos_min) / (pos_max - pos_min + 1e-8) - 1
conditioning_elements.append(pos_norm)

# Combine all conditioning elements
if conditioning_elements:
return paddle.concat(conditioning_elements, axis=1)
else:
return None

def prepare_global_conditioning(self, data):
"""
Prepare global conditioning from data.

Args:
data: Input data containing global parameters

Returns:
Tensor of global conditioning features or None
"""
# This should be implemented by specific model extensions
# For base implementation, return None
return None

def prepare_conditioning(self, data):
"""
Prepare all conditioning from data.

Args:
data: Input data

Returns:
Dictionary of conditioning tensors by type
"""
conditioning = {}

local_cond = self.prepare_local_conditioning(data)
if local_cond is not None:
conditioning[ConditioningType.LOCAL] = local_cond

global_cond = self.prepare_global_conditioning(data)
if global_cond is not None:
conditioning[ConditioningType.GLOBAL] = global_cond

return conditioning


class CellTypeEmbedding(nn.Layer):
"""Embedding for different cell types (fluid, solid, boundary, etc.)."""

def __init__(self, num_types, embedding_dim):
"""
Initialize the cell type embedding.

Args:
num_types: Number of different cell types
embedding_dim: Dimension of the embedding
"""
super().__init__()
self.embedding = nn.Embedding(num_types, embedding_dim)
self.embedding_dim = embedding_dim

def forward(self, cell_types):
"""
Get embeddings for cell types.

Args:
cell_types: Tensor of cell type indices [B, 1, H, W, D]

Returns:
Embedded tensor [B, embedding_dim, H, W, D]
"""
# Reshape for embedding lookup
shape = cell_types.shape
flat_types = paddle.reshape(cell_types, [shape[0], -1])

# Lookup embeddings
embeddings = self.embedding(flat_types)

# Reshape back to original shape with embedding dimension
embeddings = paddle.reshape(embeddings,
[shape[0], shape[2], shape[3], shape[4], self.embedding_dim])
embeddings = paddle.transpose(embeddings, [0, 4, 1, 2, 3])

return embeddings

@classmethod
def create(cls, embedding_type, embedding_dim, num_types=5):
"""
Create a cell type embedding of the specified type.

Args:
embedding_type: Type of embedding ('learned', 'fixed', etc.)
embedding_dim: Dimension of the embedding
num_types: Number of different cell types

Returns:
CellTypeEmbedding instance
"""
if embedding_type == 'learned':
return cls(num_types, embedding_dim)
elif embedding_type == 'fixed':
embedding = cls(num_types, embedding_dim)
# Initialize with fixed values and freeze parameters
for param in embedding.parameters():
param.stop_gradient = True
return embedding
else:
raise ValueError(f"Unknown embedding type: {embedding_type}")
57 changes: 57 additions & 0 deletions examples/turbdiff/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# TurbDiff model configuration

# Variables to model
variables:
- name: U
dims: 3 # Velocity vector field (u, v, w)
- name: p
dims: 1 # Pressure scalar field

# Model architecture configuration
model:
dim: 64 # Base dimension for feature maps
u_net_levels: 4 # Number of U-Net downsampling/upsampling levels
actfn: SiLU # Activation function
norm_type: instance # Normalization type: instance, batch, layer, or none
with_geometry_embedding: true # Whether to use geometry embedding
cell_type_features: true # Whether to use cell type conditioning
cell_type_embedding_type: learned # Type of cell embedding: learned or fixed
cell_type_embedding_dim: 8 # Dimension of cell type embedding
cell_pos_features: true # Whether to use cell position features
num_cell_types: 5 # Number of different cell types

# Diffusion process configuration
diffusion:
timesteps: 1000 # Number of diffusion timesteps
loss_type: l2 # Loss type: l1, l2, or huber
beta_schedule: sigmoid # Schedule for noise variance: linear, cosine, or sigmoid
clip_denoised: false # Whether to clip denoised values to [-1, 1]
noise_bcs: true # Whether to add noise to boundary conditions
learned_variances: false # Whether to predict variance
elbo_weight: null # Weight for evidence lower bound term
detach_elbo_mean: true # Whether to detach mean for ELBO calculation

# Data configuration
data:
normalization_mode: mean-std # Normalization mode: mean-std, min-max, or none
num_workers: 4 # Number of data loader workers

# Training configuration
training:
learning_rate: 1.0e-4 # Base learning rate
min_learning_rate: 1.0e-5 # Minimum learning rate for scheduler
warmup_steps: 1000 # Number of warmup steps for learning rate
weight_decay: 1.0e-4 # Weight decay for regularization
beta1: 0.9 # Adam beta1
beta2: 0.999 # Adam beta2
gradient_clip_val: 1.0 # Gradient clipping value
save_interval: 10 # Save checkpoint every N epochs
val_interval: 5 # Validate every N epochs

# Inference configuration
inference:
num_timesteps: 100 # Number of timesteps for fast sampling
metrics:
- mse # Mean squared error
- psnr # Peak signal-to-noise ratio
- ssim # Structural similarity
Loading