Kako uporabiti metodo “torch.argmax()” v PyTorchu?

Kako Uporabiti Metodo Torch Argmax V Pytorchu



V PyTorchu je » torch.argmax() ” je vgrajena funkcija, ki vrne indekse največjih vrednosti določenega tenzorja v dani dimenziji. Uporabniki uporabljajo to funkcijo, ko delajo s tenzorji in želijo najti indeks največje vrednosti vzdolž dane dimenzije tenzorja. Poleg tega je ta metoda lahko uporabna tudi za razvrščanje, kjer uporabniki želijo vedeti, kateri razred ima največjo verjetnost.

Ta spletni dnevnik bo ponazoril metodo za uporabo metode »torch.argmax()« v PyTorchu.

Kako uporabiti metodo “torch.argmax()” v PyTorchu?

Metoda »torch.argmax()« vzame kateri koli 1D ali 2D tenzor kot vhod in vrne tenzor, ki vsebuje indekse/indekse največjih vrednosti vzdolž dane dimenzije.







Sintaksa metode »torch.argmax()« je podana spodaj:



svetilka. argmax ( < vhodni_tenzor > )

Če želite uporabiti to metodo v PyTorchu, preglejte naslednje primere za boljše razumevanje:



1. primer: uporabite metodo »torch.argmax()« z 1D tenzorjem

V prvem primeru bomo ustvarili 1D tenzor in z njim uporabili metodo »torch.argmax()«. Sledimo spodnjemu postopku po korakih:





1. korak: uvozite knjižnico PyTorch

Najprej uvozite » svetilka ” za uporabo metode “torch.argmax()”:

uvoz svetilka

2. korak: Ustvarite 1D tenzor

Nato ustvarite 1D tenzor in natisnite njegove elemente. Tukaj ustvarjamo naslednje ' desetice1 ' tenzor s seznama z uporabo ' torch.tensor() ” funkcija:



desetice1 = svetilka. tenzor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

tiskanje ( desetice1 )

To je ustvarilo 1D tenzor, kot je prikazano spodaj:

3. korak: Poiščite indekse največje vrednosti

Zdaj pa uporabite » torch.argmax() ' za iskanje indeksa/indeksov največje vrednosti v ' desetice1 ” tenzor:

T1_ind = svetilka. argmax ( desetice1 )

4. korak: Natisnite indeks največje vrednosti

Na koncu prikažite indeks največje vrednosti v vhodnem tenzorju:

tiskanje ( 'Indeksi:' , T1_ind )

Spodnji rezultat prikazuje indeks največje vrednosti v ' desetice1 ” tenzor, tj. 4. To pomeni, da je najvišja vrednost tenzorja pri 4. indeksu, ki je “ 9 ”:

2. primer: uporabite metodo »torch.argmax()« z 2D tenzorjem

V drugem primeru bomo ustvarili 2D tenzor in z njim uporabili metodo »torch.argmax()«. Sledimo navedenim korakom:

1. korak: uvozite knjižnico PyTorch

Najprej uvozite » svetilka ” za uporabo metode “torch.argmax()”:

uvoz svetilka

2. korak: Ustvarite 2D tenzor

Nato uporabite » torch.tensor() ” ustvariti 2D tenzor in natisniti njegove elemente. Tukaj ustvarjamo naslednje ' desetice2 '2D tenzor:

desetice2 = svetilka. tenzor ( [ [ 4 , 1 , - 7 ] , [ petnajst , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

tiskanje ( desetice2 )

To je ustvarilo 2D tenzor, kot je prikazano spodaj:

3. korak: Poiščite indekse največje vrednosti

Zdaj poiščite indeks največje vrednosti v ' desetice2 ' tenzor z uporabo ' torch.argmax() ” funkcija:

T2_ind = svetilka. argmax ( desetice2 )

4. korak: Natisnite indeks največje vrednosti

Končno prikaži indeks največje vrednosti v vhodnem tenzorju:

tiskanje ( 'Indeksi:' , T2_ind )

Glede na spodnji rezultat je indeks največje vrednosti v ' desetice2 ” tenzor je “3”. To pomeni, da je najvišja vrednost tenzorja pri 3. indeksu, ki je ' petnajst ”:

5. korak: Poiščite indekse največje vrednosti vzdolž stolpcev

Poleg tega lahko uporabniki najdejo tudi indekse/indekse največjih vrednosti vzdolž vsakega stolpca tenzorja. Na primer, lahko uporabimo ' dim=0 « s funkcijo »torch.argmax()«. Najde indekse največjih vrednosti vzdolž stolpcev v ' desetice2 ” tenzor in nato natisne te indekse:

col_index = svetilka. argmax ( desetice2 , dim = 0 )

tiskanje ( 'Indeksi v stolpcih:' , col_index )

Spodnji rezultat prikazuje indekse največjih vrednosti vzdolž vsakega stolpca tenzorja:

6. korak: Poiščite indekse največje vrednosti vzdolž vrstic

Podobno lahko uporabniki najdejo tudi indekse/indekse največjih vrednosti vzdolž vsake vrstice tenzorja. Uporabite na primer » dim=1 « s funkcijo »torch.argmax()«, da poiščete indekse največjih vrednosti vzdolž vrstic v tenzorju »Tens2« in nato natisnete te indekse:

indeks_vrstice = svetilka. argmax ( desetice2 , dim = 1 )

tiskanje ( 'Indeksi v vrsticah:' , indeks_vrstice )

Indekse največje vrednosti vzdolž vsake vrstice tenzorja 'Tens2' lahko vidite spodaj:

Učinkovito smo razložili metodo za uporabo metode »torch.argmax()« v PyTorchu.

Opomba : Do našega zvezka Google Colab lahko dostopate tukaj povezava .

Zaključek

Če želite uporabiti metodo »torch.argmax()« v PyTorchu, najprej uvozite » svetilka ' knjižnica. Nato ustvarite želeni 1D ali 2D tenzor in si oglejte njegove elemente. Nato uporabite » torch.argmax() ” za iskanje/izračun indeksov/indeksov največjih vrednosti v tenzorju. Poleg tega lahko uporabniki najdejo tudi indekse največje vrednosti vzdolž vsake vrstice ali stolpca v tenzorju z uporabo ' dim ' prepir. Na koncu prikažite indeks največje vrednosti v vhodnem tenzorju. Ta blog je ponazoril metodo za uporabo metode »torch.argmax()« v PyTorchu.