# ******************************************************************************
# pySimBlocks
# Copyright (c) 2026 Université de Lille & INRIA
# ******************************************************************************
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
# for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# ******************************************************************************
# Authors: see Authors.txt
# ******************************************************************************
from __future__ import annotations
import re
import unicodedata
import numpy as np
from numpy.typing import ArrayLike
from pySimBlocks.core.block import Block
[docs]
class Gain(Block):
"""Static gain block.
Applies a gain to the input signal according to one of three multiplication
modes: element-wise, left matrix product (K @ u), or right matrix product
(u @ K).
Attributes:
gain: Gain coefficient(s) — scalar float, 1D vector, or 2D matrix.
multiplication: Active multiplication mode string.
"""
direct_feedthrough = True
MULT_ELEMENTWISE = "Element wise (K * u)"
MULT_LEFT = "Matrix (K @ u)"
MULT_RIGHT = "Matrix (u @ K)"
ALLOWED_MULTIPLICATIONS = {MULT_ELEMENTWISE, MULT_LEFT, MULT_RIGHT}
def __init__(
self,
name: str,
gain: ArrayLike = 1.0,
multiplication: str = MULT_ELEMENTWISE,
sample_time: float | None = None,
):
"""Initialize a Gain block.
Args:
name: Unique identifier for this block instance.
gain: Gain coefficient(s). May be a scalar, a 1D vector, or a 2D
matrix. The accepted shape depends on the chosen multiplication
mode.
multiplication: Multiplication mode string. Accepted values include
``'Element wise (K * u)'``, ``'Matrix (K @ u)'``, and
``'Matrix (u @ K)'`` as well as common aliases.
sample_time: Sampling period in seconds, or None to use the global
simulation dt.
Raises:
TypeError: If ``multiplication`` is not a string.
ValueError: If ``multiplication`` is not a recognized mode, or if
``gain`` is not scalar, 1D, or 2D.
"""
super().__init__(name, sample_time)
self.multiplication = self._parse_multiplication(multiplication)
if np.isscalar(gain):
self.gain = float(gain)
self._gain_kind = "scalar"
else:
g = np.asarray(gain, dtype=float)
if g.ndim not in (1, 2):
raise ValueError(
f"[{self.name}] 'gain' must be a scalar, 1D vector, or 2D matrix. "
f"Got ndim={g.ndim} with shape {g.shape}."
)
self.gain = g
self._gain_kind = "vector" if g.ndim == 1 else "matrix"
self.inputs["in"] = None
self.outputs["out"] = None
# --------------------------------------------------------------------------
# Class methods
# --------------------------------------------------------------------------
@classmethod
def _parse_multiplication(cls, multiplication: str) -> str:
"""Normalize and validate a multiplication mode string."""
if not isinstance(multiplication, str):
raise TypeError(f"[{cls.__name__}] 'multiplication' must be a str.")
m = cls._normalize_user_string(multiplication)
if m in {
"elementwise(k*u)", "elementwise", "elem", "k*u", "*", "k×u", "kxu"
}:
return cls.MULT_ELEMENTWISE
if m in {
"matrix(k@u)", "k@u", "left", "matleft", "@left"
}:
return cls.MULT_LEFT
if m in {
"matrix(u@k)", "u@k", "right", "matright", "@right"
}:
return cls.MULT_RIGHT
if "k@u" in m:
return cls.MULT_LEFT
if "u@k" in m:
return cls.MULT_RIGHT
raise ValueError(
f"[{cls.__name__}] Invalid 'multiplication'='{multiplication}'. "
f"Examples: '{cls.MULT_ELEMENTWISE}', '{cls.MULT_RIGHT}', '{cls.MULT_LEFT}'."
)
# --------------------------------------------------------------------------
# Public methods
# --------------------------------------------------------------------------
[docs]
def initialize(self, t0: float) -> None:
"""Compute the initial output from the initial input if available.
Args:
t0: Initial simulation time in seconds.
"""
u = self.inputs["in"]
if u is None:
self.outputs["out"] = None
return
if not self._gain_kind == "scalar":
if u.ndim == 1 and u.shape[0] == 1:
u = self._resolve_initialize(u)
if u.ndim == 2 and u.shape == (1, 1):
u = self._resolve_initialize(u)
self.outputs["out"] = self._compute(u)
[docs]
def output_update(self, t: float, dt: float) -> None:
"""Apply the gain to the input and write the result to the output port.
Args:
t: Current simulation time in seconds.
dt: Current time step in seconds.
Raises:
RuntimeError: If input ``'in'`` is not connected.
ValueError: If the input is not 2D or dimensions are incompatible
with the gain and multiplication mode.
"""
u = self.inputs["in"]
if u is None:
raise RuntimeError(f"[{self.name}] Input 'in' is not connected or not set.")
self.outputs["out"] = self._compute(u)
[docs]
def state_update(self, t: float, dt: float) -> None:
"""No-op: Gain is a stateless block.
Args:
t: Current simulation time in seconds.
dt: Current time step in seconds.
"""
return
# --------------------------------------------------------------------------
# Private methods
# --------------------------------------------------------------------------
@staticmethod
def _normalize_user_string(s: str) -> str:
"""Normalize a user-provided string to lowercase ASCII with no spaces."""
s = unicodedata.normalize("NFKC", s)
s = s.strip().lower()
s = s.replace("\u00A0", " ")
s = re.sub(r"\s+", "", s, flags=re.UNICODE)
return s
def _resolve_initialize(self, u) -> np.ndarray:
"""Broadcast a scalar placeholder input to the gain shape at initialization."""
u = u.flatten()
if self.multiplication != self.MULT_ELEMENTWISE:
u = np.full((self.gain.shape[1], 1), u[0], dtype=float)
elif self.multiplication == self.MULT_ELEMENTWISE:
if self._gain_kind == "vector":
u = np.full((self.gain.shape[0], 1), u[0], dtype=float)
elif self._gain_kind == "matrix":
u = np.full(self.gain.shape, u[0], dtype=float)
return u
def _compute(self, u) -> np.ndarray:
"""Validate input and dispatch to the active multiplication method."""
u = np.asarray(u, dtype=float)
if u.ndim != 2:
raise ValueError(
f"[{self.name}] Input 'in' must be a 2D array. Got ndim={u.ndim} with shape {u.shape}."
)
if self.multiplication == self.MULT_ELEMENTWISE:
return self._elementwise(u)
if self.multiplication == self.MULT_LEFT:
return self._left_multiply(u)
if self.multiplication == self.MULT_RIGHT:
return self._right_multiply(u)
raise RuntimeError(f"[{self.name}] Unhandled multiplication mode: {self.multiplication}")
def _elementwise(self, u: np.ndarray) -> np.ndarray:
"""Apply element-wise multiplication K * u."""
if self._gain_kind == "scalar":
return self.gain * u
if self._gain_kind == "vector":
g = self.gain
if g.shape[0] != 1 and u.shape[0] != g.shape[0]:
raise ValueError(
f"[{self.name}] Element-wise mode requires u.shape[0] == len(gain). "
f"Got u.shape={u.shape}, gain.shape={g.shape}."
)
return g.reshape(-1, 1) * u
g = self.gain
if not self._is_scalar_2d(g) and u.shape != g.shape:
raise ValueError(
f"[{self.name}] Element-wise mode with matrix gain requires u.shape == gain.shape. "
f"Got u.shape={u.shape}, gain.shape={g.shape}."
)
return g * u
def _left_multiply(self, u: np.ndarray) -> np.ndarray:
"""Apply left matrix multiplication K @ u."""
if self._gain_kind != "matrix":
raise ValueError(
f"[{self.name}] Multiplication mode '{self.MULT_LEFT}' requires a 2D matrix gain. "
f"Got gain kind '{self._gain_kind}'."
)
K = self.gain
p, m = K.shape
if u.shape[0] != m:
raise ValueError(
f"[{self.name}] Left matrix product requires u.shape[0] == gain.shape[1]. "
f"Got u.shape={u.shape}, gain.shape={K.shape}."
)
return K @ u
def _right_multiply(self, u: np.ndarray) -> np.ndarray:
"""Apply right matrix multiplication u @ K."""
if self._gain_kind != "matrix":
raise ValueError(
f"[{self.name}] Multiplication mode '{self.MULT_RIGHT}' requires a 2D matrix gain. "
f"Got gain kind '{self._gain_kind}'."
)
K = self.gain
m, q = K.shape
if u.shape[1] == 1:
if u.shape[0] != m:
raise ValueError(
f"[{self.name}] Right matrix product with vector requires u.shape[0] == gain.shape[0]."
f"Got u.shape={u.shape}, gain.shape={K.shape}."
)
return (u.T @ K).T
if u.shape[1] != m:
raise ValueError(
f"[{self.name}] Right matrix product requires u.shape[1] == gain.shape[0]. "
f"Got u.shape={u.shape}, gain.shape={K.shape}."
)
return u @ K