Problems using polars shift function with over

In the code below I try to use shift(1).over("id") to shift (lag) the values of a series by one row. Without over it seems to work fine (see lag_val column), but when I add over (which I need because each column contains more than one id), I get 2 lags instead of 1 (see lag_1_val). If I use shift(2), I get 3 lags instead of 2 (lag_2_val). Could someone please point out what I am doing wrong?

df: shape: (12, 5)

┌─────┬─────┬─────────┬───────────┬───────────┐
│ id  ┆ val ┆ lag_val ┆ lag_1_val ┆ lag_2_val │
│ --- ┆ --- ┆ ---     ┆ ---       ┆ ---       │
│ str ┆ i32 ┆ i32     ┆ i32       ┆ i32       │
╞═════╪═════╪═════════╪═══════════╪═══════════╡
│ a   ┆ 1   ┆ null    ┆ null      ┆ null      │
│ a   ┆ 2   ┆ 1       ┆ null      ┆ null      │
│ a   ┆ 3   ┆ 2       ┆ 2         ┆ null      │
│ a   ┆ 4   ┆ 3       ┆ 3         ┆ 2         │
│ …   ┆ …   ┆ …       ┆ …         ┆ …         │
│ c   ┆ 1   ┆ 14      ┆ 13        ┆ 11        │
│ c   ┆ 2   ┆ 1       ┆ 2         ┆ 13        │
│ c   ┆ 3   ┆ 2       ┆ null      ┆ null      │
│ c   ┆ 4   ┆ 3       ┆ null      ┆ null      │
└─────┴─────┴─────────┴───────────┴───────────┘
/*
[dependencies]
polars = { version = "*", features = ["lazy"] }
*/

use polars::prelude::*;

fn main() {

    let df = df!(
        "id" => &["a", "a", "a", "a", "b", "b", "b", "b", "c", "c", "c", "c"],
        "val" => &[1, 2, 3, 4, 11, 12, 13, 14, 1, 2, 3, 4]);
    let d = df.unwrap().clone().lazy()
        .with_columns([
            col("val").shift(1).prefix("lag_"),
            col("val").shift(1).over("id").prefix("lag_1_"),
            col("val").shift(2).over("id").prefix("lag_2_"),
        ])
        .collect()
        .unwrap();

}

rust explorer

Does this do what you want? I'm having trouble understanding what you mean by

and what your actual goal is.

let d = df.unwrap().clone().lazy()
    .with_columns([
        col("val").shift(1).alias("lag_val"),
        col("val").shift(1).over([col("id")]).alias("lag_1_val"),
        col("val").shift(2).over([col("id")]).alias("lag_2_val"),
    ])
    .collect()
    .unwrap();

Rustexplorer.

Output:


shape: (12, 5)
┌─────┬─────┬─────────┬───────────┬───────────┐
│ id  ┆ val ┆ lag_val ┆ lag_1_val ┆ lag_2_val │
│ --- ┆ --- ┆ ---     ┆ ---       ┆ ---       │
│ str ┆ i32 ┆ i32     ┆ i32       ┆ i32       │
╞═════╪═════╪═════════╪═══════════╪═══════════╡
│ a   ┆ 1   ┆ null    ┆ null      ┆ null      │
│ a   ┆ 2   ┆ 1       ┆ 1         ┆ null      │
│ a   ┆ 3   ┆ 2       ┆ 2         ┆ 1         │
│ a   ┆ 4   ┆ 3       ┆ 3         ┆ 2         │
│ …   ┆ …   ┆ …       ┆ …         ┆ …         │
│ c   ┆ 1   ┆ 14      ┆ null      ┆ null      │
│ c   ┆ 2   ┆ 1       ┆ 1         ┆ null      │
│ c   ┆ 3   ┆ 2       ┆ 2         ┆ 1         │
│ c   ┆ 4   ┆ 3       ┆ 3         ┆ 2         │
└─────┴─────┴─────────┴───────────┴───────────┘

Note that all I did was change the argument to the over invocations to something more similar to what was shown in the example from the documentation. I also couldn't find the prefix method, so I just used alias instead.

2 Likes

Perfect. Thanks @jofas

1 Like