Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/zenml-io/zenml/llms.txt

Use this file to discover all available pages before exploring further.

Steps are the building blocks of ZenML pipelines. Each step is a discrete unit of work that takes inputs, performs a computation, and produces outputs. Well-designed steps make your pipelines modular, testable, and reusable.

Basic Step Structure

A step is created using the @step decorator:
from zenml import step

@step
def my_step(input_data: str) -> str:
    """A simple step that processes input."""
    result = input_data.upper()
    return result

Creating Your First Step

1
Step 1: Import the Step Decorator
2
from zenml import step
from typing import Annotated
3
Step 2: Define Step Function
4
Create a function with clear inputs and outputs:
5
@step
def load_data(data_path: str) -> Annotated[dict, "dataset"]:
    """Load data from a file.
    
    Args:
        data_path: Path to the data file
        
    Returns:
        Loaded dataset as a dictionary
    """
    # Your data loading logic
    data = {"values": [1, 2, 3, 4, 5]}
    return data
6
Step 3: Add Type Hints
7
Always use type hints for inputs and outputs:
8
@step
def process_data(
    dataset: dict,
    multiplier: int = 2
) -> Annotated[list, "processed_data"]:
    """Process the dataset.
    
    Args:
        dataset: Input dataset dictionary
        multiplier: Value to multiply each item by
        
    Returns:
        Processed list of values
    """
    values = dataset["values"]
    processed = [x * multiplier for x in values]
    return processed
9
Step 4: Use Steps in a Pipeline
10
Connect steps in your pipeline:
11
from zenml import pipeline

@pipeline
def data_pipeline(data_path: str):
    """Pipeline that loads and processes data."""
    dataset = load_data(data_path)
    processed = process_data(dataset, multiplier=3)
    return processed

Step Inputs and Outputs

Type Annotations

Use Annotated to give artifacts meaningful names:
from typing import Annotated

@step
def train_model(
    train_data: dict,
    learning_rate: float
) -> Annotated[object, "trained_model"]:
    """Train a model on the data.
    
    The Annotated type gives the output artifact the name 'trained_model'
    which makes it easier to identify in the ZenML dashboard.
    """
    # Training logic
    model = {"lr": learning_rate, "trained": True}
    return model

Multiple Outputs

Return multiple artifacts from a step:
from typing import Tuple, Annotated

@step
def split_data(
    dataset: dict,
    test_size: float = 0.2
) -> Tuple[
    Annotated[dict, "train_data"],
    Annotated[dict, "test_data"]
]:
    """Split dataset into train and test sets.
    
    Args:
        dataset: Full dataset to split
        test_size: Fraction of data to use for testing
        
    Returns:
        Tuple of (train_data, test_data)
    """
    data = dataset["values"]
    split_idx = int(len(data) * (1 - test_size))
    
    train_data = {"values": data[:split_idx]}
    test_data = {"values": data[split_idx:]}
    
    return train_data, test_data
Use multiple outputs in a pipeline:
@pipeline
def ml_pipeline():
    """Pipeline with multiple step outputs."""
    dataset = load_data("data.csv")
    train_data, test_data = split_data(dataset, test_size=0.3)
    
    model = train_model(train_data, learning_rate=0.001)
    metrics = evaluate_model(model, test_data)

Optional Outputs

Use Optional for conditional outputs:
from typing import Optional, Annotated

@step
def validate_and_process(
    data: dict,
    strict: bool = False
) -> Tuple[
    Annotated[dict, "processed_data"],
    Annotated[Optional[dict], "validation_errors"]
]:
    """Validate and process data, optionally returning errors.
    
    Args:
        data: Input data to validate
        strict: Whether to perform strict validation
        
    Returns:
        Tuple of (processed_data, validation_errors)
        validation_errors will be None if no errors found
    """
    errors = None
    
    if strict:
        # Perform validation
        if len(data["values"]) < 10:
            errors = {"error": "Insufficient data"}
    
    # Process data
    processed = {"values": data["values"], "validated": True}
    
    return processed, errors

Step Parameters

Make steps configurable with parameters:
@step
def preprocess_text(
    text: str,
    lowercase: bool = True,
    remove_punctuation: bool = True,
    max_length: int = 1000
) -> Annotated[str, "processed_text"]:
    """Preprocess text with configurable options.
    
    Args:
        text: Input text to process
        lowercase: Convert to lowercase
        remove_punctuation: Remove punctuation marks
        max_length: Maximum length of output
        
    Returns:
        Processed text
    """
    processed = text
    
    if lowercase:
        processed = processed.lower()
    
    if remove_punctuation:
        processed = ''.join(c for c in processed if c.isalnum() or c.isspace())
    
    processed = processed[:max_length]
    
    return processed

Step Configuration

Resource Settings

Specify compute resources for a step:
from zenml import step
from zenml.config import ResourceSettings

@step(settings={"resources": ResourceSettings(cpu_count=4, memory="8GB")})
def train_large_model(data: dict) -> object:
    """Train a model that requires significant resources."""
    # Training logic
    return model

Disabling Cache

Disable caching for specific steps:
@step(enable_cache=False)
def fetch_latest_data() -> Annotated[dict, "fresh_data"]:
    """Always fetch fresh data, never use cached results."""
    # This step will always execute, even if inputs are the same
    return fetch_from_api()

Step Operators

Run steps on different infrastructure:
@step(step_operator="gpu_operator")
def train_model_on_gpu(data: dict) -> object:
    """Train model using GPU resources."""
    # Training logic that uses GPU
    return model

Best Practices for Steps

Keep Steps Focused

Each step should have a single, clear purpose:
@step
def load_data(path: str) -> dict:
    """Load data from file."""
    return load_from_file(path)

@step
def clean_data(data: dict) -> dict:
    """Clean the loaded data."""
    return remove_nulls(data)

@step
def transform_data(data: dict) -> dict:
    """Transform cleaned data."""
    return apply_transformations(data)

Use Meaningful Names

Choose descriptive names for steps and artifacts:
# Good names
@step
def calculate_feature_importance(
    model: object,
    test_data: dict
) -> Annotated[dict, "feature_importance_scores"]:
    """Calculate importance scores for each feature."""
    pass

# Avoid vague names like:
# def process(), def step1(), def do_stuff()

Add Comprehensive Docstrings

@step
def evaluate_model(
    model: object,
    test_data: dict,
    threshold: float = 0.8
) -> Annotated[dict, "evaluation_metrics"]:
    """Evaluate model performance on test data.
    
    Computes various metrics including accuracy, precision, recall,
    and F1 score. If accuracy falls below the threshold, a warning
    is logged.
    
    Args:
        model: Trained model to evaluate
        test_data: Test dataset dictionary with 'X' and 'y' keys
        threshold: Minimum acceptable accuracy (default: 0.8)
        
    Returns:
        Dictionary containing all evaluation metrics
        
    Raises:
        ValueError: If test_data is empty or malformed
    """
    # Implementation
    pass

Handle Errors Gracefully

@step
def process_user_input(
    input_data: str
) -> Annotated[dict, "validated_data"]:
    """Process and validate user input.
    
    Args:
        input_data: Raw input string from user
        
    Returns:
        Validated and processed data
        
    Raises:
        ValueError: If input_data is invalid
    """
    if not input_data or not input_data.strip():
        raise ValueError("Input data cannot be empty")
    
    try:
        # Process input
        processed = parse_input(input_data)
        return {"data": processed, "valid": True}
    except Exception as e:
        raise ValueError(f"Failed to process input: {str(e)}")

Common Step Patterns

Data Loading Step

import pandas as pd
from typing import Annotated

@step
def data_loader(
    data_path: str,
    file_format: str = "csv"
) -> Annotated[pd.DataFrame, "dataset"]:
    """Load data from various file formats.
    
    Args:
        data_path: Path to data file
        file_format: Format of the file (csv, parquet, json)
        
    Returns:
        Loaded data as DataFrame
    """
    if file_format == "csv":
        df = pd.read_csv(data_path)
    elif file_format == "parquet":
        df = pd.read_parquet(data_path)
    elif file_format == "json":
        df = pd.read_json(data_path)
    else:
        raise ValueError(f"Unsupported format: {file_format}")
    
    print(f"Loaded {len(df)} records from {data_path}")
    return df

Model Training Step

from sklearn.linear_model import LogisticRegression
import pandas as pd

@step
def model_trainer(
    train_data: pd.DataFrame,
    target_column: str,
    learning_rate: float = 0.01,
    max_iter: int = 100
) -> Annotated[LogisticRegression, "trained_model"]:
    """Train a logistic regression model.
    
    Args:
        train_data: Training dataset
        target_column: Name of target variable
        learning_rate: Learning rate for optimization
        max_iter: Maximum iterations
        
    Returns:
        Trained model
    """
    X = train_data.drop(columns=[target_column])
    y = train_data[target_column]
    
    model = LogisticRegression(
        learning_rate=learning_rate,
        max_iter=max_iter
    )
    model.fit(X, y)
    
    print(f"Model trained with {len(X)} samples")
    return model

Model Evaluation Step

import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score

@step
def model_evaluator(
    model: object,
    test_data: pd.DataFrame,
    target_column: str
) -> Annotated[dict, "metrics"]:
    """Evaluate model performance.
    
    Args:
        model: Trained model to evaluate
        test_data: Test dataset
        target_column: Name of target variable
        
    Returns:
        Dictionary of evaluation metrics
    """
    X = test_data.drop(columns=[target_column])
    y_true = test_data[target_column]
    
    y_pred = model.predict(X)
    
    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, average='weighted'),
        "recall": recall_score(y_true, y_pred, average='weighted'),
    }
    
    print(f"Model Accuracy: {metrics['accuracy']:.3f}")
    return metrics

Next Steps

Step Context

Access runtime information and metadata within steps

Artifact Management

Learn how ZenML tracks and manages step artifacts

Creating Pipelines

Connect steps together in pipelines