-
Notifications
You must be signed in to change notification settings - Fork 188
Description
I am adding type hinting to the entire code base for my own education. I have found that the implementation is such that there are some issues with type hinting (using Ruff and Cursor). One of the main issues is the use of function arguments that can be None or some type (e.g., Tensor | None). This is good in class constructors, but the constructor should typically replace the None option by an empty object of the correct type. For example, and empty Sequential (Sequential()) is better than either Sequential(Identity()) or a Module (see the Sequential function in neural_memory).
Here is the constructor to NeuralMemory with type hints (which are important for reduced errors and improved code structure):
lass NeuralMemory(Module):
def __init__(
self,
dim: int,
chunk_size: int | tuple[int, int] = 1,
batch_size: int | None = None,
dim_head: int | None = None,
heads: int | None = 1,
model: Module | None = None,
store_memory_loss_fn: Callable = default_loss_fn,
adaptive_step_transform: Callable | None = None,
default_step_transform_max_lr: float = 1.0,
per_parameter_lr_modulation: bool = False, # allow outer network to control learning rate per weight matrix of memory network
max_mem_layer_modulation: float = 1.0, # max of 10.
per_head_learned_parameters: bool = True,
attn_pool_chunks: bool = False,
momentum: bool = True,
momentum_order: int = 1,
learned_momentum_combine: bool = False,
learned_combine_include_zeroth: bool = False,
num_kv_per_token: int = 1, # whether a single token can do multiple updates to the memory model
qkv_receives_diff_views: bool = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
pre_rmsnorm: bool = True,
post_rmsnorm: bool = False,
qk_rmsnorm: bool = False,
max_grad_norm: float | None = None,
use_accelerated_scan: bool = False,
activation: Module | None = None,
init_adaptive_step_bias: Tensor | None = None,
init_momentum_bias: Tensor | None = None,
init_decay_bias: Tensor | None = None,
accept_weight_residual: bool = False,
gated_transition: bool = False,
mem_model_norm_add_residual: bool = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
default_model_kwargs: dict = dict(depth=2, expansion_factor=4.0),
):And argument such as activation: Module | None = None, is fine, but the None should be removed by the time the constructor is exited to remove confusion in later parts of the code. This should be a requirement for robust code development.